export_cholesky_solver.cpp
Go to the documentation of this file.
00001 /*
00002  *    This file is part of ACADO Toolkit.
00003  *
00004  *    ACADO Toolkit -- A Toolkit for Automatic Control and Dynamic Optimization.
00005  *    Copyright (C) 2008-2014 by Boris Houska, Hans Joachim Ferreau,
00006  *    Milan Vukov, Rien Quirynen, KU Leuven.
00007  *    Developed within the Optimization in Engineering Center (OPTEC)
00008  *    under supervision of Moritz Diehl. All rights reserved.
00009  *
00010  *    ACADO Toolkit is free software; you can redistribute it and/or
00011  *    modify it under the terms of the GNU Lesser General Public
00012  *    License as published by the Free Software Foundation; either
00013  *    version 3 of the License, or (at your option) any later version.
00014  *
00015  *    ACADO Toolkit is distributed in the hope that it will be useful,
00016  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018  *    Lesser General Public License for more details.
00019  *
00020  *    You should have received a copy of the GNU Lesser General Public
00021  *    License along with ACADO Toolkit; if not, write to the Free Software
00022  *    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
00023  *
00024  */
00025 
00026 
00027 
00034 #include <acado/code_generation/linear_solvers/export_cholesky_solver.hpp>
00035 
00036 using namespace std;
00037 
00038 BEGIN_NAMESPACE_ACADO
00039 
00040 ExportCholeskySolver::ExportCholeskySolver(     UserInteraction* _userInteraction,
00041                                                                                         const std::string& _commonHeaderName
00042                                                                                         ) : ExportLinearSolver(_userInteraction, _commonHeaderName)
00043 {
00044         nColsB = 0;
00045 }
00046 
00047 ExportCholeskySolver::~ExportCholeskySolver()
00048 {}
00049 
00050 returnValue ExportCholeskySolver::init( unsigned _dimA,
00051                                                                                 unsigned _numColsB,
00052                                                                                 const std::string& _id
00053                                                                                 )
00054 {
00055         nRows = nCols = _dimA;
00056         nColsB = _numColsB;
00057 
00058         identifier = _id;
00059 
00060         A.setup("A", nRows, nCols, REAL, ACADO_LOCAL);
00061         B.setup("B", nRows, nColsB, REAL, ACADO_LOCAL);
00062 
00063         chol.setup(identifier + "_chol", A);
00064         solve.setup(identifier + "_solve", A, B);
00065 
00066         REUSE = false;
00067 
00068         return SUCCESSFUL_RETURN;
00069 }
00070 
00071 returnValue ExportCholeskySolver::setup()
00072 {
00073         unsigned flopsChol, flopsSolve;
00074 
00075         if (REUSE == true)
00076                 return RET_NOT_IMPLEMENTED_YET;
00077 
00078         ExportVariable sum("sum", 1, 1, REAL, ACADO_LOCAL, true);
00079         ExportVariable div("div", 1, 1, REAL, ACADO_LOCAL, true);
00080         ExportVariable ret("ret", 1, 1, INT, ACADO_LOCAL, true);
00081 
00082         chol.addVariable( sum );
00083         chol.addVariable( div );
00084         chol.setReturnValue( ret );
00085         chol.addStatement( ret == 0 );
00086 
00087         // Approximate number of flops
00088         flopsChol = nRows * nRows * nRows / 3;
00089 
00090         if (flopsChol < 128)
00091                 for(int ii = 0; ii < (int)nRows; ++ii)
00092                 {
00093                         for (int k = 0; k < ii; ++k)
00094                                 chol.addStatement( A.getElement(ii, k) == 0.0 );
00095 
00096                         /* j == i */
00097                         //              sum = H[ii * nCols + ii];
00098                         chol.addStatement( sum == A.getElement(ii, ii) );
00099                         for(int k = (ii - 1); k >= 0; --k)
00100                                 //                      sum -= A[k*NVMAX + i] * A[k*NVMAX + i];
00101                                 chol.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, ii) );
00102 
00103                         chol << "if (" << sum.getFullName() << "< 0.0) return 1;\n";
00104 
00105                         //              if ( sum > 0.0 )
00106                         //                      R[i*NVMAX + i] = sqrt( sum );
00107                         //              else
00108                         //              {
00109                         //                      hessianType = HST_SEMIDEF;
00110                         //                      return THROWERROR( RET_HESSIAN_NOT_SPD );
00111                         //              }
00112 
00113                         chol << A.getElement(ii, ii).get(0, 0) << " = sqrt(" << sum.getFullName() << ");\n";
00114                         chol << div.getFullName() << " = 1.0 / " << A.getElement(ii, ii).get(0, 0) << ";\n";
00115 
00116                         /* j > i */
00117                         for(int jj = (ii + 1); jj < (int)nRows; ++jj)
00118                         {
00119                                 //                      jj = FR_idx[j];
00120                                 //                      sum = H[jj*NVMAX + ii];
00121                                 chol.addStatement( sum == A.getElement(jj, ii) );
00122 
00123                                 for(int k = (ii - 1); k >= 0; --k)
00124                                         //                              sum -= R[k * NVMAX + ii] * R[k * NVMAX + jj];
00125                                         chol.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, jj) );
00126 
00127                                 //                      R[ii * NVMAX + jj] = sum / R[ii * NVMAX + ii];
00128                                 chol.addStatement( A.getElement(ii, jj) == sum * div );
00129                         }
00130                 }
00131         else
00132         {
00133                 ExportIndex ii, jj, k;
00134                 chol.acquire( ii ).acquire( jj ).acquire( k );
00135 
00136                 ExportForLoop iiLoop(ii, 0, nRows);
00137 
00138                 ExportForLoop kLoop(k, 0, ii);
00139                 kLoop.addStatement( A.getElement(ii, k) == 0.0 );
00140                 iiLoop.addStatement( kLoop );
00141 
00142                 iiLoop.addStatement( sum == A.getElement(ii, ii) );
00143 
00144                 ExportForLoop kLoop2(k, ii - 1, -1, -1);
00145                 kLoop2.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, ii) );
00146                 iiLoop.addStatement( kLoop2 );
00147 
00148                 iiLoop << "if (" << sum.getFullName() << "< 0.0) return 1;\n";
00149                 iiLoop << A.getElement(ii, ii).get(0, 0) << " = sqrt(" << sum.getFullName() << ");\n";
00150                 iiLoop << div.getFullName() << " = 1.0 / " << A.getElement(ii, ii).get(0, 0) << ";\n";
00151 
00152                 ExportForLoop jjLoop(jj, ii + 1, nRows);
00153                 jjLoop.addStatement( sum == A.getElement(jj, ii) );
00154 
00155                 ExportForLoop kLoop3(k, ii - 1, -1, -1);
00156                 kLoop3.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, jj) );
00157                 jjLoop.addStatement( kLoop3 );
00158 
00159                 jjLoop.addStatement( A.getElement(ii, jj) == sum * div );
00160 
00161                 iiLoop.addStatement( jjLoop );
00162 
00163                 chol.addStatement( iiLoop );
00164                 chol.release( ii ).release( jj ).release( k );
00165         }
00166 
00167         //
00168         // Setup evaluation of the solve function
00169         // Implements R^T X = B -> X = R^{-T} * B. B is replaced by the solution.
00170         //
00171 
00172         // Approximate number of flops
00173         flopsSolve = nRows * nRows * nColsB;
00174 
00175         solve.addVariable( sum );
00176 
00177         if (flopsSolve < 128)
00178                 for (unsigned col = 0; col < nColsB; ++col)
00179                         for(int i = 0; i < int(nRows); ++i)
00180                         {
00181                                 //                      sum = b[i];
00182                                 solve.addStatement( sum == B.getElement(i, col) );
00183 
00184                                 for(int j = 0; j < i; ++j)
00185                                         //                              sum -= R[j*NVMAX + i] * a[j];
00186                                         solve.addStatement( sum-= A.getElement(j, i) * B.getElement(j, col) );
00187 
00188                                 // TODO Error checking
00189                                 //                      if ( getAbs( R[i*NVMAX + i] ) > ZERO )
00190                                 //                              a[i] = sum / R[i*NVMAX + i];
00191                                 //                      else
00192                                 //                              return THROWERROR( RET_DIV_BY_ZERO );
00193 
00194                                 solve << B.getElement(i, col).get(0, 0) << " = " << sum.getFullName() << " / " << A.getElement(i, i).get(0, 0) << ";\n";
00195                         }
00196         else
00197         {
00198                 ExportIndex col, i, j;
00199                 solve.acquire( col ).acquire( i ).acquire( j );
00200 
00201                 ExportForLoop colLoop(col, 0, nColsB);
00202 
00203                 ExportForLoop iLoop(i, 0, nRows);
00204                 iLoop.addStatement( sum == B.getElement(i, col) );
00205 
00206                 ExportForLoop jLoop(j, 0, i);
00207                 jLoop.addStatement( sum-= A.getElement(j, i) * B.getElement(j, col) );
00208                 iLoop << jLoop;
00209 
00210                 iLoop << B.getElement(i, col).get(0, 0) << " = " << sum.getFullName() << " / " << A.getElement(i, i).get(0, 0) << ";\n";
00211 
00212                 colLoop << iLoop;
00213                 solve << colLoop;
00214                 solve.release( col ).release( i ).release( j );
00215         }
00216 
00217         return SUCCESSFUL_RETURN;
00218 }
00219 
00220 returnValue ExportCholeskySolver::getCode( ExportStatementBlock& code )
00221 {
00222         code.addFunction( chol );
00223         code.addFunction( solve );
00224 
00225         return SUCCESSFUL_RETURN;
00226 }
00227 
00228 returnValue ExportCholeskySolver::getDataDeclarations(  ExportStatementBlock& declarations,
00229                                                                                                                 ExportStruct dataStruct
00230                                                                                                                 ) const
00231 {
00232         return SUCCESSFUL_RETURN;
00233 }
00234 
00235 returnValue ExportCholeskySolver::getFunctionDeclarations(      ExportStatementBlock& declarations
00236                                                                                                                         ) const
00237 {
00238         declarations.addDeclaration( chol );
00239         declarations.addDeclaration( solve );
00240 
00241         return SUCCESSFUL_RETURN;
00242 }
00243 
00244 const ExportFunction& ExportCholeskySolver::getCholeskyFunction() const
00245 {
00246         return chol;
00247 }
00248 
00249 const ExportFunction& ExportCholeskySolver::getSolveFunction() const
00250 {
00251         return solve;
00252 }
00253 
00254 returnValue ExportCholeskySolver::appendVariableNames( std::stringstream& string )
00255 {
00256         return SUCCESSFUL_RETURN;
00257 }
00258 
00259 CLOSE_NAMESPACE_ACADO
00260 
00261 // end of file.


acado
Author(s): Milan Vukov, Rien Quirynen
autogenerated on Sat Jun 8 2019 19:37:01