00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/linear_solvers/householder_qr_export.hpp>
00035
00036 using namespace std;
00037
00038 BEGIN_NAMESPACE_ACADO
00039
00040
00041
00042
00043
00044 ExportHouseholderQR::ExportHouseholderQR( UserInteraction* _userInteraction,
00045 const std::string& _commonHeaderName
00046 ) : ExportLinearSolver( _userInteraction,_commonHeaderName )
00047 {
00048 }
00049
00050
00051 ExportHouseholderQR::~ExportHouseholderQR( )
00052 {
00053 }
00054
00055
00056 returnValue ExportHouseholderQR::getDataDeclarations( ExportStatementBlock& declarations,
00057 ExportStruct dataStruct
00058 ) const
00059 {
00060 return SUCCESSFUL_RETURN;
00061 }
00062
00063
00064 returnValue ExportHouseholderQR::getFunctionDeclarations( ExportStatementBlock& declarations
00065 ) const
00066 {
00067 declarations.addDeclaration( solve );
00068 declarations.addDeclaration( solveTriangular );
00069 if( REUSE ) {
00070 declarations.addDeclaration( solveReuse );
00071 }
00072
00073 return SUCCESSFUL_RETURN;
00074 }
00075
00076
00077 returnValue ExportHouseholderQR::getCode( ExportStatementBlock& code
00078 )
00079 {
00080 unsigned run1, run2, run3;
00081
00082
00083
00084
00085 for (run1 = nCols; run1 > (nCols - nBacksolves); run1--)
00086 {
00087 for (run2 = nCols - 1; run2 > (run1 - 1); run2--)
00088 {
00089 solveTriangular.addStatement(
00090 b.getRow(run1 - 1) -= A.getSubMatrix((run1 - 1), run1, run2, run2 + 1) * b.getRow(run2));
00091 }
00092 solveTriangular <<
00093 "b[" << toString((run1 - 1)) << "] = b["
00094 << toString((run1 - 1)) << "]/A["
00095 << toString((run1 - 1) * nCols + (run1 - 1)) << "];\n";
00096 }
00097 code.addFunction(solveTriangular);
00098
00099
00100
00101
00102 solve.addStatement( determinant == 1.0 );
00103
00104 if( UNROLLING || nRows <= 5 )
00105 {
00106
00107 for (run1 = 0; run1 < nCols; run1++)
00108 {
00109 for (run2 = run1; run2 < nRows; run2++)
00110 {
00111 solve.addStatement(
00112 rk_temp.getCol(run2)
00113 == A.getSubMatrix(run2, run2 + 1, run1,
00114 run1 + 1));
00115 }
00116
00117 solve.addStatement(rk_temp.getCol(nRows) ==
00118 rk_temp.getCols(run1, nRows) * rk_temp.getTranspose().getRows(run1, nRows));
00119 solve << rk_temp.getFullName() << "[" << toString(nRows) << "] = sqrt("
00120 << rk_temp.getFullName() << "[" << toString(nRows)
00121 << "]);\n";
00122
00123
00124 solve << rk_temp.getFullName() << "[" << toString(run1) << "] += ("
00125 << rk_temp.getFullName() << "[" << toString(run1)
00126 << "] < 0 ? -1 : 1)*" << rk_temp.getFullName()
00127 << "[" << toString(nRows) << "];\n";
00128
00129
00130 solve.addStatement(rk_temp.getCol(nRows) ==
00131 rk_temp.getCols(run1, nRows) * rk_temp.getTranspose().getRows(run1, nRows));
00132 solve << rk_temp.getFullName() << "[" << toString(nRows) << "] = sqrt("
00133 << rk_temp.getFullName() << "[" << toString(nRows)
00134 << "]);\n";
00135
00136
00137 for (run2 = run1; run2 < nRows; run2++)
00138 {
00139 solve << rk_temp.getFullName() << "[" << toString(run2) << "] = "
00140 << rk_temp.getFullName() << "[" << toString(run2)
00141 << "]/" << rk_temp.getFullName() << "["
00142 << toString(nRows) << "];\n";
00143 }
00144
00145
00146 solve.addStatement(
00147 rk_temp.getCol(nRows)
00148 == rk_temp.getCols(run1, nRows)
00149 * A.getSubMatrix(run1, nRows, run1,
00150 run1 + 1));
00151 solve << rk_temp.getFullName() << "[" << toString(nRows) << "] *= 2;\n";
00152 solve.addStatement(
00153 A.getSubMatrix(run1, run1 + 1, run1, run1 + 1) -=
00154 rk_temp.getCol(run1) * rk_temp.getCol(nRows));
00155
00156 solve.addStatement( determinant == determinant * A.getElement(run1, run1) );
00157
00158 if (REUSE)
00159 {
00160
00161 for (run2 = run1; run2 < nRows - 1; run2++)
00162 {
00163 solve.addStatement(
00164 A.getSubMatrix(run2 + 1, run2 + 2, run1, run1 + 1)
00165 == rk_temp.getCol(run2));
00166 }
00167 }
00168
00169
00170 for (run2 = run1 + 1; run2 < nCols; run2++)
00171 {
00172 solve.addStatement(
00173 rk_temp.getCol(nRows)
00174 == rk_temp.getCols(run1, nRows)
00175 * A.getSubMatrix(run1, nRows, run2,
00176 run2 + 1));
00177 solve << rk_temp.getFullName() << "[" << toString(nRows) << "] *= 2;\n";
00178 for (run3 = run1; run3 < nRows; run3++)
00179 {
00180 solve.addStatement(
00181 A.getSubMatrix(run3, run3 + 1, run2, run2 + 1) -=
00182 rk_temp.getCol(run3) * rk_temp.getCol(nRows));
00183 }
00184 }
00185
00186 solve.addStatement(
00187 rk_temp.getCol(nRows)
00188 == rk_temp.getCols(run1, nRows)
00189 * b.getRows(run1, nRows));
00190 solve << rk_temp.getFullName() << "[" << toString(nRows) << "] *= 2;\n";
00191 for (run3 = run1; run3 < nRows; run3++)
00192 {
00193 solve.addStatement( b.getRow(run3) -= rk_temp.getCol(run3) * rk_temp.getCol(nRows));
00194 }
00195
00196 if (REUSE)
00197 {
00198
00199 solve.addStatement(
00200 rk_temp.getCol(run1) == rk_temp.getCol(nRows - 1));
00201 }
00202 }
00203 }
00204 else
00205 {
00206 ExportIndex i( "i" );
00207 ExportIndex j( "j" );
00208 ExportIndex k( "k" );
00209
00210 solve.addIndex( i );
00211 solve.addIndex( j );
00212 solve.addIndex( k );
00213
00214 solve << "for( i=0; i < " << toString( nCols ) << "; i++ ) {\n";
00215 solve << " for( j=i; j < " << toString( nRows ) << "; j++ ) {\n";
00216 solve << " " << rk_temp.getFullName() << "[j] = A[j*" << toString( nCols ) << "+i];\n";
00217 solve << " }\n";
00218 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n";
00219 solve << " for( j=i+1; j < " << toString( nRows ) << "; j++ ) {\n";
00220 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n";
00221 solve << " }\n";
00222 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = sqrt(" << rk_temp.getFullName() << "[" << toString( nRows ) << "]);\n";
00223
00224 solve << " " << rk_temp.getFullName() << "[i] += (" << rk_temp.getFullName() << "[i] < 0 ? -1 : 1)*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00225 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n";
00226 solve << " for( j=i+1; j < " << toString( nRows ) << "; j++ ) {\n";
00227 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n";
00228 solve << " }\n";
00229 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = sqrt(" << rk_temp.getFullName() << "[" << toString( nRows ) << "]);\n";
00230 solve << " for( j=i; j < " << toString( nRows ) << "; j++ ) {\n";
00231 solve << " " << rk_temp.getFullName() << "[j] = " << rk_temp.getFullName() << "[j]/" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00232 solve << " }\n";
00233 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << toString( nCols ) << "+i];\n";
00234 solve << " for( j=i+1; j < " << toString( nRows ) << "; j++ ) {\n";
00235 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[j]*A[j*" << toString( nCols ) << "+i];\n";
00236 solve << " }\n";
00237 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n";
00238 solve << " A[i*" << toString( nCols ) << "+i] -= " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00239
00240 solve << " " << determinant.getFullName() << " *= " << " A[i * " << toString( nCols ) << " + i];\n";
00241
00242 if( REUSE ) {
00243 solve << " for( j=i; j < (" << toString( nRows ) << "-1); j++ ) {\n";
00244 solve << " A[(j+1)*" << toString( nCols ) << "+i] = " << rk_temp.getFullName() << "[j];\n";
00245 solve << " }\n";
00246 }
00247 solve << " for( j=i+1; j < " << toString( nCols ) << "; j++ ) {\n";
00248 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << toString( nCols ) << "+j];\n";
00249 solve << " for( k=i+1; k < " << toString( nRows ) << "; k++ ) {\n";
00250 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[k]*A[k*" << toString( nCols ) << "+j];\n";
00251 solve << " }\n";
00252 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n";
00253 solve << " for( k=i; k < " << toString( nRows ) << "; k++ ) {\n";
00254 solve << " A[k*" << toString( nCols ) << "+j] -= " << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00255 solve << " }\n";
00256 solve << " }\n";
00257 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*b[i];\n";
00258 solve << " for( k=i+1; k < " << toString( nRows ) << "; k++ ) {\n";
00259 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[k]*b[k];\n";
00260 solve << " }\n";
00261 solve << " " << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n";
00262 solve << " for( k=i; k < " << toString( nRows ) << "; k++ ) {\n";
00263 solve << " b[k] -= " << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00264 solve << " }\n";
00265 if( REUSE ) {
00266 solve << " " << rk_temp.getFullName() << "[i] = " << rk_temp.getFullName() << "[" << toString( nRows-1 ) << "];\n";
00267 }
00268 solve << "}\n";
00269 }
00270 solve.addLinebreak();
00271
00272 solve.addFunctionCall(solveTriangular, A, b);
00273 code.addFunction( solve );
00274
00275 code.addLinebreak( 2 );
00276 if( REUSE ) {
00277
00278 for( run1 = 0; run1 < nCols; run1++ ) {
00279 solveReuse.addStatement( rk_temp.getCol( nRows ) == A.getSubMatrix( run1+1,run1+2,run1,run1+1 )*b.getRow( run1 ) );
00280 for( run2 = run1+1; run2 < (nRows-1); run2++ ) {
00281 solveReuse.addStatement( rk_temp.getCol( nRows ) += A.getSubMatrix( run2+1,run2+2,run1,run1+1 )*b.getRow( run2 ) );
00282 }
00283 solveReuse.addStatement( rk_temp.getCol( nRows ) += rk_temp.getCol( run1 )*b.getRow( nRows-1 ) );
00284 solveReuse << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n" ;
00285 for( run3 = run1; run3 < (nRows-1); run3++ ) {
00286 solveReuse.addStatement( b.getRow( run3 ) -= A.getSubMatrix( run3+1,run3+2,run1,run1+1 )*rk_temp.getCol( nRows ) );
00287 }
00288 solveReuse.addStatement( b.getRow( nRows-1 ) -= rk_temp.getCol( run1 )*rk_temp.getCol( nRows ) );
00289 }
00290 solveReuse.addLinebreak();
00291
00292 solveReuse.addFunctionCall( solveTriangular, A, b );
00293 code.addFunction( solveReuse );
00294 }
00295
00296 return SUCCESSFUL_RETURN;
00297 }
00298
00299
00300 returnValue ExportHouseholderQR::appendVariableNames( stringstream& string )
00301 {
00302 return SUCCESSFUL_RETURN;
00303 }
00304
00305
00306 returnValue ExportHouseholderQR::setup( )
00307 {
00308 int useOMP;
00309 get(CG_USE_OPENMP, useOMP);
00310
00311 A = ExportVariable("A", nRows, nCols, REAL);
00312 b = ExportVariable("b", nRows, 1, REAL);
00313 rk_temp = ExportVariable("rk_temp", 1, nRows + 1, REAL);
00314 solve = ExportFunction(getNameSolveFunction(), A, b, rk_temp);
00315 solve.setReturnValue(determinant, false);
00316 solve.addLinebreak( );
00317 solveTriangular = ExportFunction( std::string( "solve_" ) + identifier + "triangular", A, b);
00318 solveTriangular.addLinebreak( );
00319
00320 if (REUSE)
00321 {
00322 solveReuse = ExportFunction(getNameSolveReuseFunction(), A, b, rk_temp);
00323 solveReuse.addLinebreak();
00324 }
00325
00326 int unrollOpt;
00327 userInteraction->get(UNROLL_LINEAR_SOLVER, unrollOpt);
00328 UNROLLING = (bool) unrollOpt;
00329
00330 return SUCCESSFUL_RETURN;
00331 }
00332
00333
00334 ExportVariable ExportHouseholderQR::getGlobalExportVariable( const uint factor ) const {
00335
00336 int useOMP;
00337 get(CG_USE_OPENMP, useOMP);
00338 ExportStruct structWspace;
00339 structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00340
00341 return ExportVariable( std::string( "rk_" ) + identifier + "temp", factor, nRows+1, REAL, structWspace );
00342 }
00343
00344
00345
00346
00347
00348
00349
00350
00351 CLOSE_NAMESPACE_ACADO
00352
00353