00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/linear_solvers/gaussian_elimination_export.hpp>
00035
00036 using namespace std;
00037
00038 BEGIN_NAMESPACE_ACADO
00039
00040
00041
00042
00043
00044 ExportGaussElim::ExportGaussElim( UserInteraction* _userInteraction,
00045 const std::string& _commonHeaderName
00046 ) : ExportLinearSolver( _userInteraction,_commonHeaderName )
00047 {
00048 }
00049
00050 ExportGaussElim::~ExportGaussElim( )
00051 {}
00052
00053 returnValue ExportGaussElim::getDataDeclarations( ExportStatementBlock& declarations,
00054 ExportStruct dataStruct
00055 ) const
00056 {
00057 declarations.addDeclaration( rk_swap,dataStruct );
00058 if( REUSE ) {
00059 declarations.addDeclaration( rk_bPerm,dataStruct );
00060 }
00061
00062 return SUCCESSFUL_RETURN;
00063 }
00064
00065
00066 returnValue ExportGaussElim::getFunctionDeclarations( ExportStatementBlock& declarations
00067 ) const
00068 {
00069 declarations.addDeclaration( solve );
00070 declarations.addDeclaration( solveTriangular );
00071 if( REUSE ) {
00072 declarations.addDeclaration( solveReuse );
00073 }
00074
00075 return SUCCESSFUL_RETURN;
00076 }
00077
00078
00079 returnValue ExportGaussElim::getCode( ExportStatementBlock& code
00080 )
00081 {
00082 uint run1, run2, run3;
00083
00084 for( run1 = dim; run1 > 0; run1--) {
00085 for( run2 = dim-1; run2 > (run1-1); run2--) {
00086 solveTriangular.addStatement( b.getRow( (run1-1) ) -= A.getSubMatrix( (run1-1),(run1-1)+1,run2,run2+1 ) * b.getRow( run2 ) );
00087 }
00088 solveTriangular << "b[" << toString( (run1-1) ) << "] = b[" << toString( (run1-1) ) << "]/A[" << toString( (run1-1)*dim+(run1-1) ) << "];\n";
00089 }
00090 code.addFunction( solveTriangular );
00091
00092 ExportIndex i( "i" );
00093 solve.addIndex( i );
00094 ExportIndex j( "j" );
00095 ExportIndex k( "k" );
00096 ExportVariable indexMax( "indexMax", 1, 1, INT, ACADO_LOCAL, true );
00097 ExportVariable valueMax( "valueMax", 1, 1, REAL, ACADO_LOCAL, true );
00098 ExportVariable temp( "temp", 1, 1, REAL, ACADO_LOCAL, true );
00099 if( !UNROLLING ) {
00100 solve.addIndex( j );
00101 solve.addIndex( k );
00102 solve.addDeclaration( indexMax );
00103 solve.addDeclaration( valueMax );
00104 solve.addDeclaration( temp );
00105 }
00106
00107
00108 if( REUSE ) {
00109 ExportForLoop loop1( i,0,dim );
00110 loop1 << rk_perm.get( 0,i ) << " = " << i.getName() << ";\n";
00111 solve.addStatement( loop1 );
00112 }
00113
00114 solve.addStatement( determinant == 1 );
00115 if( UNROLLING || dim <= 5 ) {
00116
00117 for( run1 = 0; run1 < (dim-1); run1++ ) {
00118
00119 for( run2 = run1+1; run2 < dim; run2++ ) {
00120
00121 stringstream test;
00122 if( run2 == (run1+1) ) {
00123 test << "if(";
00124 } else {
00125 test << "else if(";
00126 }
00127 test << "fabs(A[" << toString( run2*dim+run1 ) << "]) > fabs(A[" << toString( run1*dim+run1 ) << "])";
00128 for( run3 = run1+1; run3 < dim; run3++ ) {
00129 if( run3 != run2) {
00130 test << " && fabs(A[" << toString( run2*dim+run1 ) << "]) > fabs(A[" << toString( run3*dim+run1 ) << "])";
00131 }
00132 }
00133 test << ") {\n";
00134 solve.addStatement( test.str() );
00135
00136
00137
00138 for( run3 = 0; run3 < dim; run3++ ) {
00139 solve.addStatement( rk_swap == A.getSubMatrix( run1,run1+1,run3,run3+1 ) );
00140 solve.addStatement( A.getSubMatrix( run1,run1+1,run3,run3+1 ) == A.getSubMatrix( run2,run2+1,run3,run3+1 ) );
00141 solve.addStatement( A.getSubMatrix( run2,run2+1,run3,run3+1 ) == rk_swap );
00142 }
00143
00144 solve.addStatement( rk_swap == b.getRow( run1 ) );
00145 solve.addStatement( b.getRow( run1 ) == b.getRow( run2 ) );
00146 solve.addStatement( b.getRow( run2 ) == rk_swap );
00147
00148 if( REUSE ) {
00149 solve.addStatement( rk_swap == rk_perm.getCol( run1 ) );
00150 solve.addStatement( rk_perm.getCol( run1 ) == rk_perm.getCol( run2 ) );
00151 solve.addStatement( rk_perm.getCol( run2 ) == rk_swap );
00152 }
00153
00154 solve.addStatement( "}\n" );
00155 }
00156
00157 solve.addLinebreak();
00158
00159 for( run2 = run1+1; run2 < dim; run2++ ) {
00160 solve << "A[" << toString( run2*dim+run1 ) << "] = -A[" << toString( run2*dim+run1 ) << "]/A[" << toString( run1*dim+run1 ) << "];\n";
00161 solve.addStatement( A.getSubMatrix( run2,run2+1,run1+1,dim ) += A.getSubMatrix( run2,run2+1,run1,run1+1 ) * A.getSubMatrix( run1,run1+1,run1+1,dim ) );
00162 solve.addStatement( b.getRow( run2 ) += A.getSubMatrix( run2,run2+1,run1,run1+1 ) * b.getRow( run1 ) );
00163 solve.addLinebreak();
00164 }
00165 solve.addStatement( determinant == determinant*A.getSubMatrix(run1,run1+1,run1,run1+1) );
00166 solve.addLinebreak();
00167 }
00168 solve.addStatement( determinant == determinant*A.getSubMatrix(dim-1,dim,dim-1,dim) );
00169 solve.addLinebreak();
00170 }
00171 else {
00172 solve << "for( i=0; i < (" << toString( dim-1 ) << "); i++ ) {\n";
00173 solve << " indexMax = i;\n";
00174 solve << " valueMax = fabs(A[i*" << toString( dim ) << "+i]);\n";
00175 solve << " for( j=(i+1); j < " << toString( dim ) << "; j++ ) {\n";
00176 solve << " temp = fabs(A[j*" << toString( dim ) << "+i]);\n";
00177 solve << " if( temp > valueMax ) {\n";
00178 solve << " indexMax = j;\n";
00179 solve << " valueMax = temp;\n";
00180 solve << " }\n";
00181 solve << " }\n";
00182 solve << " if( indexMax > i ) {\n";
00183 ExportForLoop loop2( k,0,dim );
00184 loop2 << " " << rk_swap.getFullName() << " = A[i*" << toString( dim ) << "+" << k.getName() << "];\n";
00185 loop2 << " A[i*" << toString( dim ) << "+" << k.getName() << "] = A[indexMax*" << toString( dim ) << "+" << k.getName() << "];\n";
00186 loop2 << " A[indexMax*" << toString( dim ) << "+" << k.getName() << "] = " << rk_swap.getFullName() << ";\n";
00187 solve.addStatement( loop2 );
00188 solve << " " << rk_swap.getFullName() << " = b[i];\n";
00189 solve << " b[i] = b[indexMax];\n";
00190 solve << " b[indexMax] = " << rk_swap.getFullName() << ";\n";
00191 if( REUSE ) {
00192 solve << " " << rk_swap.getFullName() << " = " << rk_perm.getFullName() << "[i];\n";
00193 solve << " " << rk_perm.getFullName() << "[i] = " << rk_perm.getFullName() << "[indexMax];\n";
00194 solve << " " << rk_perm.getFullName() << "[indexMax] = " << rk_swap.getFullName() << ";\n";
00195 }
00196 solve << " }\n";
00197 solve << " " << determinant.getFullName() << " *= A[i*" << toString( dim ) << "+i];\n";
00198 solve << " for( j=i+1; j < " << toString( dim ) << "; j++ ) {\n";
00199 solve << " A[j*" << toString( dim ) << "+i] = -A[j*" << toString( dim ) << "+i]/A[i*" << toString( dim ) << "+i];\n";
00200 solve << " for( k=i+1; k < " << toString( dim ) << "; k++ ) {\n";
00201 solve << " A[j*" << toString( dim ) << "+k] += A[j*" << toString( dim ) << "+i] * A[i*" << toString( dim ) << "+k];\n";
00202 solve << " }\n";
00203 solve << " b[j] += A[j*" << toString( dim ) << "+i] * b[i];\n";
00204 solve << " }\n";
00205 solve << "}\n";
00206 solve << determinant.getFullName() << " *= A[" << toString( (dim-1)*dim+(dim-1) ) << "];\n";
00207 }
00208 solve << determinant.getFullName() << " = fabs(" << determinant.getFullName() << ");\n";
00209
00210 solve.addFunctionCall( solveTriangular, A, b );
00211 code.addFunction( solve );
00212
00213 code.addLinebreak( 2 );
00214 if( REUSE ) {
00215 for( run1 = 0; run1 < dim; run1++ ) {
00216 solveReuse << rk_bPerm.get( run1,0 ) << " = b[" << rk_perm.getFullName() << "[" << toString( run1 ) << "]];\n";
00217 }
00218
00219 for( run2 = 1; run2 < dim; run2++ ) {
00220 for( run1 = 0; run1 < run2; run1++ ) {
00221 solveReuse << rk_bPerm.get( run2,0 ) << " += A[" << toString( run2*dim+run1 ) << "]*" << rk_bPerm.getFullName() << "[" << toString( run1 ) << "];\n";
00222 }
00223 solveReuse.addLinebreak();
00224 }
00225 solveReuse.addLinebreak();
00226
00227 solveReuse.addFunctionCall( solveTriangular, A, rk_bPerm );
00228 solveReuse.addStatement( b == rk_bPerm );
00229
00230 code.addFunction( solveReuse );
00231 }
00232
00233 return SUCCESSFUL_RETURN;
00234 }
00235
00236
00237 returnValue ExportGaussElim::appendVariableNames( stringstream& string ) {
00238
00239 string << ", " << rk_swap.getFullName();
00240 if( REUSE ) {
00241
00242 string << ", " << rk_bPerm.getFullName();
00243 }
00244
00245 return SUCCESSFUL_RETURN;
00246 }
00247
00248
00249 returnValue ExportGaussElim::setup( )
00250 {
00251
00252 ASSERT_RETURN(nCols == nRows);
00253
00254 int useOMP;
00255 get(CG_USE_OPENMP, useOMP);
00256 ExportStruct structWspace;
00257 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00258
00259 rk_swap = ExportVariable( std::string( "rk_" ) + identifier + "swap", 1, 1, REAL, structWspace, true );
00260 rk_bPerm = ExportVariable( std::string( "rk_" ) + identifier + "bPerm", dim, 1, REAL, structWspace );
00261 A = ExportVariable( "A", dim, dim, REAL );
00262 b = ExportVariable( "b", dim, 1, REAL );
00263 rk_perm = ExportVariable( "rk_perm", 1, dim, INT );
00264 solve = ExportFunction( getNameSolveFunction(), A, b, rk_perm );
00265 solve.setReturnValue( determinant, false );
00266 solve.addLinebreak( );
00267 solveTriangular = ExportFunction( std::string( "solve_" ) + identifier + "triangular", A, b );
00268 solveTriangular.addLinebreak( );
00269
00270 if( REUSE ) {
00271 solveReuse = ExportFunction( getNameSolveReuseFunction(), A, b, rk_perm );
00272 solveReuse.addLinebreak( );
00273 }
00274
00275 int unrollOpt;
00276 userInteraction->get( UNROLL_LINEAR_SOLVER, unrollOpt );
00277 UNROLLING = (bool) unrollOpt;
00278
00279 return SUCCESSFUL_RETURN;
00280 }
00281
00282
00283 ExportVariable ExportGaussElim::getGlobalExportVariable( const uint factor ) const {
00284
00285 int useOMP;
00286 get(CG_USE_OPENMP, useOMP);
00287 ExportStruct structWspace;
00288 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00289
00290 return ExportVariable( std::string( "rk_" ) + identifier + "perm", factor, dim, INT, structWspace );
00291 }
00292
00293
00294
00295
00296
00297
00298
00299
00300 CLOSE_NAMESPACE_ACADO
00301
00302