householder_qr_export.cpp
Go to the documentation of this file.
00001 /*
00002  *    This file is part of ACADO Toolkit.
00003  *
00004  *    ACADO Toolkit -- A Toolkit for Automatic Control and Dynamic Optimization.
00005  *    Copyright (C) 2008-2014 by Boris Houska, Hans Joachim Ferreau,
00006  *    Milan Vukov, Rien Quirynen, KU Leuven.
00007  *    Developed within the Optimization in Engineering Center (OPTEC)
00008  *    under supervision of Moritz Diehl. All rights reserved.
00009  *
00010  *    ACADO Toolkit is free software; you can redistribute it and/or
00011  *    modify it under the terms of the GNU Lesser General Public
00012  *    License as published by the Free Software Foundation; either
00013  *    version 3 of the License, or (at your option) any later version.
00014  *
00015  *    ACADO Toolkit is distributed in the hope that it will be useful,
00016  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018  *    Lesser General Public License for more details.
00019  *
00020  *    You should have received a copy of the GNU Lesser General Public
00021  *    License along with ACADO Toolkit; if not, write to the Free Software
00022  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
00023  *
00024  */
00025 
00026 
00027 
00034 #include <acado/code_generation/linear_solvers/householder_qr_export.hpp>
00035 
00036 using namespace std;
00037 
00038 BEGIN_NAMESPACE_ACADO
00039 
00040 //
00041 // PUBLIC MEMBER FUNCTIONS:
00042 //
00043 
00044 ExportHouseholderQR::ExportHouseholderQR( UserInteraction* _userInteraction,
00045                                                                         const std::string& _commonHeaderName
00046                                                                         ) : ExportLinearSolver( _userInteraction,_commonHeaderName )
00047 {
00048 }
00049 
00050 
00051 ExportHouseholderQR::~ExportHouseholderQR( )
00052 {
00053 }
00054 
00055 
00056 returnValue ExportHouseholderQR::getDataDeclarations(   ExportStatementBlock& declarations,
00057                                                                                                                 ExportStruct dataStruct
00058                                                                                                                 ) const
00059 {
00060         return SUCCESSFUL_RETURN;
00061 }
00062 
00063 
00064 returnValue ExportHouseholderQR::getFunctionDeclarations(       ExportStatementBlock& declarations
00065                                                                                                                         ) const
00066 {
00067         declarations.addDeclaration( solve );
00068         declarations.addDeclaration( solveTriangular );
00069         if( REUSE ) {
00070                 declarations.addDeclaration( solveReuse );
00071         }
00072 
00073         return SUCCESSFUL_RETURN;
00074 }
00075 
00076 
00077 returnValue ExportHouseholderQR::getCode(       ExportStatementBlock& code
00078                                                                                         )
00079 {
00080         unsigned run1, run2, run3;
00081 
00082         //
00083         // Solve the upper triangular system of equations:
00084         //
00085         for (run1 = nCols; run1 > (nCols - nBacksolves); run1--)
00086         {
00087                 for (run2 = nCols - 1; run2 > (run1 - 1); run2--)
00088                 {
00089                         solveTriangular.addStatement(
00090                                         b.getRow(run1 - 1) -= A.getSubMatrix((run1 - 1), run1, run2, run2 + 1) * b.getRow(run2));
00091                 }
00092                 solveTriangular <<
00093                                 "b[" << toString((run1 - 1)) << "] = b["
00094                                 << toString((run1 - 1)) << "]/A["
00095                                 << toString((run1 - 1) * nCols + (run1 - 1)) << "];\n";
00096         }
00097         code.addFunction(solveTriangular);
00098         
00099         //
00100         // Main solver function
00101         //
00102         solve.addStatement( determinant == 1.0 );
00103 
00104         if( UNROLLING || nRows <= 5 )
00105         {
00106                 // Start the factorization:
00107                 for (run1 = 0; run1 < nCols; run1++)
00108                 {
00109                         for (run2 = run1; run2 < nRows; run2++)
00110                         {
00111                                 solve.addStatement(
00112                                                 rk_temp.getCol(run2)
00113                                                                 == A.getSubMatrix(run2, run2 + 1, run1,
00114                                                                                 run1 + 1));
00115                         }
00116                         // calculate norm:
00117                         solve.addStatement(rk_temp.getCol(nRows) ==
00118                                         rk_temp.getCols(run1, nRows) * rk_temp.getTranspose().getRows(run1, nRows));
00119                         solve << rk_temp.getFullName() << "[" << toString(nRows) << "] = sqrt("
00120                                         << rk_temp.getFullName() << "[" << toString(nRows)
00121                                         << "]);\n";
00122 
00123                         // update first element:
00124                         solve << rk_temp.getFullName() << "[" << toString(run1) << "] += ("
00125                                         << rk_temp.getFullName() << "[" << toString(run1)
00126                                         << "] < 0 ? -1 : 1)*" << rk_temp.getFullName()
00127                                         << "[" << toString(nRows) << "];\n";
00128 
00129                         // calculate norm:
00130                         solve.addStatement(rk_temp.getCol(nRows) ==
00131                                         rk_temp.getCols(run1, nRows) * rk_temp.getTranspose().getRows(run1, nRows));
00132                         solve << rk_temp.getFullName() << "[" << toString(nRows) << "] = sqrt("
00133                                         << rk_temp.getFullName() << "[" << toString(nRows)
00134                                         << "]);\n";
00135 
00136                         // normalization:
00137                         for (run2 = run1; run2 < nRows; run2++)
00138                         {
00139                                 solve << rk_temp.getFullName() << "[" << toString(run2) << "] = "
00140                                                 << rk_temp.getFullName() << "[" << toString(run2)
00141                                                 << "]/" << rk_temp.getFullName() << "["
00142                                                 << toString(nRows) << "];\n";
00143                         }
00144 
00145                         // update current column:
00146                         solve.addStatement(
00147                                         rk_temp.getCol(nRows)
00148                                                         == rk_temp.getCols(run1, nRows)
00149                                                                         * A.getSubMatrix(run1, nRows, run1,
00150                                                                                         run1 + 1));
00151                         solve << rk_temp.getFullName() << "[" << toString(nRows) << "] *= 2;\n";
00152                         solve.addStatement(
00153                                         A.getSubMatrix(run1, run1 + 1, run1, run1 + 1) -=
00154                                                         rk_temp.getCol(run1) * rk_temp.getCol(nRows));
00155 
00156                         solve.addStatement( determinant == determinant * A.getElement(run1, run1) );
00157 
00158                         if (REUSE)
00159                         {
00160                                 // replace zeros by results that can be reused:
00161                                 for (run2 = run1; run2 < nRows - 1; run2++)
00162                                 {
00163                                         solve.addStatement(
00164                                                         A.getSubMatrix(run2 + 1, run2 + 2, run1, run1 + 1)
00165                                                                         == rk_temp.getCol(run2));
00166                                 }
00167                         }
00168 
00169                         // update following columns:
00170                         for (run2 = run1 + 1; run2 < nCols; run2++)
00171                         {
00172                                 solve.addStatement(
00173                                                 rk_temp.getCol(nRows)
00174                                                                 == rk_temp.getCols(run1, nRows)
00175                                                                                 * A.getSubMatrix(run1, nRows, run2,
00176                                                                                                 run2 + 1));
00177                                 solve <<  rk_temp.getFullName() << "[" << toString(nRows) << "] *= 2;\n";
00178                                 for (run3 = run1; run3 < nRows; run3++)
00179                                 {
00180                                         solve.addStatement(
00181                                                         A.getSubMatrix(run3, run3 + 1, run2, run2 + 1) -=
00182                                                                         rk_temp.getCol(run3) * rk_temp.getCol(nRows));
00183                                 }
00184                         }
00185                         // update right-hand side:
00186                         solve.addStatement(
00187                                         rk_temp.getCol(nRows)
00188                                                         == rk_temp.getCols(run1, nRows)
00189                                                                         * b.getRows(run1, nRows));
00190                         solve << rk_temp.getFullName() << "[" << toString(nRows) << "] *= 2;\n";
00191                         for (run3 = run1; run3 < nRows; run3++)
00192                         {
00193                                 solve.addStatement( b.getRow(run3) -= rk_temp.getCol(run3) * rk_temp.getCol(nRows));
00194                         }
00195 
00196                         if (REUSE)
00197                         {
00198                                 // store last element to be reused:
00199                                 solve.addStatement(
00200                                                 rk_temp.getCol(run1) == rk_temp.getCol(nRows - 1));
00201                         }
00202                 }
00203         }
00204         else
00205         {
00206                 ExportIndex i( "i" );
00207                 ExportIndex j( "j" );
00208                 ExportIndex k( "k" );
00209 
00210                 solve.addIndex( i );
00211                 solve.addIndex( j );
00212                 solve.addIndex( k );
00213 
00214                 solve << "for( i=0; i < " << toString( nCols ) << "; i++ ) {\n";
00215                 solve << "      for( j=i; j < " << toString( nRows ) << "; j++ ) {\n";
00216                 solve << "              " << rk_temp.getFullName() << "[j] = A[j*" << toString( nCols ) << "+i];\n";
00217                 solve << "      }\n";
00218                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n";
00219                 solve << "      for( j=i+1; j < " << toString( nRows ) << "; j++ ) {\n";
00220                 solve << "              " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n";
00221                 solve << "      }\n";
00222                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = sqrt(" << rk_temp.getFullName() << "[" << toString( nRows ) << "]);\n";
00223                 // update first element:
00224                 solve << "      " << rk_temp.getFullName() << "[i] += (" << rk_temp.getFullName() << "[i] < 0 ? -1 : 1)*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00225                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[i];\n";
00226                 solve << "      for( j=i+1; j < " << toString( nRows ) << "; j++ ) {\n";
00227                 solve << "              " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[j]*" << rk_temp.getFullName() << "[j];\n";
00228                 solve << "      }\n";
00229                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = sqrt(" << rk_temp.getFullName() << "[" << toString( nRows ) << "]);\n";
00230                 solve << "      for( j=i; j < " << toString( nRows ) << "; j++ ) {\n";
00231                 solve << "              " << rk_temp.getFullName() << "[j] = " << rk_temp.getFullName() << "[j]/" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00232                 solve << "      }\n";
00233                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << toString( nCols ) << "+i];\n";
00234                 solve << "      for( j=i+1; j < " << toString( nRows ) << "; j++ ) {\n";
00235                 solve << "              " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[j]*A[j*" << toString( nCols ) << "+i];\n";
00236                 solve << "      }\n";
00237                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n";
00238                 solve << "      A[i*" << toString( nCols ) << "+i] -= " << rk_temp.getFullName() << "[i]*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00239 
00240                 solve << "      " << determinant.getFullName() << " *= " << "   A[i * " << toString( nCols ) << " + i];\n";
00241 
00242                 if( REUSE ) {
00243                         solve << "      for( j=i; j < (" << toString( nRows ) << "-1); j++ ) {\n";
00244                         solve << "              A[(j+1)*" << toString( nCols ) << "+i] = " << rk_temp.getFullName() << "[j];\n";
00245                         solve << "      }\n";
00246                 }
00247                 solve << "      for( j=i+1; j < " << toString( nCols ) << "; j++ ) {\n";
00248                 solve << "              " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*A[i*" << toString( nCols ) << "+j];\n";
00249                 solve << "              for( k=i+1; k < " << toString( nRows ) << "; k++ ) {\n";
00250                 solve << "                      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[k]*A[k*" << toString( nCols ) << "+j];\n";
00251                 solve << "              }\n";
00252                 solve << "              " << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n";
00253                 solve << "              for( k=i; k < " << toString( nRows ) << "; k++ ) {\n";
00254                 solve << "                      A[k*" << toString( nCols ) << "+j] -= " << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00255                 solve << "              }\n";
00256                 solve << "      }\n";
00257                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] = " << rk_temp.getFullName() << "[i]*b[i];\n";
00258                 solve << "      for( k=i+1; k < " << toString( nRows ) << "; k++ ) {\n";
00259                 solve << "              " << rk_temp.getFullName() << "[" << toString( nRows ) << "] += " << rk_temp.getFullName() << "[k]*b[k];\n";
00260                 solve << "      }\n";
00261                 solve << "      " << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n";
00262                 solve << "      for( k=i; k < " << toString( nRows ) << "; k++ ) {\n";
00263                 solve << "              b[k] -= " << rk_temp.getFullName() << "[k]*" << rk_temp.getFullName() << "[" << toString( nRows ) << "];\n";
00264                 solve << "      }\n";
00265                 if( REUSE ) {
00266                         solve << "      " << rk_temp.getFullName() << "[i] = " << rk_temp.getFullName() << "[" << toString( nRows-1 ) << "];\n";
00267                 }
00268                 solve << "}\n";
00269         }
00270         solve.addLinebreak();
00271 
00272         solve.addFunctionCall(solveTriangular, A, b);
00273         code.addFunction( solve );
00274         
00275     code.addLinebreak( 2 );
00276         if( REUSE ) { // Also export the extra function which reuses the factorization of the matrix A
00277                 // update right-hand side:
00278                 for( run1 = 0; run1 < nCols; run1++ ) {
00279                         solveReuse.addStatement( rk_temp.getCol( nRows ) == A.getSubMatrix( run1+1,run1+2,run1,run1+1 )*b.getRow( run1 ) );
00280                         for( run2 = run1+1; run2 < (nRows-1); run2++ ) {
00281                                 solveReuse.addStatement( rk_temp.getCol( nRows ) += A.getSubMatrix( run2+1,run2+2,run1,run1+1 )*b.getRow( run2 ) );
00282                         }
00283                         solveReuse.addStatement( rk_temp.getCol( nRows ) += rk_temp.getCol( run1 )*b.getRow( nRows-1 ) );
00284                         solveReuse << rk_temp.getFullName() << "[" << toString( nRows ) << "] *= 2;\n" ;
00285                         for( run3 = run1; run3 < (nRows-1); run3++ ) {
00286                                 solveReuse.addStatement( b.getRow( run3 ) -= A.getSubMatrix( run3+1,run3+2,run1,run1+1 )*rk_temp.getCol( nRows ) );
00287                         }
00288                         solveReuse.addStatement( b.getRow( nRows-1 ) -= rk_temp.getCol( run1 )*rk_temp.getCol( nRows ) );
00289                 }
00290                 solveReuse.addLinebreak();
00291 
00292                 solveReuse.addFunctionCall( solveTriangular, A, b );
00293                 code.addFunction( solveReuse );
00294         }
00295         
00296         return SUCCESSFUL_RETURN;
00297 }
00298 
00299 
00300 returnValue ExportHouseholderQR::appendVariableNames( stringstream& string )
00301 {
00302         return SUCCESSFUL_RETURN;
00303 }
00304 
00305 
00306 returnValue ExportHouseholderQR::setup( )
00307 {
00308         int useOMP;
00309         get(CG_USE_OPENMP, useOMP);
00310 
00311         A = ExportVariable("A", nRows, nCols, REAL);
00312         b = ExportVariable("b", nRows, 1, REAL);
00313         rk_temp = ExportVariable("rk_temp", 1, nRows + 1, REAL);
00314         solve = ExportFunction(getNameSolveFunction(), A, b, rk_temp);
00315         solve.setReturnValue(determinant, false);
00316         solve.addLinebreak( );  // FIX: TO MAKE SURE IT GETS EXPORTED
00317         solveTriangular = ExportFunction( std::string( "solve_" ) + identifier + "triangular", A, b);
00318         solveTriangular.addLinebreak( );        // FIX: TO MAKE SURE IT GETS EXPORTED
00319         
00320         if (REUSE)
00321         {
00322                 solveReuse = ExportFunction(getNameSolveReuseFunction(), A, b, rk_temp);
00323                 solveReuse.addLinebreak();      // FIX: TO MAKE SURE IT GETS EXPORTED
00324         }
00325         
00326         int unrollOpt;
00327         userInteraction->get(UNROLL_LINEAR_SOLVER, unrollOpt);
00328         UNROLLING = (bool) unrollOpt;
00329 
00330     return SUCCESSFUL_RETURN;
00331 }
00332 
00333 
00334 ExportVariable ExportHouseholderQR::getGlobalExportVariable( const uint factor ) const {
00335 
00336         int useOMP;
00337         get(CG_USE_OPENMP, useOMP);
00338         ExportStruct structWspace;
00339         structWspace = useOMP ? ACADO_LOCAL : ACADO_WORKSPACE;
00340 
00341         return ExportVariable( std::string( "rk_" ) + identifier + "temp", factor, nRows+1, REAL, structWspace );
00342 }
00343 
00344 
00345 //
00346 // PROTECTED MEMBER FUNCTIONS:
00347 //
00348 
00349 
00350 
00351 CLOSE_NAMESPACE_ACADO
00352 
00353 // end of file.


acado
Author(s): Milan Vukov, Rien Quirynen
autogenerated on Sat Jun 8 2019 19:37:19