00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/integrators/erk_export.hpp>
00035 #include <acado/code_generation/integrators/erk_fob_export.hpp>
00036
00037 using namespace std;
00038
00039 BEGIN_NAMESPACE_ACADO
00040
00041
00042
00043
00044
00045 ForwardOverBackwardERKExport::ForwardOverBackwardERKExport( UserInteraction* _userInteraction,
00046 const std::string& _commonHeaderName
00047 ) : AdjointERKExport( _userInteraction,_commonHeaderName )
00048 {
00049 }
00050
00051
00052 ForwardOverBackwardERKExport::ForwardOverBackwardERKExport( const ForwardOverBackwardERKExport& arg
00053 ) : AdjointERKExport( arg )
00054 {
00055 }
00056
00057
00058 ForwardOverBackwardERKExport::~ForwardOverBackwardERKExport( )
00059 {
00060 clear( );
00061 }
00062
00063
00064
00065 returnValue ForwardOverBackwardERKExport::setDifferentialEquation( const Expression& rhs_ )
00066 {
00067 int sensGen;
00068 get( DYNAMIC_SENSITIVITY,sensGen );
00069
00070 OnlineData dummy0;
00071 Control dummy1;
00072 DifferentialState dummy2;
00073 AlgebraicState dummy3;
00074 DifferentialStateDerivative dummy4;
00075 dummy0.clearStaticCounters();
00076 dummy1.clearStaticCounters();
00077 dummy2.clearStaticCounters();
00078 dummy3.clearStaticCounters();
00079 dummy4.clearStaticCounters();
00080
00081 x = DifferentialState("", NX, 1);
00082 dx = DifferentialStateDerivative("", NDX, 1);
00083 z = AlgebraicState("", NXA, 1);
00084 u = Control("", NU, 1);
00085 od = OnlineData("", NOD, 1);
00086
00087 if( NDX > 0 && NDX != NX ) {
00088 return ACADOERROR( RET_INVALID_OPTION );
00089 }
00090 if( rhs_.getNumRows() != (NX+NXA) ) {
00091 return ACADOERROR( RET_INVALID_OPTION );
00092 }
00093
00094 DifferentialEquation f, g, f_ODE;
00095
00096 f_ODE << rhs_;
00097 if( f_ODE.getNDX() > 0 ) {
00098 return ACADOERROR( RET_INVALID_OPTION );
00099 }
00100
00101 if( (ExportSensitivityType)sensGen == FORWARD_OVER_BACKWARD ) {
00102 DifferentialState Gx("", NX,NX), Gu("", NX,NU);
00103
00104
00105
00106 f << rhs_;
00107
00108
00109
00110
00111 f << multipleForwardDerivative( rhs_, x, Gx );
00112
00113
00114
00115
00116 f << multipleForwardDerivative( rhs_, x, Gu ) + forwardDerivative( rhs_, u );
00117
00118
00119
00120
00121
00122
00123
00124 DifferentialState lx("", NX,1);
00125
00126 Expression tmp = backwardDerivative(rhs_, x, lx);
00127 g << tmp;
00128
00129 DifferentialState Sxx("", NX,NX), Sux("", NU,NX), Suu("", NU,NU);
00130
00131 g << multipleForwardDerivative(tmp, x, Gx) + multipleBackwardDerivative(rhs_, x, Sxx);
00132 g << multipleBackwardDerivative(tmp, x, Gu).transpose() + forwardDerivative(tmp, u).transpose() + multipleBackwardDerivative(rhs_, x, Sux.transpose()).transpose();
00133 g << forwardDerivative(backwardDerivative(rhs_, u, lx), u) + multipleBackwardDerivative(tmp, u, Gu) + multipleBackwardDerivative(rhs_, u, Sux.transpose());
00134 }
00135 else {
00136 return ACADOERROR( RET_INVALID_OPTION );
00137 }
00138 if( f.getNT() > 0 ) timeDependant = true;
00139
00140 return rhs.init(f, "acado_forward", NX*(NX+NU+1), 0, NU, NP, NDX, NOD)
00141 & diffs_rhs.init(g, "acado_backward", NX*(NX+NU+1) + NX + NX*NX + NX*NU + NU*NU, 0, NU, NP, NDX, NOD);
00142 }
00143
00144
00145 returnValue ForwardOverBackwardERKExport::setup( )
00146 {
00147 int sensGen;
00148 get( DYNAMIC_SENSITIVITY,sensGen );
00149 if ( (ExportSensitivityType)sensGen != FORWARD_OVER_BACKWARD ) ACADOERROR( RET_INVALID_OPTION );
00150
00151
00152 if( !equidistantControlGrid() ) ACADOERROR( RET_INVALID_OPTION );
00153
00154
00155 if( !is_symmetric ) ACADOERROR( RET_INVALID_OPTION );
00156
00157 LOG( LVL_DEBUG ) << "Preparing to export ForwardOverBackwardERKExport... " << endl;
00158
00159
00160 uint rhsDim = NX*(NX+NU+1) + NX + NX*NX + NX*NU + NU*NU;
00161 inputDim = rhsDim + NU + NOD;
00162 const uint rkOrder = getNumStages();
00163
00164 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00165
00166 ExportVariable Ah ( "A*h", DMatrix( AA )*=h );
00167 ExportVariable b4h( "b4*h", DMatrix( bb )*=h );
00168
00169 rk_index = ExportVariable( "rk_index", 1, 1, INT, ACADO_LOCAL, true );
00170 rk_eta = ExportVariable( "rk_eta", 1, inputDim );
00171
00172
00173 int useOMP;
00174 get(CG_USE_OPENMP, useOMP);
00175 ExportStruct structWspace;
00176 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00177
00178 rk_ttt.setup( "rk_ttt", 1, 1, REAL, structWspace, true );
00179 uint timeDep = 0;
00180 if( timeDependant ) timeDep = 1;
00181
00182 rk_xxx.setup("rk_xxx", 1, inputDim+timeDep, REAL, structWspace);
00183 rk_kkk.setup("rk_kkk", rkOrder, NX+NX*NX+NX*NU+NU*NU, REAL, structWspace);
00184 rk_forward_sweep.setup("rk_sweep1", 1, grid.getNumIntervals()*rkOrder*NX*(NX+NU+1), REAL, structWspace);
00185
00186 if ( useOMP )
00187 {
00188 ExportVariable auxVar;
00189
00190 auxVar = getAuxVariable();
00191 auxVar.setName( "odeAuxVar" );
00192 auxVar.setDataStruct( ACADO_LOCAL );
00193 rhs.setGlobalExportVariable( auxVar );
00194 diffs_rhs.setGlobalExportVariable( auxVar );
00195 }
00196
00197 ExportIndex run( "run1" );
00198
00199
00200 integrate = ExportFunction( "integrate", rk_eta, reset_int );
00201 integrate.setReturnValue( error_code );
00202 rk_eta.setDoc( "Working array to pass the input values and return the results." );
00203 reset_int.setDoc( "The internal memory of the integrator can be reset." );
00204 rk_index.setDoc( "Number of the shooting interval." );
00205 error_code.setDoc( "Status code of the integrator." );
00206 integrate.doc( "Performs the integration and sensitivity propagation for one shooting interval." );
00207 integrate.addIndex( run );
00208
00209 integrate.addStatement( rk_ttt == DMatrix(grid.getFirstTime()) );
00210
00211 if( inputDim > rhsDim ) {
00212
00213 DMatrix idX = eye<double>( NX );
00214 DMatrix zeroXU = zeros<double>( NX,NU );
00215 integrate.addStatement( rk_eta.getCols( NX,NX*(1+NX) ) == idX.makeVector().transpose() );
00216 integrate.addStatement( rk_eta.getCols( NX*(1+NX),NX*(1+NX+NU) ) == zeroXU.makeVector().transpose() );
00217
00218
00219 integrate.addStatement( rk_eta.getCols( NX*(2+NX+NU),rhsDim ) == zeros<double>( 1,NX*NX+NX*NU+NU*NU ) );
00220
00221 integrate.addStatement( rk_xxx.getCols( NX*(1+NX+NU),NX*(1+NX+NU)+NU+NOD ) == rk_eta.getCols( rhsDim,inputDim ) );
00222 }
00223 integrate.addLinebreak( );
00224
00225
00226 ExportForLoop loop = ExportForLoop( run, 0, grid.getNumIntervals() );
00227 for( uint run1 = 0; run1 < rkOrder; run1++ )
00228 {
00229 loop.addStatement( rk_xxx.getCols( 0,NX*(1+NX+NU) ) == rk_eta.getCols( 0,NX*(1+NX+NU) ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX*(1+NX+NU) ) );
00230
00231 loop.addStatement( rk_forward_sweep.getCols( run*rkOrder*NX*(1+NX+NU)+run1*NX*(1+NX+NU),run*rkOrder*NX*(1+NX+NU)+(run1+1)*NX*(1+NX+NU) ) == rk_xxx.getCols( 0,NX*(1+NX+NU) ) );
00232 if( timeDependant ) loop.addStatement( rk_xxx.getCol( NX*(NX+NU+1)+NU+NOD ) == rk_ttt + ((double)cc(run1))/grid.getNumIntervals() );
00233 loop.addFunctionCall( getNameRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00234 }
00235 loop.addStatement( rk_eta.getCols( 0,NX*(1+NX+NU) ) += b4h^rk_kkk.getCols( 0,NX*(1+NX+NU) ) );
00236 loop.addStatement( rk_ttt += DMatrix(1.0/grid.getNumIntervals()) );
00237
00238 integrate.addStatement( loop );
00239
00240 if( inputDim > rhsDim ) {
00241
00242 integrate.addStatement( rk_xxx.getCols( rhsDim,inputDim ) == rk_eta.getCols( rhsDim,inputDim ) );
00243 }
00244
00245 ExportForLoop loop2 = ExportForLoop( run, 0, grid.getNumIntervals() );
00246 for( uint run1 = 0; run1 < rkOrder; run1++ )
00247 {
00248
00249 loop2.addStatement( rk_xxx.getCols( 0,NX*(1+NX+NU) ) == rk_forward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX*(1+NX+NU)-(run1+1)*NX*(1+NX+NU),(grid.getNumIntervals()-run)*rkOrder*NX*(1+NX+NU)-run1*NX*(1+NX+NU) ) );
00250 loop2.addStatement( rk_xxx.getCols( NX*(1+NX+NU),rhsDim ) == rk_eta.getCols( NX*(1+NX+NU),rhsDim ) + Ah.getRow(run1)*rk_kkk );
00251 if( timeDependant ) loop2.addStatement( rk_xxx.getCol( inputDim ) == rk_ttt - ((double)cc(run1))/grid.getNumIntervals() );
00252 loop2.addFunctionCall( getNameDiffsRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00253 }
00254 loop2.addStatement( rk_eta.getCols( NX*(1+NX+NU),rhsDim ) += b4h^rk_kkk );
00255 loop2.addStatement( rk_ttt -= DMatrix(1.0/grid.getNumIntervals()) );
00256
00257 integrate.addStatement( loop2 );
00258
00259 integrate.addStatement( error_code == 0 );
00260
00261 LOG( LVL_DEBUG ) << "done" << endl;
00262
00263 return SUCCESSFUL_RETURN;
00264 }
00265
00266
00267
00268
00269
00270
00271 CLOSE_NAMESPACE_ACADO
00272
00273