export_cholesky_solver.cpp
Go to the documentation of this file.
1 /*
2  * This file is part of ACADO Toolkit.
3  *
4  * ACADO Toolkit -- A Toolkit for Automatic Control and Dynamic Optimization.
5  * Copyright (C) 2008-2014 by Boris Houska, Hans Joachim Ferreau,
6  * Milan Vukov, Rien Quirynen, KU Leuven.
7  * Developed within the Optimization in Engineering Center (OPTEC)
8  * under supervision of Moritz Diehl. All rights reserved.
9  *
10  * ACADO Toolkit is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public
12  * License as published by the Free Software Foundation; either
13  * version 3 of the License, or (at your option) any later version.
14  *
15  * ACADO Toolkit is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with ACADO Toolkit; if not, write to the Free Software
22  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
23  *
24  */
25 
26 
27 
35 
36 using namespace std;
37 
39 
41  const std::string& _commonHeaderName
42  ) : ExportLinearSolver(_userInteraction, _commonHeaderName)
43 {
44  nColsB = 0;
45 }
46 
48 {}
49 
51  unsigned _numColsB,
52  const std::string& _id
53  )
54 {
55  nRows = nCols = _dimA;
56  nColsB = _numColsB;
57 
58  identifier = _id;
59 
60  A.setup("A", nRows, nCols, REAL, ACADO_LOCAL);
62 
63  chol.setup(identifier + "_chol", A);
64  solve.setup(identifier + "_solve", A, B);
65 
66  REUSE = false;
67 
68  return SUCCESSFUL_RETURN;
69 }
70 
72 {
73  unsigned flopsChol, flopsSolve;
74 
75  if (REUSE == true)
78 
79  if (nRightHandSides > 0)
81 
82  ExportVariable sum("sum", 1, 1, REAL, ACADO_LOCAL, true);
83  ExportVariable div("div", 1, 1, REAL, ACADO_LOCAL, true);
84  ExportVariable ret("ret", 1, 1, INT, ACADO_LOCAL, true);
85 
86  chol.addVariable( sum );
87  chol.addVariable( div );
88  chol.setReturnValue( ret );
89  chol.addStatement( ret == 0 );
90 
91  // Approximate number of flops
92  flopsChol = nRows * nRows * nRows / 3;
93 
94  if (flopsChol < 128)
95  for(int ii = 0; ii < (int)nRows; ++ii)
96  {
97  for (int k = 0; k < ii; ++k)
98  chol.addStatement( A.getElement(ii, k) == 0.0 );
99 
100  /* j == i */
101  // sum = H[ii * nCols + ii];
102  chol.addStatement( sum == A.getElement(ii, ii) );
103  for(int k = (ii - 1); k >= 0; --k)
104  // sum -= A[k*NVMAX + i] * A[k*NVMAX + i];
105  chol.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, ii) );
106 
107  chol << "if (" << sum.getFullName() << "< 0.0) return 1;\n";
108 
109  // if ( sum > 0.0 )
110  // R[i*NVMAX + i] = sqrt( sum );
111  // else
112  // {
113  // hessianType = HST_SEMIDEF;
114  // return THROWERROR( RET_HESSIAN_NOT_SPD );
115  // }
116 
117  chol << A.getElement(ii, ii).get(0, 0) << " = sqrt(" << sum.getFullName() << ");\n";
118  chol << div.getFullName() << " = 1.0 / " << A.getElement(ii, ii).get(0, 0) << ";\n";
119 
120  /* j > i */
121  for(int jj = (ii + 1); jj < (int)nRows; ++jj)
122  {
123  // jj = FR_idx[j];
124  // sum = H[jj*NVMAX + ii];
125  chol.addStatement( sum == A.getElement(jj, ii) );
126 
127  for(int k = (ii - 1); k >= 0; --k)
128  // sum -= R[k * NVMAX + ii] * R[k * NVMAX + jj];
129  chol.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, jj) );
130 
131  // R[ii * NVMAX + jj] = sum / R[ii * NVMAX + ii];
132  chol.addStatement( A.getElement(ii, jj) == sum * div );
133  }
134  }
135  else
136  {
137  ExportIndex ii, jj, k;
138  chol.acquire( ii ).acquire( jj ).acquire( k );
139 
140  ExportForLoop iiLoop(ii, 0, nRows);
141 
142  ExportForLoop kLoop(k, 0, ii);
143  kLoop.addStatement( A.getElement(ii, k) == 0.0 );
144  iiLoop.addStatement( kLoop );
145 
146  iiLoop.addStatement( sum == A.getElement(ii, ii) );
147 
148  ExportForLoop kLoop2(k, ii - 1, -1, -1);
149  kLoop2.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, ii) );
150  iiLoop.addStatement( kLoop2 );
151 
152  iiLoop << "if (" << sum.getFullName() << "< 0.0) return 1;\n";
153  iiLoop << A.getElement(ii, ii).get(0, 0) << " = sqrt(" << sum.getFullName() << ");\n";
154  iiLoop << div.getFullName() << " = 1.0 / " << A.getElement(ii, ii).get(0, 0) << ";\n";
155 
156  ExportForLoop jjLoop(jj, ii + 1, nRows);
157  jjLoop.addStatement( sum == A.getElement(jj, ii) );
158 
159  ExportForLoop kLoop3(k, ii - 1, -1, -1);
160  kLoop3.addStatement( sum -= A.getElement(k, ii) * A.getElement(k, jj) );
161  jjLoop.addStatement( kLoop3 );
162 
163  jjLoop.addStatement( A.getElement(ii, jj) == sum * div );
164 
165  iiLoop.addStatement( jjLoop );
166 
167  chol.addStatement( iiLoop );
168  chol.release( ii ).release( jj ).release( k );
169  }
170 
171  //
172  // Setup evaluation of the solve function
173  // Implements R^T X = B -> X = R^{-T} * B. B is replaced by the solution.
174  //
175 
176  // Approximate number of flops
177  flopsSolve = nRows * nRows * nColsB;
178 
179  solve.addVariable( sum );
180 
181  if (flopsSolve < 128)
182  for (unsigned col = 0; col < nColsB; ++col)
183  for(int i = 0; i < int(nRows); ++i)
184  {
185  // sum = b[i];
186  solve.addStatement( sum == B.getElement(i, col) );
187 
188  for(int j = 0; j < i; ++j)
189  // sum -= R[j*NVMAX + i] * a[j];
190  solve.addStatement( sum-= A.getElement(j, i) * B.getElement(j, col) );
191 
192  // TODO Error checking
193  // if ( getAbs( R[i*NVMAX + i] ) > ZERO )
194  // a[i] = sum / R[i*NVMAX + i];
195  // else
196  // return THROWERROR( RET_DIV_BY_ZERO );
197 
198  solve << B.getElement(i, col).get(0, 0) << " = " << sum.getFullName() << " / " << A.getElement(i, i).get(0, 0) << ";\n";
199  }
200  else
201  {
202  ExportIndex col, i, j;
203  solve.acquire( col ).acquire( i ).acquire( j );
204 
205  ExportForLoop colLoop(col, 0, nColsB);
206 
207  ExportForLoop iLoop(i, 0, nRows);
208  iLoop.addStatement( sum == B.getElement(i, col) );
209 
210  ExportForLoop jLoop(j, 0, i);
211  jLoop.addStatement( sum-= A.getElement(j, i) * B.getElement(j, col) );
212  iLoop << jLoop;
213 
214  iLoop << B.getElement(i, col).get(0, 0) << " = " << sum.getFullName() << " / " << A.getElement(i, i).get(0, 0) << ";\n";
215 
216  colLoop << iLoop;
217  solve << colLoop;
218  solve.release( col ).release( i ).release( j );
219  }
220 
221  return SUCCESSFUL_RETURN;
222 }
223 
225 {
226  code.addFunction( chol );
227  code.addFunction( solve );
228 
229  return SUCCESSFUL_RETURN;
230 }
231 
233  ExportStruct dataStruct
234  ) const
235 {
236  return SUCCESSFUL_RETURN;
237 }
238 
240  ) const
241 {
242  declarations.addDeclaration( chol );
243  declarations.addDeclaration( solve );
244 
245  return SUCCESSFUL_RETURN;
246 }
247 
249 {
250  return chol;
251 }
252 
254 {
255  return solve;
256 }
257 
259 {
260  return SUCCESSFUL_RETURN;
261 }
262 
264 
265 // end of file.
virtual returnValue getCode(ExportStatementBlock &code)
virtual returnValue setup()
ExportCholeskySolver(UserInteraction *_userInteraction=0, const std::string &_commonHeaderName="")
ExportVariable & setup(const std::string &_name, uint _nRows=1, uint _nCols=1, ExportType _type=REAL, ExportStruct _dataStruct=ACADO_LOCAL, bool _callItByValue=false, const std::string &_prefix=std::string())
const ExportFunction & getCholeskyFunction() const
Allows to pass back messages to the calling function.
virtual returnValue appendVariableNames(std::stringstream &string)
Allows to export code of a for-loop.
ExportVariable getElement(const ExportIndex &rowIdx, const ExportIndex &colIdx) const
#define CLOSE_NAMESPACE_ACADO
Defines a scalar-valued index variable to be used for exporting code.
ExportFunction & setup(const std::string &_name="defaultFunctionName", const ExportArgument &_argument1=emptyConstExportArgument, const ExportArgument &_argument2=emptyConstExportArgument, const ExportArgument &_argument3=emptyConstExportArgument, const ExportArgument &_argument4=emptyConstExportArgument, const ExportArgument &_argument5=emptyConstExportArgument, const ExportArgument &_argument6=emptyConstExportArgument, const ExportArgument &_argument7=emptyConstExportArgument, const ExportArgument &_argument8=emptyConstExportArgument, const ExportArgument &_argument9=emptyConstExportArgument)
const ExportFunction & getSolveFunction() const
ExportStruct
returnValue init(unsigned _dimA, unsigned _numColsB, const std::string &_id)
const std::string get(const ExportIndex &rowIdx, const ExportIndex &colIdx) const
virtual returnValue getDataDeclarations(ExportStatementBlock &declarations, ExportStruct dataStruct=ACADO_ANY) const
Encapsulates all user interaction for setting options, logging data and plotting results.
Allows to export code of an arbitrary function.
returnValue addStatement(const ExportStatement &_statement)
std::string getFullName() const
ExportFunction & setReturnValue(const ExportVariable &_functionReturnValue, bool _returnAsPointer=false)
virtual ExportFunction & acquire(ExportIndex &obj)
ExportFunction & addVariable(const ExportVariable &_var)
returnValue addDeclaration(const ExportVariable &_data, ExportStruct _dataStruct=ACADO_ANY)
virtual ExportFunction & release(const ExportIndex &obj)
#define BEGIN_NAMESPACE_ACADO
returnValue addFunction(const ExportFunction &_function)
ColXpr col(Index i)
Definition: BlockMethods.h:708
Allows to export automatically generated algorithms for solving linear systems of specific dimensions...
virtual returnValue getFunctionDeclarations(ExportStatementBlock &declarations) const
Allows to export code for a block of statements.
#define ACADOERROR(retval)
Defines a matrix-valued variable to be used for exporting code.


acado
Author(s): Milan Vukov, Rien Quirynen
autogenerated on Mon Jun 10 2019 12:34:33