00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/integrators/erk_export.hpp>
00035 #include <acado/code_generation/integrators/erk_3sweep_export.hpp>
00036
00037 using namespace std;
00038
00039 BEGIN_NAMESPACE_ACADO
00040
00041
00042
00043
00044
00045 ThreeSweepsERKExport::ThreeSweepsERKExport( UserInteraction* _userInteraction,
00046 const std::string& _commonHeaderName
00047 ) : AdjointERKExport( _userInteraction,_commonHeaderName )
00048 {
00049 }
00050
00051
00052 ThreeSweepsERKExport::ThreeSweepsERKExport( const ThreeSweepsERKExport& arg
00053 ) : AdjointERKExport( arg )
00054 {
00055 }
00056
00057
00058 ThreeSweepsERKExport::~ThreeSweepsERKExport( )
00059 {
00060 clear( );
00061 }
00062
00063
00064
00065 returnValue ThreeSweepsERKExport::setDifferentialEquation( const Expression& rhs_ )
00066 {
00067 int sensGen;
00068 get( DYNAMIC_SENSITIVITY,sensGen );
00069
00070 OnlineData dummy0;
00071 Control dummy1;
00072 DifferentialState dummy2;
00073 AlgebraicState dummy3;
00074 DifferentialStateDerivative dummy4;
00075 dummy0.clearStaticCounters();
00076 dummy1.clearStaticCounters();
00077 dummy2.clearStaticCounters();
00078 dummy3.clearStaticCounters();
00079 dummy4.clearStaticCounters();
00080
00081 x = DifferentialState("", NX, 1);
00082 dx = DifferentialStateDerivative("", NDX, 1);
00083 z = AlgebraicState("", NXA, 1);
00084 u = Control("", NU, 1);
00085 od = OnlineData("", NOD, 1);
00086
00087 if( NDX > 0 && NDX != NX ) {
00088 return ACADOERROR( RET_INVALID_OPTION );
00089 }
00090 if( rhs_.getNumRows() != (NX+NXA) ) {
00091 return ACADOERROR( RET_INVALID_OPTION );
00092 }
00093
00094 DifferentialEquation f, g, h, f_ODE;
00095
00096 f_ODE << rhs_;
00097 if( f_ODE.getNDX() > 0 ) {
00098 return ACADOERROR( RET_INVALID_OPTION );
00099 }
00100
00101 uint numX = NX*(NX+1)/2.0;
00102 uint numU = NU*(NU+1)/2.0;
00103 uint numZ = (NX+NU)*(NX+NU+1)/2.0;
00104 if( (ExportSensitivityType)sensGen == THREE_SWEEPS ) {
00105
00106
00107 f << rhs_;
00108
00109
00110
00111
00112 DifferentialState lx("", NX,1);
00113
00114 Expression tmp = backwardDerivative(rhs_, x, lx);
00115 g << tmp;
00116
00117
00118
00119
00120 DifferentialState Gx("", NX,NX), Gu("", NX,NU);
00121 DifferentialState H("", numZ,1);
00122
00123 Expression S = Gx;
00124 S.appendCols(Gu);
00125 Expression arg;
00126 arg << x;
00127 arg << u;
00128
00129
00130 Expression S_tmp = S;
00131 S_tmp.appendRows(zeros<double>(NU,NX).appendCols(eye<double>(NU)));
00132
00133 Expression dfS;
00134 Expression h_tmp = symmetricDerivative( rhs_, arg, S_tmp, lx, &dfS );
00135 Expression VDE_X;
00136 Expression VDE_U;
00137 for( uint i = 0; i < NX; i++ ) {
00138 VDE_X.appendCols(dfS.getCol(i));
00139 }
00140 for( uint i = NX; i < NX+NU; i++ ) {
00141 VDE_U.appendCols(dfS.getCol(i));
00142 }
00143 h << VDE_X;
00144 h << VDE_U;
00145 h << returnLowerTriangular( h_tmp );
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00160
00161
00162
00163
00164 }
00165 else {
00166 return ACADOERROR( RET_INVALID_OPTION );
00167 }
00168 if( f.getNT() > 0 ) timeDependant = true;
00169
00170 return rhs.init(f, "acado_forward", NX, 0, NU, NP, NDX, NOD)
00171 & diffs_rhs.init(g, "acado_backward", 2*NX, 0, NU, NP, NDX, NOD)
00172 & diffs_sweep3.init(h, "acado_forward_sweep3", 2*NX + NX*(NX+NU) + numX + NX*NU + numU, 0, NU, NP, NDX, NOD);
00173 }
00174
00175
00176 returnValue ThreeSweepsERKExport::setup( )
00177 {
00178 int sensGen;
00179 get( DYNAMIC_SENSITIVITY,sensGen );
00180 if ( (ExportSensitivityType)sensGen != THREE_SWEEPS ) ACADOERROR( RET_INVALID_OPTION );
00181
00182
00183 if( !equidistantControlGrid() ) ACADOERROR( RET_INVALID_OPTION );
00184
00185
00186 if( !is_symmetric ) ACADOERROR( RET_INVALID_OPTION );
00187
00188 LOG( LVL_DEBUG ) << "Preparing to export ThreeSweepsERKExport... " << endl;
00189
00190
00191 uint numX = NX*(NX+1)/2.0;
00192 uint numU = NU*(NU+1)/2.0;
00193 uint rhsDim = NX + NX + NX*(NX+NU) + numX + NX*NU + numU;
00194 inputDim = rhsDim + NU + NOD;
00195 const uint rkOrder = getNumStages();
00196
00197 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00198
00199 ExportVariable Ah ( "A*h", DMatrix( AA )*=h );
00200 ExportVariable b4h( "b4*h", DMatrix( bb )*=h );
00201
00202 rk_index = ExportVariable( "rk_index", 1, 1, INT, ACADO_LOCAL, true );
00203 rk_eta = ExportVariable( "rk_eta", 1, inputDim );
00204
00205
00206 int useOMP;
00207 get(CG_USE_OPENMP, useOMP);
00208 ExportStruct structWspace;
00209 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00210
00211 rk_ttt.setup( "rk_ttt", 1, 1, REAL, structWspace, true );
00212 uint timeDep = 0;
00213 if( timeDependant ) timeDep = 1;
00214
00215 rk_xxx.setup("rk_xxx", 1, inputDim+timeDep, REAL, structWspace);
00216 uint numK = NX*(NX+NU)+numX+NX*NU+numU;
00217 rk_kkk.setup("rk_kkk", rkOrder, numK, REAL, structWspace);
00218 rk_forward_sweep.setup("rk_sweep1", 1, grid.getNumIntervals()*rkOrder*NX, REAL, structWspace);
00219 rk_backward_sweep.setup("rk_sweep2", 1, grid.getNumIntervals()*rkOrder*NX, REAL, structWspace);
00220
00221 if ( useOMP )
00222 {
00223 ExportVariable auxVar;
00224
00225 auxVar = diffs_rhs.getGlobalExportVariable();
00226 auxVar.setName( "odeAuxVar" );
00227 auxVar.setDataStruct( ACADO_LOCAL );
00228 diffs_rhs.setGlobalExportVariable( auxVar );
00229 }
00230
00231 ExportIndex run( "run1" );
00232
00233
00234 integrate = ExportFunction( "integrate", rk_eta, reset_int );
00235 integrate.setReturnValue( error_code );
00236 rk_eta.setDoc( "Working array to pass the input values and return the results." );
00237 reset_int.setDoc( "The internal memory of the integrator can be reset." );
00238 rk_index.setDoc( "Number of the shooting interval." );
00239 error_code.setDoc( "Status code of the integrator." );
00240 integrate.doc( "Performs the integration and sensitivity propagation for one shooting interval." );
00241 integrate.addIndex( run );
00242
00243 integrate.addStatement( rk_ttt == DMatrix(grid.getFirstTime()) );
00244
00245 if( inputDim > rhsDim ) {
00246
00247
00248 DMatrix idX = eye<double>( NX );
00249 DMatrix zeroXU = zeros<double>( NX,NU );
00250 integrate.addStatement( rk_eta.getCols( 2*NX,NX*(2+NX) ) == idX.makeVector().transpose() );
00251 integrate.addStatement( rk_eta.getCols( NX*(2+NX),NX*(2+NX+NU) ) == zeroXU.makeVector().transpose() );
00252
00253 integrate.addStatement( rk_eta.getCols( NX*(2+NX+NU),rhsDim ) == zeros<double>( 1,numX+NX*NU+numU ) );
00254
00255 integrate.addStatement( rk_xxx.getCols( NX,NX+NU+NOD ) == rk_eta.getCols( rhsDim,inputDim ) );
00256 }
00257 integrate.addLinebreak( );
00258
00259
00260 ExportForLoop loop = ExportForLoop( run, 0, grid.getNumIntervals() );
00261 for( uint run1 = 0; run1 < rkOrder; run1++ )
00262 {
00263 loop.addStatement( rk_xxx.getCols( 0,NX ) == rk_eta.getCols( 0,NX ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX ) );
00264
00265 loop.addStatement( rk_forward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+(run1+1)*NX ) == rk_xxx.getCols( 0,NX ) );
00266 if( timeDependant ) loop.addStatement( rk_xxx.getCol( NX+NU+NOD ) == rk_ttt + ((double)cc(run1))/grid.getNumIntervals() );
00267 loop.addFunctionCall( getNameRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00268 }
00269 loop.addStatement( rk_eta.getCols( 0,NX ) += b4h^rk_kkk.getCols( 0,NX ) );
00270 loop.addStatement( rk_ttt += DMatrix(1.0/grid.getNumIntervals()) );
00271
00272 integrate.addStatement( loop );
00273
00274 if( inputDim > rhsDim ) {
00275
00276 integrate.addStatement( rk_xxx.getCols( 2*NX,2*NX+NU+NOD ) == rk_eta.getCols( rhsDim,inputDim ) );
00277 }
00278
00279 ExportForLoop loop2 = ExportForLoop( run, 0, grid.getNumIntervals() );
00280 for( uint run1 = 0; run1 < rkOrder; run1++ )
00281 {
00282
00283 loop2.addStatement( rk_xxx.getCols( 0,NX ) == rk_forward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX-(run1+1)*NX,(grid.getNumIntervals()-run)*rkOrder*NX-run1*NX ) );
00284 loop2.addStatement( rk_xxx.getCols( NX,2*NX ) == rk_eta.getCols( NX,2*NX ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX ) );
00285
00286 loop2.addStatement( rk_backward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+(run1+1)*NX ) == rk_xxx.getCols( NX,2*NX ) );
00287 if( timeDependant ) loop2.addStatement( rk_xxx.getCol( 2*NX+NU+NOD ) == rk_ttt - ((double)cc(run1))/grid.getNumIntervals() );
00288 loop2.addFunctionCall( getNameDiffsRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00289 }
00290 loop2.addStatement( rk_eta.getCols( NX,2*NX ) += b4h^rk_kkk.getCols( 0,NX ) );
00291 loop2.addStatement( rk_ttt -= DMatrix(1.0/grid.getNumIntervals()) );
00292
00293 integrate.addStatement( loop2 );
00294
00295 if( inputDim > rhsDim ) {
00296
00297 integrate.addStatement( rk_xxx.getCols( rhsDim,inputDim ) == rk_eta.getCols( rhsDim,inputDim ) );
00298 }
00299
00300 ExportForLoop loop3 = ExportForLoop( run, 0, grid.getNumIntervals() );
00301 for( uint run1 = 0; run1 < rkOrder; run1++ )
00302 {
00303
00304 loop3.addStatement( rk_xxx.getCols( 0,NX ) == rk_forward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+(run1+1)*NX ) );
00305
00306 loop3.addStatement( rk_xxx.getCols( NX,2*NX ) == rk_backward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX-(run1+1)*NX,(grid.getNumIntervals()-run)*rkOrder*NX-run1*NX ) );
00307 loop3.addStatement( rk_xxx.getCols( 2*NX,rhsDim ) == rk_eta.getCols( 2*NX,rhsDim ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX*(NX+NU)+numX+NX*NU+numU ) );
00308 if( timeDependant ) loop3.addStatement( rk_xxx.getCol( inputDim ) == rk_ttt + ((double)cc(run1))/grid.getNumIntervals() );
00309 loop3.addFunctionCall( diffs_sweep3.getName(),rk_xxx,rk_kkk.getAddress(run1,0) );
00310 }
00311 loop3.addStatement( rk_eta.getCols( 2*NX,rhsDim ) += b4h^rk_kkk.getCols( 0,NX*(NX+NU)+numX+NX*NU+numU ) );
00312 loop3.addStatement( rk_ttt += DMatrix(1.0/grid.getNumIntervals()) );
00313
00314 integrate.addStatement( loop3 );
00315
00316 integrate.addStatement( error_code == 0 );
00317
00318 LOG( LVL_DEBUG ) << "done" << endl;
00319
00320 return SUCCESSFUL_RETURN;
00321 }
00322
00323
00324 returnValue ThreeSweepsERKExport::getDataDeclarations( ExportStatementBlock& declarations,
00325 ExportStruct dataStruct
00326 ) const
00327 {
00328 AdjointERKExport::getDataDeclarations( declarations, dataStruct );
00329
00330 declarations.addDeclaration( rk_backward_sweep,dataStruct );
00331
00332 return SUCCESSFUL_RETURN;
00333 }
00334
00335
00336 returnValue ThreeSweepsERKExport::getCode( ExportStatementBlock& code
00337 )
00338 {
00339 int useOMP;
00340 get(CG_USE_OPENMP, useOMP);
00341 if ( useOMP )
00342 {
00343 getDataDeclarations( code, ACADO_LOCAL );
00344
00345 code << "#pragma omp threadprivate( "
00346 << getAuxVariable().getFullName() << ", "
00347 << rk_xxx.getFullName() << ", "
00348 << rk_ttt.getFullName() << ", "
00349 << rk_kkk.getFullName() << ", "
00350 << rk_forward_sweep.getFullName() << ", "
00351 << rk_backward_sweep.getFullName()
00352 << " )\n\n";
00353 }
00354
00355 int sensGen;
00356 get( DYNAMIC_SENSITIVITY,sensGen );
00357 if( exportRhs ) {
00358 code.addFunction( rhs );
00359 code.addFunction( diffs_rhs );
00360 code.addFunction( diffs_sweep3 );
00361 }
00362
00363 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00364 code.addComment(std::string("Fixed step size:") + toString(h));
00365 code.addFunction( integrate );
00366
00367 return SUCCESSFUL_RETURN;
00368 }
00369
00370
00371 Expression ThreeSweepsERKExport::returnLowerTriangular( const Expression& expr ) {
00372
00373 ASSERT( expr.getNumRows() == expr.getNumCols() );
00374
00375 Expression new_expr;
00376 for( uint i = 0; i < expr.getNumRows(); i++ ) {
00377 for( uint j = 0; j <= i; j++ ) {
00378 new_expr << expr(i,j);
00379 }
00380 }
00381 return new_expr;
00382 }
00383
00384
00385 Expression ThreeSweepsERKExport::symmetricDoubleProduct( const Expression& expr, const Expression& arg ) {
00386
00387
00388 uint dim = arg.getNumCols();
00389 uint dim2 = arg.getNumRows();
00390
00391 IntermediateState inter_res = zeros<double>(dim2,dim);
00392 for( uint i = 0; i < dim; i++ ) {
00393 for( uint k1 = 0; k1 < dim2; k1++ ) {
00394 for( uint k2 = 0; k2 <= k1; k2++ ) {
00395 inter_res(k1,i) += expr(k1,k2)*arg(k2,i);
00396 }
00397 for( uint k2 = k1+1; k2 < dim2; k2++ ) {
00398 inter_res(k1,i) += expr(k2,k1)*arg(k2,i);
00399 }
00400 }
00401 }
00402
00403 Expression new_expr;
00404 for( uint i = 0; i < dim; i++ ) {
00405 for( uint j = 0; j <= i; j++ ) {
00406 Expression new_tmp = 0;
00407 for( uint k1 = 0; k1 < dim2; k1++ ) {
00408 new_tmp = new_tmp+arg(k1,i)*inter_res(k1,j);
00409 }
00410 new_expr << new_tmp;
00411 }
00412 }
00413 return new_expr;
00414
00415 }
00416
00417
00418 ExportVariable ThreeSweepsERKExport::getAuxVariable() const
00419 {
00420 ExportVariable max;
00421 max = rhs.getGlobalExportVariable();
00422 if( diffs_rhs.getGlobalExportVariable().getDim() > max.getDim() ) {
00423 max = diffs_rhs.getGlobalExportVariable();
00424 }
00425 if( diffs_sweep3.getGlobalExportVariable().getDim() > max.getDim() ) {
00426 max = diffs_sweep3.getGlobalExportVariable();
00427 }
00428 return max;
00429 }
00430
00431
00432
00433
00434
00435 CLOSE_NAMESPACE_ACADO
00436
00437