erk_adjoint_export.cpp
Go to the documentation of this file.
00001 /*
00002  *    This file is part of ACADO Toolkit.
00003  *
00004  *    ACADO Toolkit -- A Toolkit for Automatic Control and Dynamic Optimization.
00005  *    Copyright (C) 2008-2014 by Boris Houska, Hans Joachim Ferreau,
00006  *    Milan Vukov, Rien Quirynen, KU Leuven.
00007  *    Developed within the Optimization in Engineering Center (OPTEC)
00008  *    under supervision of Moritz Diehl. All rights reserved.
00009  *
00010  *    ACADO Toolkit is free software; you can redistribute it and/or
00011  *    modify it under the terms of the GNU Lesser General Public
00012  *    License as published by the Free Software Foundation; either
00013  *    version 3 of the License, or (at your option) any later version.
00014  *
00015  *    ACADO Toolkit is distributed in the hope that it will be useful,
00016  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018  *    Lesser General Public License for more details.
00019  *
00020  *    You should have received a copy of the GNU Lesser General Public
00021  *    License along with ACADO Toolkit; if not, write to the Free Software
00022  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
00023  *
00024  */
00025 
00026 
00027 
00034 #include <acado/code_generation/integrators/erk_export.hpp>
00035 #include <acado/code_generation/integrators/erk_adjoint_export.hpp>
00036 
00037 using namespace std;
00038 
00039 BEGIN_NAMESPACE_ACADO
00040 
00041 //
00042 // PUBLIC MEMBER FUNCTIONS:
00043 //
00044 
00045 AdjointERKExport::AdjointERKExport(     UserInteraction* _userInteraction,
00046                                                                         const std::string& _commonHeaderName
00047                                                                         ) : ExplicitRungeKuttaExport( _userInteraction,_commonHeaderName )
00048 {
00049 }
00050 
00051 
00052 AdjointERKExport::AdjointERKExport(     const AdjointERKExport& arg
00053                                                                         ) : ExplicitRungeKuttaExport( arg )
00054 {
00055 }
00056 
00057 
00058 AdjointERKExport::~AdjointERKExport( )
00059 {
00060         clear( );
00061 }
00062 
00063 
00064 
00065 returnValue AdjointERKExport::setDifferentialEquation(  const Expression& rhs_ )
00066 {
00067         int sensGen;
00068         get( DYNAMIC_SENSITIVITY,sensGen );
00069 
00070         OnlineData        dummy0;
00071         Control           dummy1;
00072         DifferentialState dummy2;
00073         AlgebraicState    dummy3;
00074         DifferentialStateDerivative dummy4;
00075         dummy0.clearStaticCounters();
00076         dummy1.clearStaticCounters();
00077         dummy2.clearStaticCounters();
00078         dummy3.clearStaticCounters();
00079         dummy4.clearStaticCounters();
00080 
00081         x = DifferentialState("", NX, 1);
00082         dx = DifferentialStateDerivative("", NDX, 1);
00083         z = AlgebraicState("", NXA, 1);
00084         u = Control("", NU, 1);
00085         od = OnlineData("", NOD, 1);
00086 
00087         if( NDX > 0 && NDX != NX ) {
00088                 return ACADOERROR( RET_INVALID_OPTION );
00089         }
00090         if( rhs_.getNumRows() != (NX+NXA) ) {
00091                 return ACADOERROR( RET_INVALID_OPTION );
00092         }
00093 
00094         DifferentialEquation f, f_ODE;
00095         // add usual ODE
00096         f_ODE << rhs_;
00097         if( f_ODE.getNDX() > 0 ) {
00098                 return ACADOERROR( RET_INVALID_OPTION );
00099         }
00100 
00101         if( (ExportSensitivityType)sensGen == BACKWARD ) {
00102                 DifferentialState lx("", NX,1), lu("", NU,1);
00103 
00104                 f << backwardDerivative(rhs_, x, lx);
00105                 f << backwardDerivative(rhs_, u, lx);
00106         }
00107         else {
00108                 return ACADOERROR( RET_INVALID_OPTION );
00109         }
00110         if( f.getNT() > 0 ) timeDependant = true;
00111 
00112         return rhs.init(f_ODE, "acado_rhs", NX, 0, NU, NP, NDX, NOD)
00113                         & diffs_rhs.init(f, "acado_rhs_back", 2*NX + NU, 0, NU, NP, NDX, NOD);
00114 }
00115 
00116 
00117 returnValue AdjointERKExport::setup( )
00118 {
00119         int sensGen;
00120         get( DYNAMIC_SENSITIVITY,sensGen );
00121         if ( (ExportSensitivityType)sensGen != BACKWARD ) ACADOERROR( RET_INVALID_OPTION );
00122 
00123         // NOT SUPPORTED: since the forward sweep needs to be saved
00124         if( !equidistantControlGrid() )         ACADOERROR( RET_INVALID_OPTION );
00125 
00126         // NOT SUPPORTED: since the adjoint derivatives could be 'arbitrarily bad'
00127         if( !is_symmetric )                             ACADOERROR( RET_INVALID_OPTION );
00128 
00129         LOG( LVL_DEBUG ) << "Preparing to export AdjointERKExport... " << endl;
00130 
00131         // export RK scheme
00132         uint rhsDim   = 2*NX+NU;
00133         inputDim = 2*NX+NU + NU + NOD;
00134         const uint rkOrder  = getNumStages();
00135 
00136         double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();    
00137 
00138         ExportVariable Ah ( "A*h",  DMatrix( AA )*=h );
00139         ExportVariable b4h( "b4*h", DMatrix( bb )*=h );
00140 
00141         rk_index = ExportVariable( "rk_index", 1, 1, INT, ACADO_LOCAL, true );
00142         rk_eta = ExportVariable( "rk_eta", 1, inputDim );
00143 //      seed_backward.setup( "seed", 1, NX );
00144 
00145         int useOMP;
00146         get(CG_USE_OPENMP, useOMP);
00147         ExportStruct structWspace;
00148         structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00149 
00150         rk_ttt.setup( "rk_ttt", 1, 1, REAL, structWspace, true );
00151         uint timeDep = 0;
00152         if( timeDependant ) timeDep = 1;
00153         
00154         rk_xxx.setup("rk_xxx", 1, inputDim+timeDep, REAL, structWspace);
00155         rk_kkk.setup("rk_kkk", rkOrder, NX+NU, REAL, structWspace);
00156         rk_forward_sweep.setup("rk_sweep1", 1, grid.getNumIntervals()*rkOrder*NX, REAL, structWspace);
00157 
00158         if ( useOMP )
00159         {
00160                 ExportVariable auxVar;
00161 
00162                 auxVar = getAuxVariable();
00163                 auxVar.setName( "odeAuxVar" );
00164                 auxVar.setDataStruct( ACADO_LOCAL );
00165                 rhs.setGlobalExportVariable( auxVar );
00166                 diffs_rhs.setGlobalExportVariable( auxVar );
00167         }
00168 
00169         ExportIndex run( "run1" );
00170 
00171         // setup INTEGRATE function
00172         integrate = ExportFunction( "integrate", rk_eta, reset_int );
00173         integrate.setReturnValue( error_code );
00174         rk_eta.setDoc( "Working array to pass the input values and return the results." );
00175         reset_int.setDoc( "The internal memory of the integrator can be reset." );
00176         rk_index.setDoc( "Number of the shooting interval." );
00177         error_code.setDoc( "Status code of the integrator." );
00178         integrate.doc( "Performs the integration and sensitivity propagation for one shooting interval." );
00179         integrate.addIndex( run );
00180         
00181         integrate.addStatement( rk_ttt == DMatrix(grid.getFirstTime()) );
00182 
00183         if( inputDim > rhsDim ) {
00184 //              integrate.addStatement( rk_eta.getCols( NX,2*NX ) == seed_backward );
00185                 integrate.addStatement( rk_eta.getCols( 2*NX,2*NX+NU ) == zeros<double>( 1,NU ) );
00186                 // FORWARD SWEEP FIRST
00187                 integrate.addStatement( rk_xxx.getCols( NX,NX+NU+NOD ) == rk_eta.getCols( rhsDim,inputDim ) );
00188         }
00189         integrate.addLinebreak( );
00190 
00191     // integrator loop: FORWARD SWEEP
00192         ExportForLoop loop = ExportForLoop( run, 0, grid.getNumIntervals() );
00193         for( uint run1 = 0; run1 < rkOrder; run1++ )
00194         {
00195                 loop.addStatement( rk_xxx.getCols( 0,NX ) == rk_eta.getCols( 0,NX ) + Ah.getRow(run1)*rk_kkk.getCols( 0,NX ) );
00196                 // save forward trajectory
00197                 loop.addStatement( rk_forward_sweep.getCols( run*rkOrder*NX+run1*NX,run*rkOrder*NX+run1*NX+NX ) == rk_xxx.getCols( 0,NX ) );
00198                 if( timeDependant ) loop.addStatement( rk_xxx.getCol( NX+NU+NOD ) == rk_ttt + ((double)cc(run1))/grid.getNumIntervals() );
00199                 loop.addFunctionCall( getNameRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00200         }
00201         loop.addStatement( rk_eta.getCols( 0,NX ) += b4h^rk_kkk.getCols( 0,NX ) );
00202         loop.addStatement( rk_ttt += DMatrix(1.0/grid.getNumIntervals()) );
00203     // end of integrator loop: FORWARD SWEEP
00204         integrate.addStatement( loop );
00205 
00206 //      if( !is_symmetric ) {
00207 //              integrate.addStatement( rk_xxx.getCols( 0,NX ) == rk_eta.getCols( 0,NX ) );
00208 //      }
00209         if( inputDim > rhsDim ) {
00210                 // BACKWARD SWEEP NEXT
00211                 integrate.addStatement( rk_xxx.getCols( rhsDim,inputDim ) == rk_eta.getCols( rhsDim,inputDim ) );
00212         }
00213     // integrator loop: BACKWARD SWEEP
00214         ExportForLoop loop2 = ExportForLoop( run, 0, grid.getNumIntervals() );
00215         for( uint run1 = 0; run1 < rkOrder; run1++ )
00216         {
00217                 // load forward trajectory
00218 //              if( is_symmetric ) {
00219                         loop2.addStatement( rk_xxx.getCols( 0,NX ) == rk_forward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX-run1*NX-NX,(grid.getNumIntervals()-run)*rkOrder*NX-run1*NX ) );
00220 //              }
00221                 loop2.addStatement( rk_xxx.getCols( NX,2*NX+NU ) == rk_eta.getCols( NX,2*NX+NU ) + Ah.getRow(run1)*rk_kkk );
00222                 if( timeDependant ) loop2.addStatement( rk_xxx.getCol( inputDim ) == rk_ttt - ((double)cc(run1))/grid.getNumIntervals() );
00223                 loop2.addFunctionCall( getNameDiffsRHS(),rk_xxx,rk_kkk.getAddress(run1,0) );
00224 //              if( !is_symmetric ) {
00225 //                      loop2.addStatement( rk_xxx.getCols( 0,NX ) == rk_forward_sweep.getCols( (grid.getNumIntervals()-run)*rkOrder*NX-run1*NX-NX,(grid.getNumIntervals()-run)*rkOrder*NX-run1*NX ) );
00226 //              }
00227         }
00228         loop2.addStatement( rk_eta.getCols( NX,2*NX+NU ) += b4h^rk_kkk );
00229         loop2.addStatement( rk_ttt -= DMatrix(1.0/grid.getNumIntervals()) );
00230     // end of integrator loop: BACKWARD SWEEP
00231         integrate.addStatement( loop2 );
00232 
00233         integrate.addStatement( error_code == 0 );
00234 
00235         LOG( LVL_DEBUG ) << "done" << endl;
00236 
00237         return SUCCESSFUL_RETURN;
00238 }
00239 
00240 
00241 returnValue AdjointERKExport::getDataDeclarations(      ExportStatementBlock& declarations,
00242                                                                                                         ExportStruct dataStruct
00243                                                                                                         ) const
00244 {
00245         ExplicitRungeKuttaExport::getDataDeclarations( declarations, dataStruct );
00246 
00247         declarations.addDeclaration( rk_forward_sweep,dataStruct );
00248 
00249     return SUCCESSFUL_RETURN;
00250 }
00251 
00252 
00253 returnValue AdjointERKExport::getCode(  ExportStatementBlock& code
00254                                                                                 )
00255 {
00256         int useOMP;
00257         get(CG_USE_OPENMP, useOMP);
00258         if ( useOMP )
00259         {
00260                 getDataDeclarations( code, ACADO_LOCAL );
00261 
00262                 code << "#pragma omp threadprivate( "
00263                                 << getAuxVariable().getFullName()  << ", "
00264                                 << rk_xxx.getFullName() << ", "
00265                                 << rk_ttt.getFullName() << ", "
00266                                 << rk_kkk.getFullName() << ", "
00267                                 << rk_forward_sweep.getFullName()
00268                                 << " )\n\n";
00269         }
00270 
00271         int sensGen;
00272         get( DYNAMIC_SENSITIVITY,sensGen );
00273         if( exportRhs ) {
00274                 code.addFunction( rhs );
00275                 code.addFunction( diffs_rhs );
00276         }
00277 
00278         double h = (grid.getLastTime() - grid.getFirstTime())/grid.getNumIntervals();
00279         code.addComment(std::string("Fixed step size:") + toString(h));
00280         code.addFunction( integrate );
00281 
00282         return SUCCESSFUL_RETURN;
00283 }
00284 
00285 
00286 // PROTECTED:
00287 
00288 
00289 
00290 CLOSE_NAMESPACE_ACADO
00291 
00292 // end of file.


acado
Author(s): Milan Vukov, Rien Quirynen
autogenerated on Sat Jun 8 2019 19:37:01