00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/linear_solvers/export_cholesky_solver.hpp>
00035
00036 using namespace std;
00037
00038 BEGIN_NAMESPACE_ACADO
00039
00040 ExportCholeskySolver::ExportCholeskySolver( UserInteraction* _userInteraction,
00041 const std::string& _commonHeaderName
00042 ) : ExportLinearSolver(_userInteraction, _commonHeaderName)
00043 {
00044 nColsB = 0;
00045 }
00046
00047 ExportCholeskySolver::~ExportCholeskySolver()
00048 {}
00049
00050 returnValue ExportCholeskySolver::init( unsigned _dimA,
00051 unsigned _numColsB,
00052 const std::string& _id
00053 )
00054 {
00055 nRows = nCols = _dimA;
00056 nColsB = _numColsB;
00057
00058 identifier = _id;
00059
00060 A.setup("A", nRows, nCols, REAL, ACADO_LOCAL);
00061 B.setup("B", nRows, nColsB, REAL, ACADO_LOCAL);
00062
00063 chol.setup(identifier + "_chol", A);
00064 solve.setup(identifier + "_solve", A, B);
00065
00066 REUSE = false;
00067
00068 return SUCCESSFUL_RETURN;
00069 }
00070
00071 returnValue ExportCholeskySolver::setup()
00072 {
00073 unsigned flopsChol, flopsSolve;
00074
00075 if (REUSE == true)
00076 return RET_NOT_IMPLEMENTED_YET;
00077
00078 ExportVariable sum("sum", 1, 1, REAL, ACADO_LOCAL, true);
00079 ExportVariable div("div", 1, 1, REAL, ACADO_LOCAL, true);
00080 ExportVariable ret("ret", 1, 1, INT, ACADO_LOCAL, true);
00081
00082 chol.addVariable( sum );
00083 chol.addVariable( div );
00084 chol.setReturnValue( ret );
00085 chol.addStatement( ret == 0 );
00086
00087
00088 flopsChol = nRows * nRows * nRows / 3;
00089
00090 if (flopsChol < 128)
00091 for(int ii = 0; ii < (int)nRows; ++ii)
00092 {
00093 for (int k = 0; k < ii; ++k)
00094 chol.addStatement( A.getElement(ii, k) == 0.0 );
00095
00096
00097
00098 chol.addStatement( sum == A.getElement(ii, ii) );
00099 for(int k = (ii - 1); k >= 0; --k)
00100
00101 chol.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, ii) );
00102
00103 chol << "if (" << sum.getFullName() << "< 0.0) return 1;\n";
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113 chol << A.getElement(ii, ii).get(0, 0) << " = sqrt(" << sum.getFullName() << ");\n";
00114 chol << div.getFullName() << " = 1.0 / " << A.getElement(ii, ii).get(0, 0) << ";\n";
00115
00116
00117 for(int jj = (ii + 1); jj < (int)nRows; ++jj)
00118 {
00119
00120
00121 chol.addStatement( sum == A.getElement(jj, ii) );
00122
00123 for(int k = (ii - 1); k >= 0; --k)
00124
00125 chol.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, jj) );
00126
00127
00128 chol.addStatement( A.getElement(ii, jj) == sum * div );
00129 }
00130 }
00131 else
00132 {
00133 ExportIndex ii, jj, k;
00134 chol.acquire( ii ).acquire( jj ).acquire( k );
00135
00136 ExportForLoop iiLoop(ii, 0, nRows);
00137
00138 ExportForLoop kLoop(k, 0, ii);
00139 kLoop.addStatement( A.getElement(ii, k) == 0.0 );
00140 iiLoop.addStatement( kLoop );
00141
00142 iiLoop.addStatement( sum == A.getElement(ii, ii) );
00143
00144 ExportForLoop kLoop2(k, ii - 1, -1, -1);
00145 kLoop2.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, ii) );
00146 iiLoop.addStatement( kLoop2 );
00147
00148 iiLoop << "if (" << sum.getFullName() << "< 0.0) return 1;\n";
00149 iiLoop << A.getElement(ii, ii).get(0, 0) << " = sqrt(" << sum.getFullName() << ");\n";
00150 iiLoop << div.getFullName() << " = 1.0 / " << A.getElement(ii, ii).get(0, 0) << ";\n";
00151
00152 ExportForLoop jjLoop(jj, ii + 1, nRows);
00153 jjLoop.addStatement( sum == A.getElement(jj, ii) );
00154
00155 ExportForLoop kLoop3(k, ii - 1, -1, -1);
00156 kLoop3.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, jj) );
00157 jjLoop.addStatement( kLoop3 );
00158
00159 jjLoop.addStatement( A.getElement(ii, jj) == sum * div );
00160
00161 iiLoop.addStatement( jjLoop );
00162
00163 chol.addStatement( iiLoop );
00164 chol.release( ii ).release( jj ).release( k );
00165 }
00166
00167
00168
00169
00170
00171
00172
00173 flopsSolve = nRows * nRows * nColsB;
00174
00175 solve.addVariable( sum );
00176
00177 if (flopsSolve < 128)
00178 for (unsigned col = 0; col < nColsB; ++col)
00179 for(int i = 0; i < int(nRows); ++i)
00180 {
00181
00182 solve.addStatement( sum == B.getElement(i, col) );
00183
00184 for(int j = 0; j < i; ++j)
00185
00186 solve.addStatement( sum-= A.getElement(j, i) * B.getElement(j, col) );
00187
00188
00189
00190
00191
00192
00193
00194 solve << B.getElement(i, col).get(0, 0) << " = " << sum.getFullName() << " / " << A.getElement(i, i).get(0, 0) << ";\n";
00195 }
00196 else
00197 {
00198 ExportIndex col, i, j;
00199 solve.acquire( col ).acquire( i ).acquire( j );
00200
00201 ExportForLoop colLoop(col, 0, nColsB);
00202
00203 ExportForLoop iLoop(i, 0, nRows);
00204 iLoop.addStatement( sum == B.getElement(i, col) );
00205
00206 ExportForLoop jLoop(j, 0, i);
00207 jLoop.addStatement( sum-= A.getElement(j, i) * B.getElement(j, col) );
00208 iLoop << jLoop;
00209
00210 iLoop << B.getElement(i, col).get(0, 0) << " = " << sum.getFullName() << " / " << A.getElement(i, i).get(0, 0) << ";\n";
00211
00212 colLoop << iLoop;
00213 solve << colLoop;
00214 solve.release( col ).release( i ).release( j );
00215 }
00216
00217 return SUCCESSFUL_RETURN;
00218 }
00219
00220 returnValue ExportCholeskySolver::getCode( ExportStatementBlock& code )
00221 {
00222 code.addFunction( chol );
00223 code.addFunction( solve );
00224
00225 return SUCCESSFUL_RETURN;
00226 }
00227
00228 returnValue ExportCholeskySolver::getDataDeclarations( ExportStatementBlock& declarations,
00229 ExportStruct dataStruct
00230 ) const
00231 {
00232 return SUCCESSFUL_RETURN;
00233 }
00234
00235 returnValue ExportCholeskySolver::getFunctionDeclarations( ExportStatementBlock& declarations
00236 ) const
00237 {
00238 declarations.addDeclaration( chol );
00239 declarations.addDeclaration( solve );
00240
00241 return SUCCESSFUL_RETURN;
00242 }
00243
00244 const ExportFunction& ExportCholeskySolver::getCholeskyFunction() const
00245 {
00246 return chol;
00247 }
00248
00249 const ExportFunction& ExportCholeskySolver::getSolveFunction() const
00250 {
00251 return solve;
00252 }
00253
00254 returnValue ExportCholeskySolver::appendVariableNames( std::stringstream& string )
00255 {
00256 return SUCCESSFUL_RETURN;
00257 }
00258
00259 CLOSE_NAMESPACE_ACADO
00260
00261