00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/integrators/discrete_export.hpp>
00035
00036 using namespace std;
00037
00038 BEGIN_NAMESPACE_ACADO
00039
00040
00041
00042
00043
00044 DiscreteTimeExport::DiscreteTimeExport( UserInteraction* _userInteraction,
00045 const std::string& _commonHeaderName
00046 ) : IntegratorExport( _userInteraction,_commonHeaderName )
00047 {
00048 }
00049
00050
00051 DiscreteTimeExport::DiscreteTimeExport( const DiscreteTimeExport& arg
00052 ) : IntegratorExport( arg )
00053 {
00054 copy( arg );
00055 }
00056
00057
00058 DiscreteTimeExport::~DiscreteTimeExport( )
00059 {
00060 clear( );
00061 }
00062
00063
00064 returnValue DiscreteTimeExport::setDifferentialEquation( const Expression& rhs_ )
00065 {
00066 if( rhs_.getDim() > 0 ) {
00067 OnlineData dummy0;
00068 Control dummy1;
00069 DifferentialState dummy2;
00070 AlgebraicState dummy3;
00071 DifferentialStateDerivative dummy4;
00072 dummy0.clearStaticCounters();
00073 dummy1.clearStaticCounters();
00074 dummy2.clearStaticCounters();
00075 dummy3.clearStaticCounters();
00076 dummy4.clearStaticCounters();
00077
00078 NX2 = rhs_.getDim() - NXA;
00079 x = DifferentialState("", NX1+NX2, 1);
00080 z = AlgebraicState("", NXA, 1);
00081 dx = DifferentialStateDerivative("", NDX, 1);
00082 u = Control("", NU, 1);
00083 od = OnlineData("", NOD, 1);
00084
00085 DifferentialEquation f;
00086 f << rhs_;
00087
00088 DifferentialEquation g;
00089 for( uint i = 0; i < rhs_.getDim(); i++ ) {
00090 g << forwardDerivative( rhs_(i), x );
00091 g << forwardDerivative( rhs_(i), u );
00092
00093 }
00094
00095 return (rhs.init(f, "acado_rhs", NX, NXA, NU, NP, NDX, NOD) &
00096 diffs_rhs.init(g, "acado_diffs", NX, NXA, NU, NP, NDX, NOD));
00097 }
00098 return SUCCESSFUL_RETURN;
00099 }
00100
00101
00102 returnValue DiscreteTimeExport::getDataDeclarations( ExportStatementBlock& declarations,
00103 ExportStruct dataStruct
00104 ) const
00105 {
00106 ExportVariable max = getAuxVariable();
00107 declarations.addDeclaration( max,dataStruct );
00108 declarations.addDeclaration( rk_xxx,dataStruct );
00109 declarations.addDeclaration( reset_int,dataStruct );
00110
00111 declarations.addDeclaration( rk_diffsPrev1,dataStruct );
00112 declarations.addDeclaration( rk_diffsPrev2,dataStruct );
00113 declarations.addDeclaration( rk_diffsPrev3,dataStruct );
00114
00115 declarations.addDeclaration( rk_diffsNew1,dataStruct );
00116 declarations.addDeclaration( rk_diffsNew2,dataStruct );
00117 declarations.addDeclaration( rk_diffsNew3,dataStruct );
00118 declarations.addDeclaration( rk_diffsTemp3,dataStruct );
00119
00120 return SUCCESSFUL_RETURN;
00121 }
00122
00123
00124 returnValue DiscreteTimeExport::getFunctionDeclarations( ExportStatementBlock& declarations
00125 ) const
00126 {
00127 declarations.addDeclaration( integrate );
00128
00129 if( NX2 != NX ) declarations.addDeclaration( fullRhs );
00130 else declarations.addDeclaration( rhs );
00131
00132 return SUCCESSFUL_RETURN;
00133 }
00134
00135
00136 returnValue DiscreteTimeExport::setup( )
00137 {
00138 int sensGen;
00139 get( DYNAMIC_SENSITIVITY,sensGen );
00140 if ( (ExportSensitivityType)sensGen != FORWARD ) ACADOERROR( RET_INVALID_OPTION );
00141
00142 int useOMP;
00143 get(CG_USE_OPENMP, useOMP);
00144 ExportStruct structWspace;
00145 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00146
00147 LOG( LVL_DEBUG ) << "Preparing to export DiscreteTimeExport... " << endl;
00148
00149 ExportIndex run( "run" );
00150 ExportIndex i( "i" );
00151 ExportIndex j( "j" );
00152 ExportIndex k( "k" );
00153 ExportIndex tmp_index("tmp_index");
00154 diffsDim = NX*(NX+NU);
00155 inputDim = NX*(NX+NU+1) + NU + NOD;
00156
00157 rk_index = ExportVariable( "rk_index", 1, 1, INT, ACADO_LOCAL, true );
00158 rk_eta = ExportVariable( "rk_eta", 1, inputDim, REAL );
00159 if( equidistantControlGrid() ) {
00160 integrate = ExportFunction( "integrate", rk_eta, reset_int );
00161 }
00162 else {
00163 integrate = ExportFunction( "integrate", rk_eta, reset_int, rk_index );
00164 }
00165 integrate.setReturnValue( error_code );
00166 rk_eta.setDoc( "Working array to pass the input values and return the results." );
00167 reset_int.setDoc( "The internal memory of the integrator can be reset." );
00168 rk_index.setDoc( "Number of the shooting interval." );
00169 error_code.setDoc( "Status code of the integrator." );
00170 integrate.doc( "Performs the integration and sensitivity propagation for one shooting interval." );
00171 integrate.addIndex( run );
00172 integrate.addIndex( i );
00173 integrate.addIndex( j );
00174 integrate.addIndex( k );
00175 integrate.addIndex( tmp_index );
00176 rhs_in = ExportVariable( "x", inputDim-diffsDim, 1, REAL, ACADO_LOCAL );
00177 rhs_out = ExportVariable( "f", NX, 1, REAL, ACADO_LOCAL );
00178 fullRhs = ExportFunction( "full_rhs", rhs_in, rhs_out );
00179 rhs_in.setDoc( "The state and parameter values." );
00180 rhs_out.setDoc( "Right-hand side evaluation." );
00181 fullRhs.doc( "Evaluates the right-hand side of the full model." );
00182 rk_xxx = ExportVariable( "rk_xxx", 1, inputDim-diffsDim, REAL, structWspace );
00183 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) {
00184 rk_diffsPrev1 = ExportVariable( "rk_diffsPrev1", NX1, NX1+NU, REAL, structWspace );
00185 rk_diffsPrev2 = ExportVariable( "rk_diffsPrev2", NX2, NX1+NX2+NU, REAL, structWspace );
00186 rk_diffsPrev3 = ExportVariable( "rk_diffsPrev3", NX3, NX+NU, REAL, structWspace );
00187 }
00188 rk_diffsNew1 = ExportVariable( "rk_diffsNew1", NX1, NX1+NU, REAL, structWspace );
00189 rk_diffsNew2 = ExportVariable( "rk_diffsNew2", NX2, NX1+NX2+NU, REAL, structWspace );
00190 rk_diffsNew3 = ExportVariable( "rk_diffsNew3", NX3, NX+NU, REAL, structWspace );
00191 rk_diffsTemp3 = ExportVariable( "rk_diffsTemp3", NX3, NX1+NX2+NU, REAL, structWspace );
00192
00193 ExportVariable numInt( "numInts", 1, 1, INT );
00194 if( !equidistantControlGrid() ) {
00195 ExportVariable numStepsV( "numSteps", numSteps, STATIC_CONST_INT );
00196 integrate.addStatement( std::string( "int " ) + numInt.getName() + " = " + numStepsV.getName() + "[" + rk_index.getName() + "];\n" );
00197 }
00198
00199 integrate.addStatement( rk_xxx.getCols( NX,inputDim-diffsDim ) == rk_eta.getCols( NX+diffsDim,inputDim ) );
00200 integrate.addLinebreak( );
00201
00202 if( NX1 > 0 ) {
00203 for( uint i1 = 0; i1 < NX1; i1++ ) {
00204 for( uint i2 = 0; i2 < NX1; i2++ ) {
00205 integrate.addStatement( rk_diffsNew1.getSubMatrix(i1,i1+1,i2,i2+1) == A11(i1,i2) );
00206 }
00207 for( uint i2 = 0; i2 < NU; i2++ ) {
00208 integrate.addStatement( rk_diffsNew1.getSubMatrix(i1,i1+1,NX1+i2,NX1+i2+1) == B11(i1,i2) );
00209 }
00210 }
00211 }
00212
00213 if( NX1 > 0 ) {
00214 for( uint i1 = 0; i1 < NX3; i1++ ) {
00215 for( uint i2 = 0; i2 < NX3; i2++ ) {
00216 integrate.addStatement( rk_diffsNew3.getSubMatrix(i1,i1+1,NX-NX3+i2,NX-NX3+i2+1) == A33(i1,i2) );
00217 }
00218 }
00219 }
00220 integrate.addLinebreak( );
00221
00222
00223 ExportForLoop tmpLoop( run, 0, grid.getNumIntervals() );
00224 ExportStatementBlock *loop;
00225 if( equidistantControlGrid() ) {
00226 loop = &tmpLoop;
00227 }
00228 else {
00229 loop = &integrate;
00230 loop->addStatement( std::string("for(") + run.getName() + " = 0; " + run.getName() + " < " + numInt.getName() + "; " + run.getName() + "++ ) {\n" );
00231 }
00232
00233 loop->addStatement( rk_xxx.getCols( 0,NX ) == rk_eta.getCols( 0,NX ) );
00234
00235 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) {
00236
00237 loop->addStatement( std::string("if( run > 0 ) {\n") );
00238 if( NX1 > 0 ) {
00239 ExportForLoop loopTemp1( i,0,NX1 );
00240 loopTemp1.addStatement( rk_diffsPrev1.getSubMatrix( i,i+1,0,NX1 ) == rk_eta.getCols( i*NX+NX+NXA,i*NX+NX+NXA+NX1 ) );
00241 if( NU > 0 ) loopTemp1.addStatement( rk_diffsPrev1.getSubMatrix( i,i+1,NX1,NX1+NU ) == rk_eta.getCols( i*NU+(NX+NXA)*(NX+1),i*NU+(NX+NXA)*(NX+1)+NU ) );
00242 loop->addStatement( loopTemp1 );
00243 }
00244 if( NX2 > 0 ) {
00245 ExportForLoop loopTemp2( i,0,NX2 );
00246 loopTemp2.addStatement( rk_diffsPrev2.getSubMatrix( i,i+1,0,NX1+NX2 ) == rk_eta.getCols( i*NX+NX+NXA+NX1*NX,i*NX+NX+NXA+NX1*NX+NX1+NX2 ) );
00247 if( NU > 0 ) loopTemp2.addStatement( rk_diffsPrev2.getSubMatrix( i,i+1,NX1+NX2,NX1+NX2+NU ) == rk_eta.getCols( i*NU+(NX+NXA)*(NX+1)+NX1*NU,i*NU+(NX+NXA)*(NX+1)+NX1*NU+NU ) );
00248 loop->addStatement( loopTemp2 );
00249 }
00250 if( NX3 > 0 ) {
00251 ExportForLoop loopTemp3( i,0,NX3 );
00252 loopTemp3.addStatement( rk_diffsPrev3.getSubMatrix( i,i+1,0,NX ) == rk_eta.getCols( i*NX+NX+NXA+(NX1+NX2)*NX,i*NX+NX+NXA+(NX1+NX2)*NX+NX ) );
00253 if( NU > 0 ) loopTemp3.addStatement( rk_diffsPrev3.getSubMatrix( i,i+1,NX,NX+NU ) == rk_eta.getCols( i*NU+(NX+NXA)*(NX+1)+(NX1+NX2)*NU,i*NU+(NX+NXA)*(NX+1)+(NX1+NX2)*NU+NU ) );
00254 loop->addStatement( loopTemp3 );
00255 }
00256 loop->addStatement( std::string("}\n") );
00257 }
00258
00259
00260 if( NX1 > 0 ) {
00261 loop->addFunctionCall( lin_input.getName(), rk_xxx, rk_eta.getAddress(0,0) );
00262 }
00263 if( NX2 > 0 ) {
00264 loop->addFunctionCall( getNameRHS(), rk_xxx, rk_eta.getAddress(0,NX1) );
00265 }
00266 if( NX3 > 0 ) {
00267 loop->addFunctionCall( getNameOutputRHS(), rk_xxx, rk_eta.getAddress(0,NX1+NX2) );
00268 }
00269
00270
00271 if( NX2 > 0 ) {
00272 loop->addFunctionCall( getNameDiffsRHS(), rk_xxx, rk_diffsNew2.getAddress(0,0) );
00273 }
00274 if( NX3 > 0 ) {
00275 loop->addFunctionCall( getNameOutputDiffs(), rk_xxx, rk_diffsTemp3.getAddress(0,0) );
00276 ExportForLoop loop1( i,0,NX3 );
00277 ExportForLoop loop2( j,0,NX1+NX2 );
00278 loop2.addStatement( rk_diffsNew3.getSubMatrix(i,i+1,j,j+1) == rk_diffsTemp3.getSubMatrix(i,i+1,j,j+1) );
00279 loop1.addStatement( loop2 );
00280 loop2 = ExportForLoop( j,0,NU );
00281 loop2.addStatement( rk_diffsNew3.getSubMatrix(i,i+1,NX+j,NX+j+1) == rk_diffsTemp3.getSubMatrix(i,i+1,NX1+NX2+j,NX1+NX2+j+1) );
00282 loop1.addStatement( loop2 );
00283 loop->addStatement( loop1 );
00284 }
00285
00286
00287 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) {
00288 loop->addStatement( std::string( "if( run == 0 ) {\n" ) );
00289 }
00290
00291 updateInputSystem(loop, i, j, tmp_index);
00292
00293 updateImplicitSystem(loop, i, j, tmp_index);
00294
00295 updateOutputSystem(loop, i, j, tmp_index);
00296
00297 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) {
00298 loop->addStatement( std::string( "}\n" ) );
00299 loop->addStatement( std::string( "else {\n" ) );
00300
00301 propagateInputSystem(loop, i, j, k, tmp_index);
00302
00303 propagateImplicitSystem(loop, i, j, k, tmp_index);
00304
00305 propagateOutputSystem(loop, i, j, k, tmp_index);
00306 loop->addStatement( std::string( "}\n" ) );
00307 }
00308
00309
00310 if( !equidistantControlGrid() ) {
00311 loop->addStatement( "}\n" );
00312 }
00313 else {
00314 integrate.addStatement( *loop );
00315 }
00316
00317 if( NX1 > 0 ) {
00318 DMatrix zeroR = zeros<double>(1, NX2+NX3);
00319 ExportForLoop loop1( i,0,NX1 );
00320 loop1.addStatement( rk_eta.getCols( i*NX+NX+NXA+NX1,i*NX+NX+NXA+NX ) == zeroR );
00321 integrate.addStatement( loop1 );
00322 }
00323
00324 DMatrix zeroR = zeros<double>(1, NX3);
00325 if( NX2 > 0 ) {
00326 ExportForLoop loop2( i,NX1,NX1+NX2 );
00327 loop2.addStatement( rk_eta.getCols( i*NX+NX+NXA+NX1+NX2,i*NX+NX+NXA+NX ) == zeroR );
00328 integrate.addStatement( loop2 );
00329 }
00330
00331 LOG( LVL_DEBUG ) << "done" << endl;
00332
00333 return SUCCESSFUL_RETURN;
00334 }
00335
00336
00337 returnValue DiscreteTimeExport::getCode( ExportStatementBlock& code
00338 )
00339 {
00340 int useOMP;
00341 get(CG_USE_OPENMP, useOMP);
00342 if ( useOMP ) {
00343 ExportVariable max = getAuxVariable();
00344 max.setName( "auxVar" );
00345 max.setDataStruct( ACADO_LOCAL );
00346 if( NX2 > 0 ) {
00347 rhs.setGlobalExportVariable( max );
00348 diffs_rhs.setGlobalExportVariable( max );
00349 }
00350 if( NX3 > 0 ) {
00351 rhs3.setGlobalExportVariable( max );
00352 diffs_rhs3.setGlobalExportVariable( max );
00353 }
00354
00355 getDataDeclarations( code, ACADO_LOCAL );
00356
00357 stringstream s;
00358 s << "#pragma omp threadprivate( "
00359 << max.getFullName() << ", "
00360 << rk_xxx.getFullName();
00361 if( NX1 > 0 ) {
00362 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) s << ", " << rk_diffsPrev1.getFullName();
00363 s << ", " << rk_diffsNew1.getFullName();
00364 }
00365 if( NX2 > 0 || NXA > 0 ) {
00366 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) s << ", " << rk_diffsPrev2.getFullName();
00367 s << ", " << rk_diffsNew2.getFullName();
00368 }
00369 if( NX3 > 0 ) {
00370 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) s << ", " << rk_diffsPrev3.getFullName();
00371 s << ", " << rk_diffsNew3.getFullName();
00372 s << ", " << rk_diffsTemp3.getFullName();
00373 }
00374 s << " )" << endl << endl;
00375 code.addStatement( s.str().c_str() );
00376 }
00377
00378 if( NX1 > 0 ) {
00379 code.addFunction( lin_input );
00380 code.addStatement( "\n\n" );
00381 }
00382
00383 if( NX2 > 0 ) {
00384 code.addFunction( rhs );
00385 code.addStatement( "\n\n" );
00386 code.addFunction( diffs_rhs );
00387 code.addStatement( "\n\n" );
00388 }
00389
00390 if( NX3 > 0 ) {
00391 code.addFunction( rhs3 );
00392 code.addStatement( "\n\n" );
00393 code.addFunction( diffs_rhs3 );
00394 code.addStatement( "\n\n" );
00395 }
00396
00397 if( !equidistantControlGrid() ) {
00398 ExportVariable numStepsV( "numSteps", numSteps, STATIC_CONST_INT );
00399 code.addDeclaration( numStepsV );
00400 code.addLinebreak( 2 );
00401 }
00402 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00403 code.addComment(std::string("Fixed step size:") + toString(h));
00404
00405 code.addFunction( integrate );
00406
00407 return SUCCESSFUL_RETURN;
00408 }
00409
00410
00411 returnValue DiscreteTimeExport::setNARXmodel( const uint delay, const DMatrix& parms ) {
00412
00413 return RET_INVALID_OPTION;
00414 }
00415
00416
00417 returnValue DiscreteTimeExport::setupOutput( const std::vector<Grid> outputGrids_, const std::vector<Expression> _rhs ) {
00418
00419 return ACADOERROR( RET_INVALID_OPTION );
00420 }
00421
00422
00423 returnValue DiscreteTimeExport::setupOutput( const std::vector<Grid> outputGrids_,
00424 const std::vector<std::string> _outputNames,
00425 const std::vector<std::string> _diffs_outputNames,
00426 const std::vector<uint> _dims_output ) {
00427
00428 return ACADOERROR( RET_INVALID_OPTION );
00429 }
00430
00431
00432 returnValue DiscreteTimeExport::setupOutput( const std::vector<Grid> outputGrids_,
00433 const std::vector<std::string> _outputNames,
00434 const std::vector<std::string> _diffs_outputNames,
00435 const std::vector<uint> _dims_output,
00436 const std::vector<DMatrix> _outputDependencies ) {
00437
00438 return ACADOERROR( RET_INVALID_OPTION );
00439 }
00440
00441
00442 DiscreteTimeExport& DiscreteTimeExport::operator=( const DiscreteTimeExport& arg
00443 )
00444 {
00445 if( this != &arg )
00446 {
00447 clear( );
00448 IntegratorExport::operator=( arg );
00449 copy( arg );
00450 }
00451 return *this;
00452 }
00453
00454
00455
00456
00457
00458
00459
00460
00461
00462 IntegratorExport* createDiscreteTimeExport( UserInteraction* _userInteraction,
00463 const std::string &_commonHeaderName )
00464 {
00465 return new DiscreteTimeExport(_userInteraction, _commonHeaderName);
00466 }
00467
00468
00469 ExportVariable DiscreteTimeExport::getAuxVariable() const
00470 {
00471 ExportVariable max;
00472 if( NX1 > 0 ) {
00473 max = lin_input.getGlobalExportVariable();
00474 }
00475 if( NX2 > 0 ) {
00476 if( rhs.getGlobalExportVariable().getDim() >= max.getDim() ) {
00477 max = rhs.getGlobalExportVariable();
00478 }
00479 if( diffs_rhs.getGlobalExportVariable().getDim() >= max.getDim() ) {
00480 max = diffs_rhs.getGlobalExportVariable();
00481 }
00482 }
00483 if( NX3 > 0 ) {
00484 if( rhs3.getGlobalExportVariable().getDim() >= max.getDim() ) {
00485 max = rhs3.getGlobalExportVariable();
00486 }
00487 if( diffs_rhs3.getGlobalExportVariable().getDim() >= max.getDim() ) {
00488 max = diffs_rhs3.getGlobalExportVariable();
00489 }
00490 }
00491
00492 return max;
00493 }
00494
00495
00496 returnValue DiscreteTimeExport::copy( const DiscreteTimeExport& arg
00497 )
00498 {
00499 rhs = arg.rhs;
00500 diffs_rhs = arg.diffs_rhs;
00501
00502
00503 rk_ttt = arg.rk_ttt;
00504 rk_xxx = arg.rk_xxx;
00505
00506
00507 integrate = arg.integrate;
00508
00509 return SUCCESSFUL_RETURN;
00510 }
00511
00512
00513
00514 CLOSE_NAMESPACE_ACADO
00515
00516