gaussian_elimination_export.cpp
Go to the documentation of this file.
00001 /*
00002  *    This file is part of ACADO Toolkit.
00003  *
00004  *    ACADO Toolkit -- A Toolkit for Automatic Control and Dynamic Optimization.
00005  *    Copyright (C) 2008-2014 by Boris Houska, Hans Joachim Ferreau,
00006  *    Milan Vukov, Rien Quirynen, KU Leuven.
00007  *    Developed within the Optimization in Engineering Center (OPTEC)
00008  *    under supervision of Moritz Diehl. All rights reserved.
00009  *
00010  *    ACADO Toolkit is free software; you can redistribute it and/or
00011  *    modify it under the terms of the GNU Lesser General Public
00012  *    License as published by the Free Software Foundation; either
00013  *    version 3 of the License, or (at your option) any later version.
00014  *
00015  *    ACADO Toolkit is distributed in the hope that it will be useful,
00016  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018  *    Lesser General Public License for more details.
00019  *
00020  *    You should have received a copy of the GNU Lesser General Public
00021  *    License along with ACADO Toolkit; if not, write to the Free Software
00022  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
00023  *
00024  */
00025 
00026 
00027 
00034 #include <acado/code_generation/linear_solvers/gaussian_elimination_export.hpp>
00035 
00036 using namespace std;
00037 
00038 BEGIN_NAMESPACE_ACADO
00039 
00040 //
00041 // PUBLIC MEMBER FUNCTIONS:
00042 //
00043 
00044 ExportGaussElim::ExportGaussElim( UserInteraction* _userInteraction,
00045                                                                         const std::string& _commonHeaderName
00046                                                                         ) : ExportLinearSolver( _userInteraction,_commonHeaderName )
00047 {
00048 }
00049 
00050 ExportGaussElim::~ExportGaussElim( )
00051 {}
00052 
00053 returnValue ExportGaussElim::getDataDeclarations(       ExportStatementBlock& declarations,
00054                                                                                                                 ExportStruct dataStruct
00055                                                                                                                 ) const
00056 {
00057         declarations.addDeclaration( rk_swap,dataStruct );                      // needed for the row swaps
00058         if( REUSE ) {
00059                 declarations.addDeclaration( rk_bPerm,dataStruct );             // reordered right-hand side
00060         }
00061 
00062         return SUCCESSFUL_RETURN;
00063 }
00064 
00065 
00066 returnValue ExportGaussElim::getFunctionDeclarations(   ExportStatementBlock& declarations
00067                                                                                                                         ) const
00068 {
00069         declarations.addDeclaration( solve );
00070         declarations.addDeclaration( solveTriangular );
00071         if( REUSE ) {
00072                 declarations.addDeclaration( solveReuse );
00073         }
00074 
00075         return SUCCESSFUL_RETURN;
00076 }
00077 
00078 
00079 returnValue ExportGaussElim::getCode(   ExportStatementBlock& code
00080                                                                                         )
00081 {
00082         uint run1, run2, run3;
00083         // Solve the upper triangular system of equations:
00084         for( run1 = dim; run1 > 0; run1--) {
00085                 for( run2 = dim-1; run2 > (run1-1); run2--) {
00086                         solveTriangular.addStatement( b.getRow( (run1-1) ) -= A.getSubMatrix( (run1-1),(run1-1)+1,run2,run2+1 ) * b.getRow( run2 ) );
00087                 }
00088                 solveTriangular << "b[" << toString( (run1-1) ) << "] = b[" << toString( (run1-1) ) << "]/A[" << toString( (run1-1)*dim+(run1-1) ) << "];\n";
00089         }
00090         code.addFunction( solveTriangular );
00091 
00092         ExportIndex i( "i" );
00093         solve.addIndex( i );
00094         ExportIndex j( "j" );
00095         ExportIndex k( "k" );
00096         ExportVariable indexMax( "indexMax", 1, 1, INT, ACADO_LOCAL, true );
00097         ExportVariable valueMax( "valueMax", 1, 1, REAL, ACADO_LOCAL, true );
00098         ExportVariable temp( "temp", 1, 1, REAL, ACADO_LOCAL, true );
00099         if( !UNROLLING ) {
00100                 solve.addIndex( j );
00101                 solve.addIndex( k );
00102                 solve.addDeclaration( indexMax );
00103                 solve.addDeclaration( valueMax );
00104                 solve.addDeclaration( temp );
00105         }
00106         
00107         // initialise rk_perm (the permutation vector)
00108         if( REUSE ) {
00109                 ExportForLoop loop1( i,0,dim );
00110                 loop1 << rk_perm.get( 0,i ) << " = " << i.getName() << ";\n";
00111                 solve.addStatement( loop1 );
00112         }
00113         
00114         solve.addStatement( determinant == 1 );
00115         if( UNROLLING || dim <= 5 ) {
00116                 // Start the factorization:
00117                 for( run1 = 0; run1 < (dim-1); run1++ ) {
00118                         // Search for pivot in column run1:
00119                         for( run2 = run1+1; run2 < dim; run2++ ) {
00120                                 // add the test (if or else if):
00121                                 stringstream test;
00122                                 if( run2 == (run1+1) ) {
00123                                         test << "if(";
00124                                 } else {
00125                                         test << "else if(";
00126                                 }
00127                                 test << "fabs(A[" << toString( run2*dim+run1 ) << "]) > fabs(A[" << toString( run1*dim+run1 ) << "])";
00128                                 for( run3 = run1+1; run3 < dim; run3++ ) {
00129                                         if( run3 != run2) {
00130                                                 test << " && fabs(A[" << toString( run2*dim+run1 ) << "]) > fabs(A[" << toString( run3*dim+run1 ) << "])";
00131                                         }
00132                                 }
00133                                 test << ") {\n";
00134                                 solve.addStatement( test.str() );
00135                         
00136                                 // do the row swaps:
00137                                 // for A:
00138                                 for( run3 = 0; run3 < dim; run3++ ) {
00139                                         solve.addStatement( rk_swap == A.getSubMatrix( run1,run1+1,run3,run3+1 ) );
00140                                         solve.addStatement( A.getSubMatrix( run1,run1+1,run3,run3+1 ) == A.getSubMatrix( run2,run2+1,run3,run3+1 ) );
00141                                         solve.addStatement( A.getSubMatrix( run2,run2+1,run3,run3+1 ) == rk_swap );
00142                                 }
00143                                 // for b:
00144                                 solve.addStatement( rk_swap == b.getRow( run1 ) );
00145                                 solve.addStatement( b.getRow( run1 ) == b.getRow( run2 ) );
00146                                 solve.addStatement( b.getRow( run2 ) == rk_swap );
00147                         
00148                                 if( REUSE ) { // rk_perm also needs to be updated if it needs to be possible to reuse the factorization
00149                                         solve.addStatement( rk_swap == rk_perm.getCol( run1 ) );
00150                                         solve.addStatement( rk_perm.getCol( run1 ) == rk_perm.getCol( run2 ) );
00151                                         solve.addStatement( rk_perm.getCol( run2 ) == rk_swap );
00152                                 }
00153                         
00154                                 solve.addStatement( "}\n" );
00155                         }
00156                         // potentially needed row swaps are done
00157                         solve.addLinebreak();
00158                         // update of the next rows:
00159                         for( run2 = run1+1; run2 < dim; run2++ ) {
00160                                 solve << "A[" << toString( run2*dim+run1 ) << "] = -A[" << toString( run2*dim+run1 ) << "]/A[" << toString( run1*dim+run1 ) << "];\n";
00161                                 solve.addStatement( A.getSubMatrix( run2,run2+1,run1+1,dim ) += A.getSubMatrix( run2,run2+1,run1,run1+1 ) * A.getSubMatrix( run1,run1+1,run1+1,dim ) );
00162                                 solve.addStatement( b.getRow( run2 ) += A.getSubMatrix( run2,run2+1,run1,run1+1 ) * b.getRow( run1 ) );
00163                                 solve.addLinebreak();
00164                         }
00165                         solve.addStatement( determinant == determinant*A.getSubMatrix(run1,run1+1,run1,run1+1) );
00166                         solve.addLinebreak();
00167                 }
00168                 solve.addStatement( determinant == determinant*A.getSubMatrix(dim-1,dim,dim-1,dim) );
00169                 solve.addLinebreak();
00170         }
00171         else { // without UNROLLING:
00172                 solve << "for( i=0; i < (" << toString( dim-1 ) << "); i++ ) {\n";
00173                 solve << "      indexMax = i;\n";
00174                 solve << "      valueMax = fabs(A[i*" << toString( dim ) << "+i]);\n";
00175                 solve << "      for( j=(i+1); j < " << toString( dim ) << "; j++ ) {\n";
00176                 solve << "              temp = fabs(A[j*" << toString( dim ) << "+i]);\n";
00177                 solve << "              if( temp > valueMax ) {\n";
00178                 solve << "                      indexMax = j;\n";
00179                 solve << "                      valueMax = temp;\n";
00180                 solve << "              }\n";
00181                 solve << "      }\n";
00182                 solve << "      if( indexMax > i ) {\n";
00183                 ExportForLoop loop2( k,0,dim );
00184                 loop2 << "      " << rk_swap.getFullName() << " = A[i*" << toString( dim ) << "+" << k.getName() << "];\n";
00185                 loop2 << "      A[i*" << toString( dim ) << "+" << k.getName() << "] = A[indexMax*" << toString( dim ) << "+" << k.getName() << "];\n";
00186                 loop2 << "      A[indexMax*" << toString( dim ) << "+" << k.getName() << "] = " << rk_swap.getFullName() << ";\n";
00187                 solve.addStatement( loop2 );
00188                 solve << "      " << rk_swap.getFullName() << " = b[i];\n";
00189                 solve << "      b[i] = b[indexMax];\n";
00190                 solve << "      b[indexMax] = " << rk_swap.getFullName() << ";\n";
00191                 if( REUSE ) {
00192                         solve << "      " << rk_swap.getFullName() << " = " << rk_perm.getFullName() << "[i];\n";
00193                         solve << "      " << rk_perm.getFullName() << "[i] = " << rk_perm.getFullName() << "[indexMax];\n";
00194                         solve << "      " << rk_perm.getFullName() << "[indexMax] = " << rk_swap.getFullName() << ";\n";
00195                 }
00196                 solve << "      }\n";
00197                 solve << "      " << determinant.getFullName() << " *= A[i*" << toString( dim ) << "+i];\n";
00198                 solve << "      for( j=i+1; j < " << toString( dim ) << "; j++ ) {\n";
00199                 solve << "              A[j*" << toString( dim ) << "+i] = -A[j*" << toString( dim ) << "+i]/A[i*" << toString( dim ) << "+i];\n";
00200                 solve << "              for( k=i+1; k < " << toString( dim ) << "; k++ ) {\n";
00201                 solve << "                      A[j*" << toString( dim ) << "+k] += A[j*" << toString( dim ) << "+i] * A[i*" << toString( dim ) << "+k];\n";
00202                 solve << "              }\n";
00203                 solve << "              b[j] += A[j*" << toString( dim ) << "+i] * b[i];\n";
00204                 solve << "      }\n";
00205                 solve << "}\n";
00206                 solve << determinant.getFullName() << " *= A[" << toString( (dim-1)*dim+(dim-1) ) << "];\n";
00207         }
00208         solve << determinant.getFullName() << " = fabs(" << determinant.getFullName() << ");\n";
00209         
00210         solve.addFunctionCall( solveTriangular, A, b );
00211         code.addFunction( solve );
00212         
00213     code.addLinebreak( 2 );
00214         if( REUSE ) { // Also export the extra function which reuses the factorization of the matrix A
00215                 for( run1 = 0; run1 < dim; run1++ ) {
00216                         solveReuse << rk_bPerm.get( run1,0 ) << " = b[" << rk_perm.getFullName() << "[" << toString( run1 ) << "]];\n";
00217                 }
00218 
00219                 for( run2 = 1; run2 < dim; run2++ ) {           // row run2
00220                         for( run1 = 0; run1 < run2; run1++ ) {  // column run1
00221                                 solveReuse << rk_bPerm.get( run2,0 ) << " += A[" << toString( run2*dim+run1 ) << "]*" << rk_bPerm.getFullName() << "[" << toString( run1 ) << "];\n";
00222                         }
00223                         solveReuse.addLinebreak();
00224                 }
00225                 solveReuse.addLinebreak();
00226 
00227                 solveReuse.addFunctionCall( solveTriangular, A, rk_bPerm );
00228                 solveReuse.addStatement( b == rk_bPerm );
00229 
00230                 code.addFunction( solveReuse );
00231         }
00232         
00233         return SUCCESSFUL_RETURN;
00234 }
00235 
00236 
00237 returnValue ExportGaussElim::appendVariableNames( stringstream& string ) {
00238 
00239         string << ", " << rk_swap.getFullName();
00240         if( REUSE ) {
00241 //              string << ", " << rk_perm.getFullName().getName();
00242                 string << ", " << rk_bPerm.getFullName();
00243         }
00244 
00245         return SUCCESSFUL_RETURN;
00246 }
00247 
00248 
00249 returnValue ExportGaussElim::setup( )
00250 {
00251         // Other cases are not implemented...
00252         ASSERT_RETURN(nCols == nRows);
00253 
00254         int useOMP;
00255         get(CG_USE_OPENMP, useOMP);
00256         ExportStruct structWspace;
00257         structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00258 
00259         rk_swap = ExportVariable( std::string( "rk_" ) + identifier + "swap", 1, 1, REAL, structWspace, true );
00260         rk_bPerm = ExportVariable( std::string( "rk_" ) + identifier + "bPerm", dim, 1, REAL, structWspace );
00261         A = ExportVariable( "A", dim, dim, REAL );
00262         b = ExportVariable( "b", dim, 1, REAL );
00263         rk_perm = ExportVariable( "rk_perm", 1, dim, INT );
00264         solve = ExportFunction( getNameSolveFunction(), A, b, rk_perm );
00265         solve.setReturnValue( determinant, false );
00266         solve.addLinebreak( );  // FIX: TO MAKE SURE IT GETS EXPORTED
00267         solveTriangular = ExportFunction( std::string( "solve_" ) + identifier + "triangular", A, b );
00268         solveTriangular.addLinebreak( );        // FIX: TO MAKE SURE IT GETS EXPORTED
00269         
00270         if( REUSE ) {
00271                 solveReuse = ExportFunction( getNameSolveReuseFunction(), A, b, rk_perm );
00272                 solveReuse.addLinebreak( );     // FIX: TO MAKE SURE IT GETS EXPORTED
00273         }
00274         
00275         int unrollOpt;
00276         userInteraction->get( UNROLL_LINEAR_SOLVER, unrollOpt );
00277         UNROLLING = (bool) unrollOpt;
00278 
00279     return SUCCESSFUL_RETURN;
00280 }
00281 
00282 
00283 ExportVariable ExportGaussElim::getGlobalExportVariable( const uint factor ) const {
00284 
00285         int useOMP;
00286         get(CG_USE_OPENMP, useOMP);
00287         ExportStruct structWspace;
00288         structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00289 
00290         return ExportVariable( std::string( "rk_" ) + identifier + "perm", factor, dim, INT, structWspace );
00291 }
00292 
00293 
00294 //
00295 // PROTECTED MEMBER FUNCTIONS:
00296 //
00297 
00298 
00299 
00300 CLOSE_NAMESPACE_ACADO
00301 
00302 // end of file.


acado
Author(s): Milan Vukov, Rien Quirynen
autogenerated on Sat Jun 8 2019 19:37:10