erk_3sweep_export.cpp
Go to the documentation of this file.
00001 /*
00002  *    This file is part of ACADO Toolkit.
00003  *
00004  *    ACADO Toolkit -- A Toolkit for Automatic Control and Dynamic Optimization.
00005  *    Copyright (C) 2008-2014 by Boris Houska, Hans Joachim Ferreau,
00006  *    Milan Vukov, Rien Quirynen, KU Leuven.
00007  *    Developed within the Optimization in Engineering Center (OPTEC)
00008  *    under supervision of Moritz Diehl. All rights reserved.
00009  *
00010  *    ACADO Toolkit is free software; you can redistribute it and/or
00011  *    modify it under the terms of the GNU Lesser General Public
00012  *    License as published by the Free Software Foundation; either
00013  *    version 3 of the License, or (at your option) any later version.
00014  *
00015  *    ACADO Toolkit is distributed in the hope that it will be useful,
00016  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018  *    Lesser General Public License for more details.
00019  *
00020  *    You should have received a copy of the GNU Lesser General Public
00021  *    License along with ACADO Toolkit; if not, write to the Free Software
00022  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
00023  *
00024  */
00025 
00026 
00027 
00034 #include <acado/code_generation/integrators/erk_export.hpp>
00035 #include <acado/code_generation/integrators/erk_3sweep_export.hpp>
00036 
00037 using namespace std;
00038 
00039 BEGIN_NAMESPACE_ACADO
00040 
00041 //
00042 // PUBLIC MEMBER FUNCTIONS:
00043 //
00044 
00045 ThreeSweepsERKExport::ThreeSweepsERKExport(     UserInteraction* _userInteraction,
00046                                                                         const std::string& _commonHeaderName
00047                                                                         ) : AdjointERKExport( _userInteraction,_commonHeaderName )
00048 {
00049 }
00050 
00051 
00052 ThreeSweepsERKExport::ThreeSweepsERKExport(     const ThreeSweepsERKExport& arg
00053                                                                         ) : AdjointERKExport( arg )
00054 {
00055 }
00056 
00057 
00058 ThreeSweepsERKExport::~ThreeSweepsERKExport( )
00059 {
00060         clear( );
00061 }
00062 
00063 
00064 
00065 returnValue ThreeSweepsERKExport::setDifferentialEquation(      const Expression& rhs_ )
00066 {
00067         int sensGen;
00068         get( DYNAMIC_SENSITIVITY,sensGen );
00069 
00070         OnlineData        dummy0;
00071         Control           dummy1;
00072         DifferentialState dummy2;
00073         AlgebraicState    dummy3;
00074         DifferentialStateDerivative dummy4;
00075         dummy0.clearStaticCounters();
00076         dummy1.clearStaticCounters();
00077         dummy2.clearStaticCounters();
00078         dummy3.clearStaticCounters();
00079         dummy4.clearStaticCounters();
00080 
00081         x = DifferentialState("", NX, 1);
00082         dx = DifferentialStateDerivative("", NDX, 1);
00083         z = AlgebraicState("", NXA, 1);
00084         u = Control("", NU, 1);
00085         od = OnlineData("", NOD, 1);
00086 
00087         if( NDX > 0 && NDX != NX ) {
00088                 return ACADOERROR( RET_INVALID_OPTION );
00089         }
00090         if( rhs_.getNumRows() != (NX+NXA) ) {
00091                 return ACADOERROR( RET_INVALID_OPTION );
00092         }
00093 
00094         DifferentialEquation f, g, h, f_ODE;
00095         // add usual ODE
00096         f_ODE << rhs_;
00097         if( f_ODE.getNDX() > 0 ) {
00098                 return ACADOERROR( RET_INVALID_OPTION );
00099         }
00100 
00101         uint numX = NX*(NX+1)/2.0;
00102         uint numU = NU*(NU+1)/2.0;
00103         uint numZ = (NX+NU)*(NX+NU+1)/2.0;
00104         if( (ExportSensitivityType)sensGen == THREE_SWEEPS ) {
00105                 // SWEEP 1:
00106                 // ---------
00107                 f << rhs_;
00108 
00109 
00110                 // SWEEP 2:
00111                 // ---------
00112                 DifferentialState lx("", NX,1);
00113 
00114                 Expression tmp = backwardDerivative(rhs_, x, lx);
00115                 g << tmp;
00116 
00117 
00118                 // SWEEP 3:
00119                 // ---------
00120                 DifferentialState Gx("", NX,NX), Gu("", NX,NU);
00121                 DifferentialState H("", numZ,1);
00122 
00123                 Expression S = Gx;
00124                 S.appendCols(Gu);
00125                 Expression arg;
00126                 arg << x;
00127                 arg << u;
00128 
00129                 // SYMMETRIC DERIVATIVES
00130                 Expression S_tmp = S;
00131                 S_tmp.appendRows(zeros<double>(NU,NX).appendCols(eye<double>(NU)));
00132 
00133                 Expression dfS;
00134                 Expression h_tmp = symmetricDerivative( rhs_, arg, S_tmp, lx, &dfS );
00135                 Expression VDE_X;
00136                 Expression VDE_U;
00137                 for( uint i = 0; i < NX; i++ ) {
00138                         VDE_X.appendCols(dfS.getCol(i));
00139                 }
00140                 for( uint i = NX; i < NX+NU; i++ ) {
00141                         VDE_U.appendCols(dfS.getCol(i));
00142                 }
00143                 h << VDE_X;
00144                 h << VDE_U;
00145                 h << returnLowerTriangular( h_tmp );
00146 
00147                 // OLD VERSION:
00148 //              // add VDE for differential states
00149 //              h << multipleForwardDerivative( rhs_, x, Gx );
00150 //
00151 //              // add VDE for control inputs
00152 //              h << multipleForwardDerivative( rhs_, x, Gu ) + forwardDerivative( rhs_, u );
00153 //
00154 //              IntermediateState tmp2 = forwardDerivative(tmp, x);
00155 //              Expression tmp3 = backwardDerivative(rhs_, u, lx);
00156 //              Expression tmp4 = multipleForwardDerivative(tmp3, x, Gu);
00157 //
00158 //              // TODO: include a symmetric_AD_operator to strongly improve the symmetric left-right multiplied second order derivative computations !!
00160 //              h << symmetricDoubleProduct(tmp2, Gx);
00161 //              h << Gu.transpose()*tmp2*Gx + multipleForwardDerivative(tmp3, x, Gx);
00162 //              Expression tmp7 = tmp4 + tmp4.transpose() + forwardDerivative(tmp3, u);
00163 //              h << symmetricDoubleProduct(tmp2, Gu) + returnLowerTriangular(tmp7);
00164         }
00165         else {
00166                 return ACADOERROR( RET_INVALID_OPTION );
00167         }
00168         if( f.getNT() > 0 ) timeDependant = true;
00169 
00170         return rhs.init(f, "acado_forward", NX, 0, NU, NP, NDX, NOD)
00171                         & diffs_rhs.init(g, "acado_backward", 2*NX, 0, NU, NP, NDX, NOD)
00172                         & diffs_sweep3.init(h, "acado_forward_sweep3", 2*NX + NX*(NX+NU) + numX + NX*NU + numU, 0, NU, NP, NDX, NOD);
00173 }
00174 
00175 
00176 returnValue ThreeSweepsERKExport::setup( )
00177 {
00178         int sensGen;
00179         get( DYNAMIC_SENSITIVITY,sensGen );
00180         if ( (ExportSensitivityType)sensGen != THREE_SWEEPS ) ACADOERROR( RET_INVALID_OPTION );
00181 
00182         // NOT SUPPORTED: since the forward sweep needs to be saved
00183         if( !equidistantControlGrid() )         ACADOERROR( RET_INVALID_OPTION );
00184 
00185         // NOT SUPPORTED: since the adjoint derivatives could be 'arbitrarily bad'
00186         if( !is_symmetric )                             ACADOERROR( RET_INVALID_OPTION );
00187 
00188         LOG( LVL_DEBUG ) << "Preparing to export ThreeSweepsERKExport... " << endl;
00189 
00190         // export RK scheme
00191         uint numX = NX*(NX+1)/2.0;
00192         uint numU = NU*(NU+1)/2.0;
00193         uint rhsDim   = NX + NX + NX*(NX+NU) + numX + NX*NU + numU;
00194         inputDim = rhsDim + NU + NOD;
00195         const uint rkOrder  = getNumStages();
00196 
00197         double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();    
00198 
00199         ExportVariable Ah ( "A*h",  DMatrix( AA )*=h );
00200         ExportVariable b4h( "b4*h", DMatrix( bb )*=h );
00201 
00202         rk_index = ExportVariable( "rk_index", 1, 1, INT, ACADO_LOCAL, true );
00203         rk_eta = ExportVariable( "rk_eta", 1, inputDim );
00204 //      seed_backward.setup( "seed", 1, NX );
00205 
00206         int useOMP;
00207         get(CG_USE_OPENMP, useOMP);
00208         ExportStruct structWspace;
00209         structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00210 
00211         rk_ttt.setup( "rk_ttt", 1, 1, REAL, structWspace, true );
00212         uint timeDep = 0;
00213         if( timeDependant ) timeDep = 1;
00214         
00215         rk_xxx.setup("rk_xxx", 1, inputDim+timeDep, REAL, structWspace);
00216         uint numK = NX*(NX+NU)+numX+NX*NU+numU;
00217         rk_kkk.setup("rk_kkk", rkOrder, numK, REAL, structWspace);
00218         rk_forward_sweep.setup("rk_sweep1", 1, grid.getNumIntervals()*rkOrder*NX, REAL, structWspace);
00219         rk_backward_sweep.setup("rk_sweep2", 1, grid.getNumIntervals()*rkOrder*NX, REAL, structWspace);
00220 
00221         if ( useOMP )
00222         {
00223                 ExportVariable auxVar;
00224 
00225                 auxVar = diffs_rhs.getGlobalExportVariable();
00226                 auxVar.setName( "odeAuxVar" );
00227                 auxVar.setDataStruct( ACADO_LOCAL );
00228                 diffs_rhs.setGlobalExportVariable( auxVar );
00229         }
00230 
00231         ExportIndex run( "run1" );
00232 
00233         // setup INTEGRATE function
00234         integrate = ExportFunction( "integrate", rk_eta, reset_int );
00235         integrate.setReturnValue( error_code );
00236         rk_eta.setDoc( "Working array to pass the input values and return the results." );
00237         reset_int.setDoc( "The internal memory of the integrator can be reset." );
00238         rk_index.setDoc( "Number of the shooting interval." );
00239         error_code.setDoc( "Status code of the integrator." );
00240         integrate.doc( "Performs the integration and sensitivity propagation for one shooting interval." );
00241         integrate.addIndex( run );
00242         
00243         integrate.addStatement( rk_ttt == DMatrix(grid.getFirstTime()) );
00244 
00245         if( inputDim > rhsDim ) {
00246                 // initialize sensitivities:
00247 //              integrate.addStatement( rk_eta.getCols( NX,2*NX ) == seed_backward );
00248                 DMatrix idX    = eye<double>( NX );
00249                 DMatrix zeroXU = zeros<double>( NX,NU );
00250                 integrate.addStatement( rk_eta.getCols( 2*NX,NX*(2+NX) ) == idX.makeVector().transpose() );
00251                 integrate.addStatement( rk_eta.getCols( NX*(2+NX),NX*(2+NX+NU) ) == zeroXU.makeVector().transpose() );
00252 
00253                 integrate.addStatement( rk_eta.getCols( NX*(2+NX+NU),rhsDim ) == zeros<double>( 1,numX+NX*NU+numU ) );
00254                 // FORWARD SWEEP FIRST
00255                 integrate.addStatement( rk_xxx.getCols( NX,NX+NU+NOD ) == rk_eta.getCols( rhsDim,inputDim ) );
00256         }
00257         integrate.addLinebreak( );
00258 
00259     // integrator loop: FORWARD SWEEP
00260         ExportForLoop loop = ExportForLoop( run, 0, grid.getNumIntervals() );
00261         for( uint run1 = 0; run1 < rkOrder; run1++ )
00262         {
00263                 loop.addStatement( rk_xxx.getCols( 0,NX ) == rk_eta.getCols( 0,NX ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX ) );
00264                 // save forward trajectory
00265                 loop.addStatement( rk_forward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+(run1+1)*NX ) == rk_xxx.getCols( 0,NX ) );
00266                 if( timeDependant ) loop.addStatement( rk_xxx.getCol( NX+NU+NOD ) == rk_ttt + ((double)cc(run1))/grid.getNumIntervals() );
00267                 loop.addFunctionCall( getNameRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00268         }
00269         loop.addStatement( rk_eta.getCols( 0,NX ) += b4h^rk_kkk.getCols( 0,NX ) );
00270         loop.addStatement( rk_ttt += DMatrix(1.0/grid.getNumIntervals()) );
00271     // end of integrator loop: FORWARD SWEEP
00272         integrate.addStatement( loop );
00273 
00274         if( inputDim > rhsDim ) {
00275                 // BACKWARD SWEEP NEXT
00276                 integrate.addStatement( rk_xxx.getCols( 2*NX,2*NX+NU+NOD ) == rk_eta.getCols( rhsDim,inputDim ) );
00277         }
00278     // integrator loop: BACKWARD SWEEP
00279         ExportForLoop loop2 = ExportForLoop( run, 0, grid.getNumIntervals() );
00280         for( uint run1 = 0; run1 < rkOrder; run1++ )
00281         {
00282                 // load forward trajectory
00283                 loop2.addStatement( rk_xxx.getCols( 0,NX ) == rk_forward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX-(run1+1)*NX,(grid.getNumIntervals()-run)*rkOrder*NX-run1*NX ) );
00284                 loop2.addStatement( rk_xxx.getCols( NX,2*NX ) == rk_eta.getCols( NX,2*NX ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX ) );
00285                 // save backward trajectory
00286                 loop2.addStatement( rk_backward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+(run1+1)*NX ) == rk_xxx.getCols( NX,2*NX ) );
00287                 if( timeDependant ) loop2.addStatement( rk_xxx.getCol( 2*NX+NU+NOD ) == rk_ttt - ((double)cc(run1))/grid.getNumIntervals() );
00288                 loop2.addFunctionCall( getNameDiffsRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00289         }
00290         loop2.addStatement( rk_eta.getCols( NX,2*NX ) += b4h^rk_kkk.getCols( 0,NX ) );
00291         loop2.addStatement( rk_ttt -= DMatrix(1.0/grid.getNumIntervals()) );
00292     // end of integrator loop: BACKWARD SWEEP
00293         integrate.addStatement( loop2 );
00294 
00295         if( inputDim > rhsDim ) {
00296                 // THIRD SWEEP NEXT
00297                 integrate.addStatement( rk_xxx.getCols( rhsDim,inputDim ) == rk_eta.getCols( rhsDim,inputDim ) );
00298         }
00299     // integrator loop: THIRD SWEEP
00300         ExportForLoop loop3 = ExportForLoop( run, 0, grid.getNumIntervals() );
00301         for( uint run1 = 0; run1 < rkOrder; run1++ )
00302         {
00303                 // load forward trajectory
00304                 loop3.addStatement( rk_xxx.getCols( 0,NX ) == rk_forward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+(run1+1)*NX ) );
00305                 // load backward trajectory
00306                 loop3.addStatement( rk_xxx.getCols( NX,2*NX ) == rk_backward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX-(run1+1)*NX,(grid.getNumIntervals()-run)*rkOrder*NX-run1*NX ) );
00307                 loop3.addStatement( rk_xxx.getCols( 2*NX,rhsDim ) == rk_eta.getCols( 2*NX,rhsDim ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX*(NX+NU)+numX+NX*NU+numU ) );
00308                 if( timeDependant ) loop3.addStatement( rk_xxx.getCol( inputDim ) == rk_ttt + ((double)cc(run1))/grid.getNumIntervals() );
00309                 loop3.addFunctionCall( diffs_sweep3.getName(),rk_xxx,rk_kkk.getAddress(run1,0) );
00310         }
00311         loop3.addStatement( rk_eta.getCols( 2*NX,rhsDim ) += b4h^rk_kkk.getCols( 0,NX*(NX+NU)+numX+NX*NU+numU ) );
00312         loop3.addStatement( rk_ttt += DMatrix(1.0/grid.getNumIntervals()) );
00313     // end of integrator loop: THIRD SWEEP
00314         integrate.addStatement( loop3 );
00315 
00316         integrate.addStatement( error_code == 0 );
00317 
00318         LOG( LVL_DEBUG ) << "done" << endl;
00319 
00320         return SUCCESSFUL_RETURN;
00321 }
00322 
00323 
00324 returnValue ThreeSweepsERKExport::getDataDeclarations(  ExportStatementBlock& declarations,
00325                                                                                                                 ExportStruct dataStruct
00326                                                                                                                 ) const
00327 {
00328         AdjointERKExport::getDataDeclarations( declarations, dataStruct );
00329 
00330         declarations.addDeclaration( rk_backward_sweep,dataStruct );
00331 
00332     return SUCCESSFUL_RETURN;
00333 }
00334 
00335 
00336 returnValue ThreeSweepsERKExport::getCode(      ExportStatementBlock& code
00337                                                                                         )
00338 {
00339         int useOMP;
00340         get(CG_USE_OPENMP, useOMP);
00341         if ( useOMP )
00342         {
00343                 getDataDeclarations( code, ACADO_LOCAL );
00344 
00345                 code << "#pragma omp threadprivate( "
00346                                 << getAuxVariable().getFullName()  << ", "
00347                                 << rk_xxx.getFullName() << ", "
00348                                 << rk_ttt.getFullName() << ", "
00349                                 << rk_kkk.getFullName() << ", "
00350                                 << rk_forward_sweep.getFullName() << ", "
00351                                 << rk_backward_sweep.getFullName()
00352                                 << " )\n\n";
00353         }
00354 
00355         int sensGen;
00356         get( DYNAMIC_SENSITIVITY,sensGen );
00357         if( exportRhs ) {
00358                 code.addFunction( rhs );
00359                 code.addFunction( diffs_rhs );
00360                 code.addFunction( diffs_sweep3 );
00361         }
00362 
00363         double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00364         code.addComment(std::string("Fixed step size:") + toString(h));
00365         code.addFunction( integrate );
00366 
00367         return SUCCESSFUL_RETURN;
00368 }
00369 
00370 
00371 Expression ThreeSweepsERKExport::returnLowerTriangular( const Expression& expr ) {
00372 //      std::cout << "returnLowerTriangular with " << expr.getNumRows() << " rows and " << expr.getNumCols() << " columns\n";
00373         ASSERT( expr.getNumRows() == expr.getNumCols() );
00374 
00375         Expression new_expr;
00376         for( uint i = 0; i < expr.getNumRows(); i++ ) {
00377                 for( uint j = 0; j <= i; j++ ) {
00378                         new_expr << expr(i,j);
00379                 }
00380         }
00381         return new_expr;
00382 }
00383 
00384 
00385 Expression ThreeSweepsERKExport::symmetricDoubleProduct( const Expression& expr, const Expression& arg ) {
00386 
00387         // NOTE: the speedup of the three-sweeps-propagation approach is strongly dependent on the support for this specific operator which shows many symmetries
00388         uint dim = arg.getNumCols();
00389         uint dim2 = arg.getNumRows();
00390 
00391         IntermediateState inter_res = zeros<double>(dim2,dim);
00392         for( uint i = 0; i < dim; i++ ) {
00393                 for( uint k1 = 0; k1 < dim2; k1++ ) {
00394                         for( uint k2 = 0; k2 <= k1; k2++ ) {
00395                                 inter_res(k1,i) += expr(k1,k2)*arg(k2,i);
00396                         }
00397                         for( uint k2 = k1+1; k2 < dim2; k2++ ) {
00398                                 inter_res(k1,i) += expr(k2,k1)*arg(k2,i);
00399                         }
00400                 }
00401         }
00402 
00403         Expression new_expr;
00404         for( uint i = 0; i < dim; i++ ) {
00405                 for( uint j = 0; j <= i; j++ ) {
00406                         Expression new_tmp = 0;
00407                         for( uint k1 = 0; k1 < dim2; k1++ ) {
00408                                 new_tmp = new_tmp+arg(k1,i)*inter_res(k1,j);
00409                         }
00410                         new_expr << new_tmp;
00411                 }
00412         }
00413         return new_expr;
00414 //      return returnLowerTriangular(arg.transpose()*expr*arg, dim);
00415 }
00416 
00417 
00418 ExportVariable ThreeSweepsERKExport::getAuxVariable() const
00419 {
00420         ExportVariable max;
00421         max = rhs.getGlobalExportVariable();
00422         if( diffs_rhs.getGlobalExportVariable().getDim() > max.getDim() ) {
00423                 max = diffs_rhs.getGlobalExportVariable();
00424         }
00425         if( diffs_sweep3.getGlobalExportVariable().getDim() > max.getDim() ) {
00426                 max = diffs_sweep3.getGlobalExportVariable();
00427         }
00428         return max;
00429 }
00430 
00431 // PROTECTED:
00432 
00433 
00434 
00435 CLOSE_NAMESPACE_ACADO
00436 
00437 // end of file.


acado
Author(s): Milan Vukov, Rien Quirynen
autogenerated on Sat Jun 8 2019 19:37:01