00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/integrators/dirk_export.hpp>
00035
00036 #include <sstream>
00037 using namespace std;
00038
00039
00040
00041 BEGIN_NAMESPACE_ACADO
00042
00043
00044
00045
00046
00047
00048 DiagonallyImplicitRKExport::DiagonallyImplicitRKExport( UserInteraction* _userInteraction,
00049 const std::string& _commonHeaderName
00050 ) : ForwardIRKExport( _userInteraction,_commonHeaderName )
00051 {
00052
00053 }
00054
00055 DiagonallyImplicitRKExport::DiagonallyImplicitRKExport( const DiagonallyImplicitRKExport& arg ) : ForwardIRKExport( arg )
00056 {
00057
00058 }
00059
00060
00061 DiagonallyImplicitRKExport::~DiagonallyImplicitRKExport( )
00062 {
00063 if ( solver )
00064 delete solver;
00065 solver = 0;
00066
00067 clear( );
00068 }
00069
00070
00071 DiagonallyImplicitRKExport& DiagonallyImplicitRKExport::operator=( const DiagonallyImplicitRKExport& arg ){
00072
00073 if( this != &arg ){
00074
00075 ForwardIRKExport::operator=( arg );
00076 }
00077 return *this;
00078 }
00079
00080
00081 returnValue DiagonallyImplicitRKExport::solveInputSystem( ExportStatementBlock* block, const ExportIndex& index1, const ExportIndex& index2, const ExportIndex& index3, const ExportIndex& tmp_index, const ExportVariable& Ah )
00082 {
00083 if( NX1 > 0 ) {
00084 ExportForLoop loop( index1,0,numStages );
00085 loop.addStatement( rk_xxx.getCols(0,NX1) == rk_eta.getCols(0,NX1) );
00086 ExportForLoop loop01( index2,0,NX1 );
00087 ExportForLoop loop02( index3,0,index1 );
00088 loop02.addStatement( rk_xxx.getCol( index2 ) += Ah.getElement(index1,index3)*rk_kkk.getElement(index2,index3) );
00089 loop01.addStatement( loop02 );
00090 loop.addStatement( loop01 );
00091 loop.addFunctionCall( lin_input.getName(), rk_xxx, rk_b.getAddress(0,0) );
00092
00093 ExportForLoop loop5( index2,0,NX1 );
00094 loop5.addStatement( tmp_index == index1*NX1+index2 );
00095 loop5.addStatement( rk_kkk.getElement(index2,index1) == rk_mat1.getElement(tmp_index,0)*rk_b.getRow(0) );
00096 ExportForLoop loop6( index3,1,NX1 );
00097 loop6.addStatement( rk_kkk.getElement(index2,index1) += rk_mat1.getElement(tmp_index,index3)*rk_b.getRow(index3) );
00098 loop5.addStatement(loop6);
00099 loop.addStatement(loop5);
00100 block->addStatement(loop);
00101 }
00102
00103 return SUCCESSFUL_RETURN;
00104 }
00105
00106
00107 returnValue DiagonallyImplicitRKExport::prepareInputSystem( ExportStatementBlock& code )
00108 {
00109 if( NX1 > 0 ) {
00110 DMatrix mat1 = formMatrix( M11, A11 );
00111 rk_mat1 = ExportVariable( "rk_mat1", mat1, STATIC_CONST_REAL );
00112 code.addDeclaration( rk_mat1 );
00113
00114 rk_mat1 = ExportVariable( "rk_mat1", numStages*NX1, NX1, STATIC_CONST_REAL, ACADO_LOCAL );
00115 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00116
00117 DMatrix sens = zeros<double>(NX1*(NX1+NU), numStages);
00118 uint i, j, k, s1, s2;
00119 for( i = 0; i < NX1; i++ ) {
00120 DVector vec(NX1);
00121 for( j = 0; j < numStages; j++ ) {
00122 for( k = 0; k < NX1; k++ ) {
00123 vec(k) = A11(k,i);
00124 for( s1 = 0; s1 < j; s1++ ) {
00125 for( s2 = 0; s2 < NX1; s2++ ) {
00126 vec(k) = vec(k) + AA(j,s1)*h*A11(k,s2)*sens(i*NX1+s2,s1);
00127 }
00128 }
00129 }
00130 DVector sol = mat1*vec;
00131 for( k = 0; k < NX1; k++ ) {
00132 sens(i*NX1+k,j) = sol(k);
00133 }
00134 }
00135 }
00136 for( i = 0; i < NU; i++ ) {
00137 DVector vec(NX1);
00138 for( j = 0; j < numStages; j++ ) {
00139 for( k = 0; k < NX1; k++ ) {
00140 vec(k) = B11(k,i);
00141 for( s1 = 0; s1 < j; s1++ ) {
00142 for( s2 = 0; s2 < NX1; s2++ ) {
00143 vec(k) = vec(k) + AA(j,s1)*h*A11(k,s2)*sens(NX1*NX1+i*NX1+s2,s1);
00144 }
00145 }
00146 }
00147 DVector sol = mat1*vec;
00148 for( k = 0; k < NX1; k++ ) {
00149 sens(NX1*NX1+i*NX1+k,j) = sol(k);
00150 }
00151 }
00152 }
00153 rk_dk1 = ExportVariable( "rk_dk1", sens, STATIC_CONST_REAL );
00154 code.addDeclaration( rk_dk1 );
00155
00156 rk_dk1 = ExportVariable( "rk_dk1", NX1*(NX1+NU), numStages, STATIC_CONST_REAL, ACADO_LOCAL );
00157 }
00158
00159 return SUCCESSFUL_RETURN;
00160 }
00161
00162
00163 DMatrix DiagonallyImplicitRKExport::formMatrix( const DMatrix& mass, const DMatrix& jacobian ) {
00164 if( jacobian.getNumRows() != jacobian.getNumCols() ) {
00165 return RET_UNABLE_TO_EXPORT_CODE;
00166 }
00167 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00168 uint vars = jacobian.getNumRows();
00169 uint i1, i2, j2;
00170 DMatrix result = zeros<double>(numStages*vars, vars);
00171 DMatrix tmp = zeros<double>(vars, vars);
00172 for( i1 = 0; i1 < numStages; i1++ ){
00173 for( i2 = 0; i2 < vars; i2++ ){
00174 for( j2 = 0; j2 < vars; j2++ ) {
00175 tmp(i2, j2) = mass(i2,j2) - AA(i1,i1)*h*jacobian(i2,j2);
00176 }
00177 }
00178 tmp = tmp.inverse();
00179 for( i2 = 0; i2 < vars; i2++ ){
00180 for( j2 = 0; j2 < vars; j2++ ) {
00181 result(i1*vars+i2, j2) = tmp(i2, j2);
00182 }
00183 }
00184 }
00185
00186 return result;
00187 }
00188
00189
00190 returnValue DiagonallyImplicitRKExport::solveImplicitSystem( ExportStatementBlock* block, const ExportIndex& index1, const ExportIndex& index2, const ExportIndex& index3, const ExportIndex& tmp_index, const ExportVariable& Ah, const ExportVariable& C, const ExportVariable& det, bool DERIVATIVES )
00191 {
00192 if( NX2 > 0 || NXA > 0 ) {
00193
00194 if( REUSE ) block->addStatement( std::string( "if( " ) + reset_int.getFullName() + " ) {\n" );
00195
00196 ExportForLoop loop11( index2,0,numStages );
00197 ExportForLoop loop1( index1,0,numItsInit+1 );
00198 evaluateMatrix( &loop1, index2, index3, tmp_index, Ah, C, true, DERIVATIVES );
00199 loop1.addStatement( det.getFullName() + " = " + solver->getNameSolveFunction() + "( &" + rk_A.get(index2*(NX2+NXA),0) + ", " + rk_b.getFullName() + ", &" + rk_auxSolver.get(index2,0) + " );\n" );
00200 loop1.addStatement( rk_kkk.getSubMatrix( NX1,NX1+NX2,index2,index2+1 ) += rk_b.getRows( 0,NX2 ) );
00201 if(NXA > 0) loop1.addStatement( rk_kkk.getSubMatrix( NX,NX+NXA,index2,index2+1 ) += rk_b.getRows( NX2,NX2+NXA ) );
00202 loop11.addStatement( loop1 );
00203 block->addStatement( loop11 );
00204 if( REUSE ) block->addStatement( std::string( "}\n" ) );
00205
00206
00207 ExportForLoop loop21( index2,0,numStages );
00208 ExportForLoop loop2( index1,0,numIts );
00209 evaluateStatesImplicitSystem( &loop2, Ah, C, index2, index3, tmp_index );
00210 evaluateRhsImplicitSystem( &loop2, index2 );
00211 loop2.addFunctionCall( solver->getNameSolveReuseFunction(),rk_A.getAddress(index2*(NX2+NXA),0),rk_b.getAddress(0,0),rk_auxSolver.getAddress(index2,0) );
00212 loop2.addStatement( rk_kkk.getSubMatrix( NX1,NX1+NX2,index2,index2+1 ) += rk_b.getRows( 0,NX2 ) );
00213 if(NXA > 0) loop2.addStatement( rk_kkk.getSubMatrix( NX,NX+NXA,index2,index2+1 ) += rk_b.getRows( NX2,NX2+NXA ) );
00214 loop21.addStatement( loop2 );
00215 block->addStatement( loop21 );
00216
00217 if( DERIVATIVES ) {
00218
00219 ExportForLoop loop3( index2,0,numStages );
00220 evaluateMatrix( &loop3, index2, index3, tmp_index, Ah, C, false, DERIVATIVES );
00221 block->addStatement( loop3 );
00222 }
00223
00224
00225 int debugMode;
00226 get( INTEGRATOR_DEBUG_MODE, debugMode );
00227 if ( (bool)debugMode == true ) {
00228 block->addStatement( debug_mat == rk_A );
00229 }
00230 }
00231
00232 return SUCCESSFUL_RETURN;
00233 }
00234
00235
00236 returnValue DiagonallyImplicitRKExport::sensitivitiesImplicitSystem( ExportStatementBlock* block, const ExportIndex& index1, const ExportIndex& index2, const ExportIndex& index3, const ExportIndex& tmp_index1, const ExportIndex& tmp_index2, const ExportVariable& Ah, const ExportVariable& Bh, const ExportVariable& det, bool STATES, uint number )
00237 {
00238 if( NX2 > 0 ) {
00239 DMatrix zeroM = zeros<double>( NX2+NXA,1 );
00240 DMatrix tempCoefs( evaluateDerivedPolynomial( 0.0 ) );
00241 uint i;
00242
00243 ExportForLoop loop1( index2,0,numStages );
00244 if( STATES && number == 1 ) {
00245 ExportForLoop loop2( index3,0,NX1 );
00246 loop2.addStatement( std::string(rk_rhsTemp.get( index3,0 )) + " = -(" + index3.getName() + " == " + index1.getName() + ");\n" );
00247 ExportForLoop loop21( tmp_index1,0,index2+1 );
00248 loop21.addStatement( rk_rhsTemp.getRow( index3 ) -= rk_diffK.getElement( index3,tmp_index1 )*Ah.getElement(index2,tmp_index1) );
00249 loop2.addStatement( loop21 );
00250 loop1.addStatement( loop2 );
00251 ExportForLoop loop3( index3,0,NX2+NXA );
00252 loop3.addStatement( rk_b.getRow( index3 ) == rk_diffsTemp2.getSubMatrix( index2,index2+1,index3*(NVARS2),index3*(NVARS2)+NX1 )*rk_rhsTemp.getRows(0,NX1) );
00253 if( NDX2 > 0 ) {
00254 loop3.addStatement( rk_b.getRow( index3 ) -= rk_diffsTemp2.getSubMatrix( index2,index2+1,index3*(NVARS2)+NVARS2-NX1-NX2,index3*(NVARS2)+NVARS2-NX2 )*rk_diffK.getSubMatrix( 0,NX1,index2,index2+1 ) );
00255 }
00256 loop1.addStatement( loop3 );
00257 }
00258 else if( STATES && number == 2 ) {
00259 for( i = 0; i < NX2+NXA; i++ ) {
00260 loop1.addStatement( rk_b.getRow( i ) == zeroM.getRow( 0 ) - rk_diffsTemp2.getElement( index2,index1+i*(NVARS2) ) );
00261 }
00262 }
00263 else {
00264 ExportForLoop loop2( index3,0,NX1 );
00265 loop2.addStatement( rk_rhsTemp.getRow( index3 ) == rk_diffK.getElement( index3,0 )*Ah.getElement(index2,0) );
00266 ExportForLoop loop21( tmp_index1,1,index2+1 );
00267 loop21.addStatement( rk_rhsTemp.getRow( index3 ) += rk_diffK.getElement( index3,tmp_index1 )*Ah.getElement(index2,tmp_index1) );
00268 loop2.addStatement( loop21 );
00269 loop1.addStatement( loop2 );
00270 ExportForLoop loop3( index3,0,NX2+NXA );
00271 loop3.addStatement( tmp_index2 == index1+index3*(NVARS2) );
00272 loop3.addStatement( rk_b.getRow( index3 ) == zeroM.getRow( 0 ) - rk_diffsTemp2.getElement( index2,tmp_index2+NX1+NX2+NXA ) );
00273 loop3.addStatement( rk_b.getRow( index3 ) -= rk_diffsTemp2.getSubMatrix( index2,index2+1,index3*(NVARS2),index3*(NVARS2)+NX1 )*rk_rhsTemp.getRows(0,NX1) );
00274 if( NDX2 > 0 ) {
00275 loop3.addStatement( rk_b.getRow( index3 ) -= rk_diffsTemp2.getSubMatrix( index2,index2+1,index3*(NVARS2)+NVARS2-NX1-NX2,index3*(NVARS2)+NVARS2-NX2 )*rk_diffK.getSubMatrix( 0,NX1,index2,index2+1 ) );
00276 }
00277 loop1.addStatement( loop3 );
00278 }
00279 ExportForLoop loop11( index3,0,NX2+NXA );
00280 ExportForLoop loop12( tmp_index1,0,index2 );
00281 ExportForLoop loop13( tmp_index2,NX1,NX1+NX2 );
00282 loop13.addStatement( std::string( rk_b.get(index3,0) ) + " -= " + Ah.get(index2,tmp_index1) + "*" + rk_diffsTemp2.get(index2,index3*NVARS2+tmp_index2) + "*" + rk_diffK.get(tmp_index2,tmp_index1) + ";\n" );
00283 loop12.addStatement( loop13 );
00284 loop11.addStatement( loop12 );
00285 loop1.addStatement( loop11 );
00286 if( STATES && (number == 1 || NX1 == 0) ) {
00287 loop1.addStatement( std::string( "if( 0 == " ) + index1.getName() + " ) {\n" );
00288 loop1.addStatement( det.getFullName() + " = " + solver->getNameSolveFunction() + "( &" + rk_A.get(index2*(NX2+NXA),0) + ", " + rk_b.getFullName() + ", &" + rk_auxSolver.get(index2,0) + " );\n" );
00289 loop1.addStatement( std::string( "}\n else {\n" ) );
00290 }
00291 loop1.addFunctionCall( solver->getNameSolveReuseFunction(),rk_A.getAddress(index2*(NX2+NXA),0),rk_b.getAddress(0,0),rk_auxSolver.getAddress(index2,0) );
00292 if( STATES && (number == 1 || NX1 == 0) ) loop1.addStatement( std::string( "}\n" ) );
00293
00294 loop1.addStatement( rk_diffK.getSubMatrix(NX1,NX1+NX2,index2,index2+1) == rk_b.getRows(0,NX2) );
00295 loop1.addStatement( rk_diffK.getSubMatrix(NX,NX+NXA,index2,index2+1) == rk_b.getRows(NX2,NX2+NXA) );
00296 block->addStatement( loop1 );
00297
00298 ExportForLoop loop3( index2,0,NX2 );
00299 if( STATES && number == 2 ) loop3.addStatement( std::string(rk_diffsNew2.get( index2,index1 )) + " = (" + index2.getName() + " == " + index1.getName() + "-" + toString(NX1) + ");\n" );
00300
00301 if( STATES && number == 2 ) loop3.addStatement( rk_diffsNew2.getElement( index2,index1 ) += rk_diffK.getRow( NX1+index2 )*Bh );
00302 else if( STATES ) loop3.addStatement( rk_diffsNew2.getElement( index2,index1 ) == rk_diffK.getRow( NX1+index2 )*Bh );
00303 else loop3.addStatement( rk_diffsNew2.getElement( index2,index1+NX1+NX2 ) == rk_diffK.getRow( NX1+index2 )*Bh );
00304 block->addStatement( loop3 );
00305 if( NXA > 0 ) {
00306 block->addStatement( std::string("if( run == 0 ) {\n") );
00307 ExportForLoop loop4( index2,0,NXA );
00308 if( STATES ) loop4.addStatement( rk_diffsNew2.getElement( index2+NX2,index1 ) == rk_diffK.getRow( NX+index2 )*tempCoefs );
00309 else loop4.addStatement( rk_diffsNew2.getElement( index2+NX2,index1+NX1+NX2 ) == rk_diffK.getRow( NX+index2 )*tempCoefs );
00310 block->addStatement( loop4 );
00311 block->addStatement( std::string("}\n") );
00312 }
00313 }
00314
00315 return SUCCESSFUL_RETURN;
00316 }
00317
00318
00319 returnValue DiagonallyImplicitRKExport::evaluateMatrix( ExportStatementBlock* block, const ExportIndex& index1, const ExportIndex& index2, const ExportIndex& tmp_index, const ExportVariable& Ah, const ExportVariable& C, bool evaluateB, bool DERIVATIVES )
00320 {
00321 evaluateStatesImplicitSystem( block, Ah, C, index1, index2, tmp_index );
00322
00323 ExportIndex indexDiffs(index1);
00324 if( !DERIVATIVES ) indexDiffs = ExportIndex(0);
00325
00326 block->addFunctionCall( getNameDiffsRHS(), rk_xxx, rk_diffsTemp2.getAddress(indexDiffs,0) );
00327 ExportForLoop loop2( index2,0,NX2+NXA );
00328 loop2.addStatement( tmp_index == index1*(NX2+NXA)+index2 );
00329 if( NDX2 == 0 ) {
00330 loop2.addStatement( rk_A.getSubMatrix( tmp_index,tmp_index+1,0,NX2 ) == Ah.getElement( 0,0 )*rk_diffsTemp2.getSubMatrix( indexDiffs,indexDiffs+1,index2*(NVARS2)+NX1,index2*(NVARS2)+NX1+NX2 ) );
00331 loop2.addStatement( rk_A.getElement( tmp_index,index2 ) -= 1 );
00332 }
00333 else {
00334 loop2.addStatement( rk_A.getSubMatrix( tmp_index,tmp_index+1,0,NX2 ) == Ah.getElement( 0,0 )*rk_diffsTemp2.getSubMatrix( indexDiffs,indexDiffs+1,index2*(NVARS2)+NX1,index2*(NVARS2)+NX1+NX2 ) );
00335 loop2.addStatement( rk_A.getSubMatrix( tmp_index,tmp_index+1,0,NX2 ) += rk_diffsTemp2.getSubMatrix( indexDiffs,indexDiffs+1,index2*(NVARS2)+NVARS2-NX2,index2*(NVARS2)+NVARS2 ) );
00336 }
00337 if( NXA > 0 ) {
00338 DMatrix zeroM = zeros<double>( 1,NXA );
00339 loop2.addStatement( rk_A.getSubMatrix( tmp_index,tmp_index+1,NX2,NX2+NXA ) == rk_diffsTemp2.getSubMatrix( indexDiffs,indexDiffs+1,index2*(NVARS2)+NX1+NX2,index2*(NVARS2)+NX1+NX2+NXA ) );
00340 }
00341 block->addStatement( loop2 );
00342 if( evaluateB ) {
00343 evaluateRhsImplicitSystem( block, index1 );
00344 }
00345
00346 return SUCCESSFUL_RETURN;
00347 }
00348
00349
00350 returnValue DiagonallyImplicitRKExport::evaluateStatesImplicitSystem( ExportStatementBlock* block, const ExportVariable& Ah, const ExportVariable& C, const ExportIndex& stage, const ExportIndex& i, const ExportIndex& j )
00351 {
00352 ExportForLoop loop1( i, 0, NX1+NX2 );
00353 loop1.addStatement( rk_xxx.getCol( i ) == rk_eta.getCol( i ) );
00354 ExportForLoop loop2( j, 0, stage+1 );
00355 loop2.addStatement( rk_xxx.getCol( i ) += Ah.getElement(stage,j)*rk_kkk.getElement( i,j ) );
00356 loop1.addStatement( loop2 );
00357 block->addStatement( loop1 );
00358
00359 ExportForLoop loop3( i, 0, NXA );
00360 loop3.addStatement( rk_xxx.getCol( NX+i ) == rk_kkk.getElement( NX+i,stage ) );
00361 block->addStatement( loop3 );
00362
00363 ExportForLoop loop4( i, 0, NDX2 );
00364 loop4.addStatement( rk_xxx.getCol( inputDim-diffsDim+i ) == rk_kkk.getElement( i,stage ) );
00365 block->addStatement( loop4 );
00366
00367 if( C.getDim() > 0 ) {
00368 block->addStatement( rk_xxx.getCol( inputDim-diffsDim+NDX2 ) == C.getCol(stage) );
00369 }
00370
00371 return SUCCESSFUL_RETURN;
00372 }
00373
00374
00375 returnValue DiagonallyImplicitRKExport::evaluateRhsImplicitSystem( ExportStatementBlock* block, const ExportIndex& stage )
00376 {
00377 DMatrix zeroM = zeros<double>( NX2+NXA,1 );
00378 block->addFunctionCall( getNameRHS(), rk_xxx, rk_rhsTemp.getAddress(0,0) );
00379
00380 if( NDX2 == 0 ) {
00381 block->addStatement( rk_b.getRows( 0,NX2 ) == rk_kkk.getSubMatrix( NX1,NX1+NX2,stage,stage+1 ) - rk_rhsTemp.getRows( 0,NX2 ) );
00382 }
00383 else {
00384 block->addStatement( rk_b.getRows( 0,NX2 ) == zeroM.getRows( 0,NX2-1 ) - rk_rhsTemp.getRows( 0,NX2 ) );
00385 }
00386 if( NXA > 0 ) {
00387 block->addStatement( rk_b.getRows( NX2,NX2+NXA ) == zeroM.getRows( 0,NXA-1 ) - rk_rhsTemp.getRows( NX2,NX2+NXA ) );
00388 }
00389
00390 return SUCCESSFUL_RETURN;
00391 }
00392
00393
00394 returnValue DiagonallyImplicitRKExport::solveOutputSystem( ExportStatementBlock* block, const ExportIndex& index1, const ExportIndex& index2, const ExportIndex& index3, const ExportIndex& tmp_index, const ExportVariable& Ah, bool DERIVATIVES )
00395 {
00396 if( NX3 > 0 ) {
00397 ExportForLoop loop( index1,0,numStages );
00398 evaluateStatesOutputSystem( &loop, Ah, index1 );
00399 ExportForLoop loop01( index2,NX1+NX2,NX );
00400 ExportForLoop loop02( index3,0,index1 );
00401 loop02.addStatement( rk_xxx.getCol( index2 ) += Ah.getElement(index1,index3)*rk_kkk.getElement(index2,index3) );
00402 loop01.addStatement( loop02 );
00403 loop.addStatement( loop01 );
00404 loop.addFunctionCall( getNameOutputRHS(), rk_xxx, rk_b.getAddress(0,0) );
00405 if( DERIVATIVES ) loop.addFunctionCall( getNameOutputDiffs(), rk_xxx, rk_diffsTemp3.getAddress(index1,0) );
00406
00407 ExportForLoop loop5( index2,0,NX3 );
00408 loop5.addStatement( tmp_index == index1*NX3+index2 );
00409 loop5.addStatement( rk_kkk.getElement(NX1+NX2+index2,index1) == rk_mat3.getElement(tmp_index,0)*rk_b.getRow(0) );
00410 ExportForLoop loop6( index3,1,NX3 );
00411 loop6.addStatement( rk_kkk.getElement(NX1+NX2+index2,index1) += rk_mat3.getElement(tmp_index,index3)*rk_b.getRow(index3) );
00412 loop5.addStatement(loop6);
00413 loop.addStatement(loop5);
00414 block->addStatement(loop);
00415 }
00416
00417 return SUCCESSFUL_RETURN;
00418 }
00419
00420
00421 returnValue DiagonallyImplicitRKExport::sensitivitiesOutputSystem( ExportStatementBlock* block, const ExportIndex& index1, const ExportIndex& index2, const ExportIndex& index3, const ExportIndex& index4, const ExportIndex& tmp_index1, const ExportIndex& tmp_index2, const ExportVariable& Ah, const ExportVariable& Bh, bool STATES, uint number )
00422 {
00423 if( NX3 > 0 ) {
00424 uint i, j;
00425 ExportForLoop loop1( index2,0,numStages );
00426 if( STATES && number == 1 ) {
00427 ExportForLoop loop2( index3,0,NX1 );
00428 loop2.addStatement( std::string(rk_rhsTemp.get( index3,0 )) + " = (" + index3.getName() + " == " + index1.getName() + ");\n" );
00429 for( i = 0; i < numStages; i++ ) {
00430 loop2.addStatement( rk_rhsTemp.getRow( index3 ) += rk_diffK.getElement( index3,i )*Ah.getElement(index2,i) );
00431 }
00432 loop1.addStatement( loop2 );
00433 ExportForLoop loop3( index3,NX1,NX1+NX2 );
00434 loop3.addStatement( rk_rhsTemp.getRow( index3 ) == 0.0 );
00435 for( i = 0; i < numStages; i++ ) {
00436 loop3.addStatement( rk_rhsTemp.getRow( index3 ) += rk_diffK.getElement( index3,i )*Ah.getElement(index2,i) );
00437 }
00438 loop1.addStatement( loop3 );
00439 ExportForLoop loop4( index3,0,NX3 );
00440 loop4.addStatement( rk_b.getRow( index3 ) == rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3),index3*(NVARS3)+NX1+NX2 )*rk_rhsTemp.getRows(0,NX1+NX2) );
00441 if( NXA3 > 0 ) {
00442 loop4.addStatement( rk_b.getRow( index3 ) += rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3)+NX1+NX2,index3*(NVARS3)+NX1+NX2+NXA )*rk_diffK.getSubMatrix( NX,NX+NXA,index2,index2+1 ) );
00443 }
00444 if( NDX3 > 0 ) {
00445 loop4.addStatement( rk_b.getRow( index3 ) += rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3)+NVARS3-NX1-NX2,index3*(NVARS3)+NVARS3 )*rk_diffK.getSubMatrix( 0,NX1+NX2,index2,index2+1 ) );
00446 }
00447 loop1.addStatement( loop4 );
00448 }
00449 else if( STATES && number == 2 ) {
00450 ExportForLoop loop3( index3,NX1,NX1+NX2 );
00451 loop3.addStatement( std::string(rk_rhsTemp.get( index3,0 )) + " = (" + index3.getName() + " == " + index1.getName() + ");\n" );
00452 for( i = 0; i < numStages; i++ ) {
00453 loop3.addStatement( rk_rhsTemp.getRow( index3 ) += rk_diffK.getElement( index3,i )*Ah.getElement(index2,i) );
00454 }
00455 loop1.addStatement( loop3 );
00456 ExportForLoop loop4( index3,0,NX3 );
00457 loop4.addStatement( rk_b.getRow( index3 ) == rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3)+NX1,index3*(NVARS3)+NX1+NX2 )*rk_rhsTemp.getRows(NX1,NX1+NX2) );
00458 if( NXA3 > 0 ) {
00459 loop4.addStatement( rk_b.getRow( index3 ) += rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3)+NX1+NX2,index3*(NVARS3)+NX1+NX2+NXA )*rk_diffK.getSubMatrix( NX,NX+NXA,index2,index2+1 ) );
00460 }
00461 if( NDX3 > 0 ) {
00462 loop4.addStatement( rk_b.getRow( index3 ) += rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3)+NVARS3-NX2,index3*(NVARS3)+NVARS3 )*rk_diffK.getSubMatrix( NX1,NX1+NX2,index2,index2+1 ) );
00463 }
00464 loop1.addStatement( loop4 );
00465 }
00466 else if( !STATES ) {
00467 ExportForLoop loop2( index3,0,NX1 );
00468 loop2.addStatement( rk_rhsTemp.getRow( index3 ) == rk_diffK.getElement( index3,0 )*Ah.getElement(index2,0) );
00469 for( i = 1; i < numStages; i++ ) {
00470 loop2.addStatement( rk_rhsTemp.getRow( index3 ) += rk_diffK.getElement( index3,i )*Ah.getElement(index2,i) );
00471 }
00472 loop1.addStatement( loop2 );
00473 ExportForLoop loop3( index3,NX1,NX1+NX2 );
00474 loop3.addStatement( rk_rhsTemp.getRow( index3 ) == rk_diffK.getElement( index3,0 )*Ah.getElement(index2,0) );
00475 for( i = 1; i < numStages; i++ ) {
00476 loop3.addStatement( rk_rhsTemp.getRow( index3 ) += rk_diffK.getElement( index3,i )*Ah.getElement(index2,i) );
00477 }
00478 loop1.addStatement( loop3 );
00479 ExportForLoop loop4( index3,0,NX3 );
00480 loop4.addStatement( tmp_index2 == index1+index3*(NVARS3) );
00481 loop4.addStatement( rk_b.getRow( index3 ) == rk_diffsTemp3.getElement( index2,tmp_index2+NX1+NX2+NXA3 ) );
00482 loop4.addStatement( rk_b.getRow( index3 ) += rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3),index3*(NVARS3)+NX1+NX2 )*rk_rhsTemp.getRows(0,NX1+NX2) );
00483 if( NXA3 > 0 ) {
00484 loop4.addStatement( rk_b.getRow( index3 ) += rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3)+NX1+NX2,index3*(NVARS3)+NX1+NX2+NXA )*rk_diffK.getSubMatrix( NX,NX+NXA,index2,index2+1 ) );
00485 }
00486 if( NDX3 > 0 ) {
00487 loop4.addStatement( rk_b.getRow( index3 ) += rk_diffsTemp3.getSubMatrix( index2,index2+1,index3*(NVARS3)+NVARS3-NX1-NX2,index3*(NVARS3)+NVARS3 )*rk_diffK.getSubMatrix( 0,NX1+NX2,index2,index2+1 ) );
00488 }
00489 loop1.addStatement( loop4 );
00490 }
00491 if( !STATES || number != 3 ) {
00492 ExportForLoop loop12( tmp_index1,0,index2 );
00493 for( i = 0; i < NX3; i++ ) {
00494 for( j = NX1+NX2; j < NX; j++ ) {
00495 if( acadoRoundAway(A33(i,j-NX1-NX2)) != 0 ) {
00496 loop12.addStatement( std::string( rk_b.get(i,0) ) + " += " + Ah.get(index2,tmp_index1) + "*" + toString(A33(i,j-NX1-NX2)) + "*" + rk_diffK.get(j,tmp_index1) + ";\n" );
00497 }
00498 }
00499 }
00500 loop1.addStatement( loop12 );
00501 }
00502
00503
00504 if( STATES && number == 3 ) {
00505 block->addStatement( rk_diffK.getRows(NX1+NX2,NX) == rk_dk3.getRows(index1*NX3-(NX1+NX2)*NX3,index1*NX3+NX3-(NX1+NX2)*NX3) );
00506 }
00507 else {
00508 ExportForLoop loop5( index3,0,NX3 );
00509 loop5.addStatement( tmp_index1 == index2*NX3+index3 );
00510 loop5.addStatement( rk_diffK.getElement(NX1+NX2+index3,index2) == rk_mat3.getElement(tmp_index1,0)*rk_b.getRow(0) );
00511 ExportForLoop loop6( index4,1,NX3 );
00512 loop6.addStatement( rk_diffK.getElement(NX1+NX2+index3,index2) += rk_mat3.getElement(tmp_index1,index4)*rk_b.getRow(index4) );
00513 loop5.addStatement(loop6);
00514 loop1.addStatement(loop5);
00515 block->addStatement( loop1 );
00516 }
00517
00518 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) block->addStatement( std::string( "if( run == 0 ) {\n" ) );
00519 ExportForLoop loop8( index2,0,NX3 );
00520 if( STATES && number == 3 ) loop8.addStatement( std::string(rk_diffsNew3.get( index2,index1 )) + " = (" + index2.getName() + " == " + index1.getName() + "-" + toString(NX1+NX2) + ");\n" );
00521
00522 if( STATES && number == 3 ) loop8.addStatement( rk_diffsNew3.getElement( index2,index1 ) += rk_diffK.getRow( NX1+NX2+index2 )*Bh );
00523 else if( STATES ) loop8.addStatement( rk_diffsNew3.getElement( index2,index1 ) == rk_diffK.getRow( NX1+NX2+index2 )*Bh );
00524 else loop8.addStatement( rk_diffsNew3.getElement( index2,index1+NX ) == rk_diffK.getRow( NX1+NX2+index2 )*Bh );
00525 block->addStatement( loop8 );
00526 if( grid.getNumIntervals() > 1 || !equidistantControlGrid() ) block->addStatement( std::string( "}\n" ) );
00527 }
00528
00529 return SUCCESSFUL_RETURN;
00530 }
00531
00532
00533 returnValue DiagonallyImplicitRKExport::prepareOutputSystem( ExportStatementBlock& code )
00534 {
00535 if( NX3 > 0 ) {
00536 DMatrix mat3 = formMatrix( M33, A33 );
00537 rk_mat3 = ExportVariable( "rk_mat3", mat3, STATIC_CONST_REAL );
00538 code.addDeclaration( rk_mat3 );
00539
00540 rk_mat3 = ExportVariable( "rk_mat3", numStages*NX3, NX3, STATIC_CONST_REAL, ACADO_LOCAL );
00541 double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00542
00543 DMatrix sens = zeros<double>(NX3*NX3, numStages);
00544 uint i, j, k, s1, s2;
00545 for( i = 0; i < NX3; i++ ) {
00546 DVector vec(NX3);
00547 for( j = 0; j < numStages; j++ ) {
00548 for( k = 0; k < NX3; k++ ) {
00549 vec(k) = A33(k,i);
00550 for( s1 = 0; s1 < j; s1++ ) {
00551 for( s2 = 0; s2 < NX3; s2++ ) {
00552 vec(k) = vec(k) + AA(j,s1)*h*A33(k,s2)*sens(i*NX3+s2,s1);
00553 }
00554 }
00555 }
00556 DVector sol = mat3*vec;
00557 for( k = 0; k < NX3; k++ ) {
00558 sens(i*NX3+k,j) = sol(k);
00559 }
00560 }
00561 }
00562 rk_dk3 = ExportVariable( "rk_dk3", sens, STATIC_CONST_REAL );
00563 code.addDeclaration( rk_dk3 );
00564
00565 rk_dk3 = ExportVariable( "rk_dk3", NX3*NX3, numStages, STATIC_CONST_REAL, ACADO_LOCAL );
00566 }
00567
00568 return SUCCESSFUL_RETURN;
00569 }
00570
00571
00572 returnValue DiagonallyImplicitRKExport::setup( )
00573 {
00574 returnValue IRKsetup = ForwardIRKExport::setup();
00575
00576 int debugMode;
00577 get( INTEGRATOR_DEBUG_MODE, debugMode );
00578
00579 int useOMP;
00580 get(CG_USE_OPENMP, useOMP);
00581 ExportStruct structWspace;
00582 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00583
00584 rk_A = ExportVariable( "rk_A", numStages*(NX2+NXA), NX2+NXA, REAL, structWspace );
00585 if ( (bool)debugMode == true && useOMP ) {
00586 return ACADOERROR( RET_INVALID_OPTION );
00587 }
00588 else {
00589 debug_mat = ExportVariable( "debug_mat", numStages*(NX2+NXA), NX2+NXA, REAL, ACADO_VARIABLES );
00590 }
00591 uint Xmax = NX1;
00592 if( NX2 > Xmax ) Xmax = NX2;
00593 if( NX3 > Xmax ) Xmax = NX3;
00594 rk_b = ExportVariable( "rk_b", Xmax+NXA, 1, REAL, structWspace );
00595
00596 if( NX2 > 0 || NXA > 0 ) {
00597 solver->init( NX2+NXA );
00598 solver->setup();
00599 rk_auxSolver = solver->getGlobalExportVariable( numStages );
00600 }
00601
00602 return IRKsetup;
00603 }
00604
00605
00606 CLOSE_NAMESPACE_ACADO
00607
00608