00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/integrators/erk_export.hpp>
00035 #include <acado/code_generation/integrators/erk_adjoint_export.hpp>
00036
00037 using namespace std;
00038
00039 BEGIN_NAMESPACE_ACADO
00040
00041
00042
00043
00044
00045 AdjointERKExport::AdjointERKExport( UserInteraction* _userInteraction,
00046 const std::string& _commonHeaderName
00047 ) : ExplicitRungeKuttaExport( _userInteraction,_commonHeaderName )
00048 {
00049 }
00050
00051
00052 AdjointERKExport::AdjointERKExport( const AdjointERKExport& arg
00053 ) : ExplicitRungeKuttaExport( arg )
00054 {
00055 }
00056
00057
00058 AdjointERKExport::~AdjointERKExport( )
00059 {
00060 clear( );
00061 }
00062
00063
00064
00065 returnValue AdjointERKExport::setDifferentialEquation( const Expression& rhs_ )
00066 {
00067 int sensGen;
00068 get( DYNAMIC_SENSITIVITY,sensGen );
00069
00070 OnlineData dummy0;
00071 Control dummy1;
00072 DifferentialState dummy2;
00073 AlgebraicState dummy3;
00074 DifferentialStateDerivative dummy4;
00075 dummy0.clearStaticCounters();
00076 dummy1.clearStaticCounters();
00077 dummy2.clearStaticCounters();
00078 dummy3.clearStaticCounters();
00079 dummy4.clearStaticCounters();
00080
00081 x = DifferentialState("", NX, 1);
00082 dx = DifferentialStateDerivative("", NDX, 1);
00083 z = AlgebraicState("", NXA, 1);
00084 u = Control("", NU, 1);
00085 od = OnlineData("", NOD, 1);
00086
00087 if( NDX > 0 && NDX != NX ) {
00088 return ACADOERROR( RET_INVALID_OPTION );
00089 }
00090 if( rhs_.getNumRows() != (NX+NXA) ) {
00091 return ACADOERROR( RET_INVALID_OPTION );
00092 }
00093
00094 DifferentialEquation f, f_ODE;
00095
00096 f_ODE << rhs_;
00097 if( f_ODE.getNDX() > 0 ) {
00098 return ACADOERROR( RET_INVALID_OPTION );
00099 }
00100
00101 if( (ExportSensitivityType)sensGen == BACKWARD ) {
00102 DifferentialState lx("", NX,1), lu("", NU,1);
00103
00104 f << backwardDerivative(rhs_, x, lx);
00105 f << backwardDerivative(rhs_, u, lx);
00106 }
00107 else {
00108 return ACADOERROR( RET_INVALID_OPTION );
00109 }
00110 if( f.getNT() > 0 ) timeDependant = true;
00111
00112 return rhs.init(f_ODE, "acado_rhs", NX, 0, NU, NP, NDX, NOD)
00113 & diffs_rhs.init(f, "acado_rhs_back", 2*NX + NU, 0, NU, NP, NDX, NOD);
00114 }
00115
00116
00117 returnValue AdjointERKExport::setup( )
00118 {
00119 int sensGen;
00120 get( DYNAMIC_SENSITIVITY,sensGen );
00121 if ( (ExportSensitivityType)sensGen != BACKWARD ) ACADOERROR( RET_INVALID_OPTION );
00122
00123
00124 if( !equidistantControlGrid() ) ACADOERROR( RET_INVALID_OPTION );
00125
00126
00127 if( !is_symmetric ) ACADOERROR( RET_INVALID_OPTION );
00128
00129 LOG( LVL_DEBUG ) << "Preparing to export AdjointERKExport... " << endl;
00130
00131
00132 uint rhsDim = 2*NX+NU;
00133 inputDim = 2*NX+NU + NU + NOD;
00134 const uint rkOrder = getNumStages();
00135
00136 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00137
00138 ExportVariable Ah ( "A*h", DMatrix( AA )*=h );
00139 ExportVariable b4h( "b4*h", DMatrix( bb )*=h );
00140
00141 rk_index = ExportVariable( "rk_index", 1, 1, INT, ACADO_LOCAL, true );
00142 rk_eta = ExportVariable( "rk_eta", 1, inputDim );
00143
00144
00145 int useOMP;
00146 get(CG_USE_OPENMP, useOMP);
00147 ExportStruct structWspace;
00148 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00149
00150 rk_ttt.setup( "rk_ttt", 1, 1, REAL, structWspace, true );
00151 uint timeDep = 0;
00152 if( timeDependant ) timeDep = 1;
00153
00154 rk_xxx.setup("rk_xxx", 1, inputDim+timeDep, REAL, structWspace);
00155 rk_kkk.setup("rk_kkk", rkOrder, NX+NU, REAL, structWspace);
00156 rk_forward_sweep.setup("rk_sweep1", 1, grid.getNumIntervals()*rkOrder*NX, REAL, structWspace);
00157
00158 if ( useOMP )
00159 {
00160 ExportVariable auxVar;
00161
00162 auxVar = getAuxVariable();
00163 auxVar.setName( "odeAuxVar" );
00164 auxVar.setDataStruct( ACADO_LOCAL );
00165 rhs.setGlobalExportVariable( auxVar );
00166 diffs_rhs.setGlobalExportVariable( auxVar );
00167 }
00168
00169 ExportIndex run( "run1" );
00170
00171
00172 integrate = ExportFunction( "integrate", rk_eta, reset_int );
00173 integrate.setReturnValue( error_code );
00174 rk_eta.setDoc( "Working array to pass the input values and return the results." );
00175 reset_int.setDoc( "The internal memory of the integrator can be reset." );
00176 rk_index.setDoc( "Number of the shooting interval." );
00177 error_code.setDoc( "Status code of the integrator." );
00178 integrate.doc( "Performs the integration and sensitivity propagation for one shooting interval." );
00179 integrate.addIndex( run );
00180
00181 integrate.addStatement( rk_ttt == DMatrix(grid.getFirstTime()) );
00182
00183 if( inputDim > rhsDim ) {
00184
00185 integrate.addStatement( rk_eta.getCols( 2*NX,2*NX+NU ) == zeros<double>( 1,NU ) );
00186
00187 integrate.addStatement( rk_xxx.getCols( NX,NX+NU+NOD ) == rk_eta.getCols( rhsDim,inputDim ) );
00188 }
00189 integrate.addLinebreak( );
00190
00191
00192 ExportForLoop loop = ExportForLoop( run, 0, grid.getNumIntervals() );
00193 for( uint run1 = 0; run1 < rkOrder; run1++ )
00194 {
00195 loop.addStatement( rk_xxx.getCols( 0,NX ) == rk_eta.getCols( 0,NX ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX ) );
00196
00197 loop.addStatement( rk_forward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+run1*NX+NX ) == rk_xxx.getCols( 0,NX ) );
00198 if( timeDependant ) loop.addStatement( rk_xxx.getCol( NX+NU+NOD ) == rk_ttt + ((double)cc(run1))/grid.getNumIntervals() );
00199 loop.addFunctionCall( getNameRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00200 }
00201 loop.addStatement( rk_eta.getCols( 0,NX ) += b4h^rk_kkk.getCols( 0,NX ) );
00202 loop.addStatement( rk_ttt += DMatrix(1.0/grid.getNumIntervals()) );
00203
00204 integrate.addStatement( loop );
00205
00206
00207
00208
00209 if( inputDim > rhsDim ) {
00210
00211 integrate.addStatement( rk_xxx.getCols( rhsDim,inputDim ) == rk_eta.getCols( rhsDim,inputDim ) );
00212 }
00213
00214 ExportForLoop loop2 = ExportForLoop( run, 0, grid.getNumIntervals() );
00215 for( uint run1 = 0; run1 < rkOrder; run1++ )
00216 {
00217
00218
00219 loop2.addStatement( rk_xxx.getCols( 0,NX ) == rk_forward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX-run1*NX-NX,(grid.getNumIntervals()-run)*rkOrder*NX-run1*NX ) );
00220
00221 loop2.addStatement( rk_xxx.getCols( NX,2*NX+NU ) == rk_eta.getCols( NX,2*NX+NU ) + Ah.getRow(run1)*rk_kkk );
00222 if( timeDependant ) loop2.addStatement( rk_xxx.getCol( inputDim ) == rk_ttt - ((double)cc(run1))/grid.getNumIntervals() );
00223 loop2.addFunctionCall( getNameDiffsRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00224
00225
00226
00227 }
00228 loop2.addStatement( rk_eta.getCols( NX,2*NX+NU ) += b4h^rk_kkk );
00229 loop2.addStatement( rk_ttt -= DMatrix(1.0/grid.getNumIntervals()) );
00230
00231 integrate.addStatement( loop2 );
00232
00233 integrate.addStatement( error_code == 0 );
00234
00235 LOG( LVL_DEBUG ) << "done" << endl;
00236
00237 return SUCCESSFUL_RETURN;
00238 }
00239
00240
00241 returnValue AdjointERKExport::getDataDeclarations( ExportStatementBlock& declarations,
00242 ExportStruct dataStruct
00243 ) const
00244 {
00245 ExplicitRungeKuttaExport::getDataDeclarations( declarations, dataStruct );
00246
00247 declarations.addDeclaration( rk_forward_sweep,dataStruct );
00248
00249 return SUCCESSFUL_RETURN;
00250 }
00251
00252
00253 returnValue AdjointERKExport::getCode( ExportStatementBlock& code
00254 )
00255 {
00256 int useOMP;
00257 get(CG_USE_OPENMP, useOMP);
00258 if ( useOMP )
00259 {
00260 getDataDeclarations( code, ACADO_LOCAL );
00261
00262 code << "#pragma omp threadprivate( "
00263 << getAuxVariable().getFullName() << ", "
00264 << rk_xxx.getFullName() << ", "
00265 << rk_ttt.getFullName() << ", "
00266 << rk_kkk.getFullName() << ", "
00267 << rk_forward_sweep.getFullName()
00268 << " )\n\n";
00269 }
00270
00271 int sensGen;
00272 get( DYNAMIC_SENSITIVITY,sensGen );
00273 if( exportRhs ) {
00274 code.addFunction( rhs );
00275 code.addFunction( diffs_rhs );
00276 }
00277
00278 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00279 code.addComment(std::string("Fixed step size:") + toString(h));
00280 code.addFunction( integrate );
00281
00282 return SUCCESSFUL_RETURN;
00283 }
00284
00285
00286
00287
00288
00289
00290 CLOSE_NAMESPACE_ACADO
00291
00292