00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00034 #include <acado/code_generation/export_arithmetic_statement.hpp>
00035 #include <acado/code_generation/export_variable_internal.hpp>
00036
00037 #include <iomanip>
00038
00039 using namespace std;
00040 BEGIN_NAMESPACE_ACADO
00041
00042
00043
00044
00045
00046 ExportArithmeticStatement::ExportArithmeticStatement( )
00047 {
00048 op0 = ESO_UNDEFINED;
00049 op1 = ESO_UNDEFINED;
00050 op2 = ESO_UNDEFINED;
00051 }
00052
00053 ExportArithmeticStatement::ExportArithmeticStatement( const ExportVariable& _lhs,
00054 ExportStatementOperator _op0,
00055 const ExportVariable& _rhs1,
00056 ExportStatementOperator _op1,
00057 const ExportVariable& _rhs2,
00058 ExportStatementOperator _op2,
00059 const ExportVariable& _rhs3
00060 ) : ExportStatement( )
00061 {
00062 ASSERT( ( _op0 == ESO_UNDEFINED ) || ( _op0 == ESO_ASSIGN ) || ( _op0 == ESO_ADD_ASSIGN ) || ( _op0 == ESO_SUBTRACT_ASSIGN ) );
00063 ASSERT( ( _op2 == ESO_UNDEFINED ) || ( _op2 == ESO_ADD ) || ( _op2 == ESO_SUBTRACT ) );
00064
00065 lhs = _lhs;
00066 rhs1 = _rhs1;
00067 rhs2 = _rhs2;
00068 rhs3 = _rhs3;
00069
00070 op0 = _op0;
00071 op1 = _op1;
00072 op2 = _op2;
00073 }
00074
00075 ExportArithmeticStatement::~ExportArithmeticStatement( )
00076 {}
00077
00078 ExportStatement* ExportArithmeticStatement::clone( ) const
00079 {
00080 return new ExportArithmeticStatement(*this);
00081 }
00082
00083
00084 uint ExportArithmeticStatement::getNumRows( ) const
00085 {
00086 if ( rhs1.isNull() )
00087 return 0;
00088
00089 if (op1 != ESO_MULTIPLY_TRANSPOSE)
00090 return rhs1->getNumRows( );
00091
00092 return rhs1->getNumCols( );
00093 }
00094
00095
00096 uint ExportArithmeticStatement::getNumCols( ) const
00097 {
00098 if ( rhs1.isNull() )
00099 return 0;
00100
00101 if ( rhs2.isNull() )
00102 return rhs1->getNumCols( );
00103
00104 return rhs2->getNumCols( );
00105 }
00106
00107
00108 returnValue ExportArithmeticStatement::exportDataDeclaration( std::ostream& stream,
00109 const std::string& _realString,
00110 const std::string& _intString,
00111 int _precision
00112 ) const
00113 {
00114 return SUCCESSFUL_RETURN;
00115 }
00116
00117
00118 returnValue ExportArithmeticStatement::exportCode( std::ostream& stream,
00119 const std::string& _realString,
00120 const std::string& _intString,
00121 int _precision
00122 ) const
00123 {
00124 ASSERT(lhs.isNull() == false);
00125
00126 if (lhs.getDim() == 0 || rhs1.getDim() == 0 || rhs2.getDim() == 0)
00127 return SUCCESSFUL_RETURN;
00128
00129 if (lhs->isGiven() == true && lhs->getDim() > 0)
00130 {
00131 LOG( LVL_ERROR ) << "Left hand side ('" << lhs.getFullName() << "') of an arithmetic "
00132 "expression is given." << endl;
00133 return ACADOERROR(RET_INVALID_ARGUMENTS);
00134 }
00135
00136 if (memAllocator == 0)
00137 return ACADOERRORTEXT(RET_INVALID_ARGUMENTS, "Memory allocator is not defined.");
00138
00139 IoFormatter iof( stream );
00140 iof.set(_precision, iof.width, iof.flags);
00141
00142 switch ( op1 )
00143 {
00144 case ESO_ADD:
00145 return exportCodeAddSubtract(stream, "+", _realString, _intString);
00146
00147 case ESO_SUBTRACT:
00148 return exportCodeAddSubtract(stream, "-", _realString, _intString);
00149
00150 case ESO_ADD_ASSIGN:
00151 return exportCodeAssign(stream, "+=", _realString, _intString);
00152
00153 case ESO_SUBTRACT_ASSIGN:
00154 return exportCodeAssign(stream, "-=", _realString, _intString);
00155
00156 case ESO_MULTIPLY:
00157 return exportCodeMultiply(stream, false, _realString, _intString);
00158
00159 case ESO_MULTIPLY_TRANSPOSE:
00160 return exportCodeMultiply(stream, true, _realString, _intString);
00161
00162 case ESO_ASSIGN:
00163 return exportCodeAssign(stream, "=", _realString, _intString);
00164
00165 default:
00166 return ACADOERROR( RET_UNKNOWN_BUG );
00167 }
00168
00169 iof.reset();
00170
00171 return ACADOERROR( RET_UNKNOWN_BUG );
00172 }
00173
00174
00175
00176
00177
00178 returnValue ExportArithmeticStatement::exportCodeAddSubtract( std::ostream& stream,
00179 const std::string& _sign,
00180 const std::string& _realString,
00181 const std::string& _intString
00182 ) const
00183 {
00184 if ( ( rhs1->getNumRows() != rhs2->getNumRows() ) || ( rhs1->getNumCols() != rhs2->getNumCols() ) )
00185 return ACADOERROR( RET_VECTOR_DIMENSION_MISMATCH );
00186
00187 if (rhs1->getNumRows() != lhs->getNumRows() || rhs1->getNumCols() != lhs->getNumCols())
00188 {
00189 LOG( LVL_DEBUG )
00190 << "lhs name is " << lhs.getName() << ", size: " << lhs.getNumRows() << " x " << lhs.getNumCols() << endl
00191 << "rhs1 name is " << rhs1.getName() << ", size: " << rhs1.getNumRows() << " x " << rhs1.getNumCols() << endl;
00192
00193 return ACADOERROR( RET_VECTOR_DIMENSION_MISMATCH );
00194 }
00195
00196
00197
00198
00199 unsigned numberOfFlops = lhs->getNumRows() * lhs->getNumCols();
00200
00201
00202
00203
00204
00205 bool optimizationsAllowed = ( rhs1->isGiven() == false ) && ( rhs2->isGiven() == false );
00206
00207 if (numberOfFlops < 4096 || optimizationsAllowed == false)
00208 {
00209 for( uint i=0; i<getNumRows( ); ++i )
00210 for( uint j=0; j<getNumCols( ); ++j )
00211 {
00212 if ( ( op0 != ESO_ASSIGN ) &&
00213 ( rhs1->isGiven(i,j) == true ) && ( rhs2->isGiven(i,j) == true ) )
00214 {
00215
00216 if ( ( op1 == ESO_ADD ) && ( acadoIsZero(rhs1(i, j) + rhs2(i, j)) == true ) )
00217 continue;
00218
00219 if ( ( op1 == ESO_SUBTRACT ) && ( acadoIsZero( rhs1(i, j) - rhs2(i, j)) == true ) )
00220 continue;
00221 }
00222
00223 stream << lhs.get(i, j) << " " << getAssignString() << " ";
00224
00225 if ( rhs1->isZero(i, j) == false )
00226 {
00227 stream << rhs1->get(i, j);
00228 if ( rhs2->isZero(i,j) == false )
00229 stream << _sign << " " << rhs2->get(i, j) << ";\n";
00230 else
00231 stream << ";" << endl;
00232 }
00233 else
00234 {
00235 if (rhs2->isZero(i, j) == false)
00236 stream << _sign << " " << rhs2->get(i, j) << ";\n";
00237 else
00238 stream << "0.0;\n";
00239 }
00240 }
00241 }
00242 else if ( numberOfFlops < 32768 )
00243 {
00244 ExportIndex ii;
00245 memAllocator->acquire( ii );
00246
00247 stream << "for (" << ii.getName() << " = 0; ";
00248 stream << ii.getName() << " < " << getNumRows() << "; ";
00249 stream << "++" << ii.getName() << ")\n{\n";
00250
00251 for(unsigned j = 0; j < getNumCols( ); ++j)
00252 {
00253 stream << lhs->get(ii, j) << " " << getAssignString();
00254 stream << _sign << " " << rhs2->get(ii, j) << ";\n";
00255 }
00256
00257 stream << "\n{\n";
00258
00259 memAllocator->release( ii );
00260 }
00261 else
00262 {
00263 ExportIndex ii, jj;
00264 memAllocator->acquire( ii );
00265 memAllocator->acquire( jj );
00266
00267 stream << "for (" << ii.getName() << " = 0; "
00268 << ii.getName() << " < " << getNumRows() <<"; "
00269 << "++" << ii.getName() << ")\n{\n";
00270
00271 stream << "for (" << jj.getName() << " = 0; "
00272 << jj.getName() << " < " << getNumCols() <<"; "
00273 << "++" << jj.getName() << ")\n{\n";
00274
00275 stream << lhs->get(ii, jj) << " " << getAssignString()
00276 << _sign << " " << rhs2->get(ii, jj) << ";\n";
00277
00278 stream << "\n}\n"
00279 << "\n}\n";
00280
00281 memAllocator->release( ii );
00282 memAllocator->release( jj );
00283 }
00284
00285 return SUCCESSFUL_RETURN;
00286 }
00287
00288 returnValue ExportArithmeticStatement::exportCodeMultiply( std::ostream& stream,
00289 bool transposeRhs1,
00290 const std::string& _realString,
00291 const std::string& _intString
00292 ) const
00293 {
00294 uint nRowsRhs1;
00295 uint nColsRhs1;
00296
00297 if ( transposeRhs1 == false )
00298 {
00299 nRowsRhs1 = rhs1->getNumRows( );
00300 nColsRhs1 = rhs1->getNumCols( );
00301 }
00302 else
00303 {
00304 nRowsRhs1 = rhs1->getNumCols( );
00305 nColsRhs1 = rhs1->getNumRows( );
00306 }
00307
00308 if ( ( nColsRhs1 != rhs2->getNumRows( ) ) ||
00309 ( nRowsRhs1 != lhs->getNumRows( ) ) ||
00310 ( rhs2->getNumCols( ) != lhs->getNumCols( ) ) )
00311 return ACADOERROR( RET_VECTOR_DIMENSION_MISMATCH );
00312
00313 char sign[2] = "+";
00314
00315 if ( op2 != ESO_UNDEFINED )
00316 {
00317 if ( ( rhs3->getNumRows( ) != lhs->getNumRows( ) ) ||
00318 ( rhs3->getNumCols( ) != lhs->getNumCols( ) ) )
00319 return ACADOERROR( RET_VECTOR_DIMENSION_MISMATCH );
00320
00321 if ( op2 == ESO_SUBTRACT )
00322 sign[0] = '-';
00323 }
00324
00325 bool allZero;
00326
00327 ExportIndex ii, iiRhs1;
00328 ExportIndex jj, jjRhs1;
00329 ExportIndex kk, kkRhs1;
00330
00331
00332
00333
00334 unsigned numberOfFlops = nRowsRhs1 * rhs2->getNumRows( ) * rhs2->getNumCols();
00335
00336
00337
00338
00339
00340 bool optimizationsAllowed =
00341 rhs1->isGiven() == false && rhs2->isGiven() == false;
00342 if (op2 == ESO_ADD || op2 == ESO_SUBTRACT)
00343 optimizationsAllowed &= rhs3.isGiven() == false;
00344
00345
00346
00347
00348 if (numberOfFlops < 4096 || optimizationsAllowed == false)
00349 {
00350
00351
00352
00353
00354 for(uint i = 0; i < getNumRows( ); ++i)
00355 {
00356 ii = i;
00357
00358 for(uint j = 0; j < getNumCols( ); ++j)
00359 {
00360 allZero = true;
00361
00362 stream << lhs->get(ii,j) << " " << getAssignString() << " ";
00363
00364 for(uint k = 0; k < nColsRhs1; ++k)
00365 {
00366 kk = k;
00367 if ( transposeRhs1 == false )
00368 {
00369 iiRhs1 = ii;
00370 kkRhs1 = kk;
00371 }
00372 else
00373 {
00374 iiRhs1 = kk;
00375 kkRhs1 = ii;
00376 }
00377
00378 if ( ( rhs1->isZero(iiRhs1,kkRhs1) == false ) &&
00379 ( rhs2->isZero(kk,j) == false ) )
00380 {
00381 allZero = false;
00382
00383 if ( rhs1->isOne(iiRhs1,kkRhs1) == false )
00384 {
00385 stream << sign << " " << rhs1->get(iiRhs1,kkRhs1);
00386
00387 if ( rhs2->isOne(kk,j) == false )
00388 stream << "*" << rhs2->get(kk, j);
00389 }
00390 else
00391 {
00392 if ( rhs2->isOne(kk,j) == false )
00393 stream << " " << sign << rhs2->get(kk,j);
00394 else
00395 stream << " " << sign << " 1.0";
00396 }
00397 }
00398 }
00399
00400 if (op2 == ESO_ADD && rhs3->isZero(ii, j) == false)
00401 stream << " + " << rhs3->get(ii, j);
00402 if (op2 == ESO_SUBTRACT && rhs3->isZero(ii, j) == false)
00403 stream << " - " << rhs3->get(ii, j);
00404 if (op2 == ESO_UNDEFINED && allZero == true)
00405 stream << " 0.0;\n";
00406
00407 stream << ";\n";
00408 }
00409 }
00410 }
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459
00460
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479 else
00480 {
00481
00482
00483
00484
00485 memAllocator->acquire( ii );
00486 memAllocator->acquire( jj );
00487 memAllocator->acquire( kk );
00488
00489
00490 stream << "for (" << ii.getName() << " = 0; ";
00491 stream << ii.getName() << " < " << getNumRows() <<"; ";
00492 stream << "++" << ii.getName() << ")\n{\n";
00493
00494
00495 stream << "for (" << jj.getName() << " = 0; ";
00496 stream << jj.getName() << " < " << getNumCols() <<"; ";
00497 stream << "++" << jj.getName() << ")\n{\n";
00498
00499 stream << _realString << " t = 0.0;" << endl;
00500
00501
00502 stream << "for (" << kk.getName() << " = 0; ";
00503 stream << kk.getName() << " < " << nColsRhs1 <<"; ";
00504 stream << "++" << kk.getName() << ")\n{\n";
00505
00506 if ( transposeRhs1 == false )
00507 {
00508 iiRhs1 = ii;
00509 kkRhs1 = kk;
00510 }
00511 else
00512 {
00513 iiRhs1 = kk;
00514 kkRhs1 = ii;
00515 }
00516 stream << "t += " << sign << " " << rhs1->get(iiRhs1, kkRhs1) << "*" << rhs2->get(kk, jj) << ";";
00517 stream << "\n}\n";
00518
00519 if (lhs.isCalledByValue() == true)
00520 stream << lhs.getFullName();
00521 else
00522 stream << lhs->get(ii, jj);
00523
00524 stream << " " << getAssignString() << " t";
00525
00526 if (op2 == ESO_ADD)
00527 {
00528 stream << " + " << rhs3->get(ii, jj);
00529 }
00530 else if (op2 == ESO_SUBTRACT)
00531 {
00532 stream << " - " << rhs3->get(ii, jj);
00533 }
00534 stream << ";\n";
00535
00536 stream << "}\n";
00537 stream << "}\n";
00538
00539 memAllocator->release( ii );
00540 memAllocator->release( jj );
00541 memAllocator->release( kk );
00542 }
00543
00544 return SUCCESSFUL_RETURN;
00545 }
00546
00547
00548 returnValue ExportArithmeticStatement::exportCodeAssign( std::ostream& stream,
00549 const std::string& _op,
00550 const std::string& _realString,
00551 const std::string& _intString
00552 ) const
00553 {
00554 if ( ( rhs1.getNumRows( ) != lhs.getNumRows( ) ) || ( rhs1.getNumCols( ) != lhs.getNumCols( ) ) )
00555 {
00556 LOG( LVL_DEBUG ) << "lhs name is " << lhs.getName()
00557 << ", size: " << lhs.getNumRows() << " x " << lhs.getNumCols()
00558 << "rhs1 name is " << rhs1.getName()
00559 << ", size: " << rhs1.getNumRows() << " x " << rhs1.getNumCols() << endl;
00560
00561 return ACADOERROR( RET_VECTOR_DIMENSION_MISMATCH );
00562 }
00563
00564 unsigned numOps = lhs.getNumRows() * lhs.getNumCols();
00565
00566 if ( lhs.isSubMatrix() == false && lhs.getDim() > 1 &&
00567 rhs1.isGiven() == true && rhs1.getGivenMatrix().isZero() == true &&
00568 _op == "=" )
00569 {
00570 stream << "{ int lCopy; for (lCopy = 0; lCopy < "<< lhs.getDim() << "; lCopy++) "
00571 << lhs.getFullName() << "[ lCopy ] = 0; }" << endl;
00572 }
00573 else if ((numOps < 128) || (rhs1.isGiven() == true))
00574 {
00575 for(unsigned i = 0; i < lhs.getNumRows( ); ++i)
00576 for(unsigned j = 0; j < lhs.getNumCols( ); ++j)
00577 if ( ( _op == "=" ) || ( rhs1.isZero(i,j) == false ) )
00578 {
00579 stream << lhs->get(i, j) << " " << _op << " ";
00580 if (rhs1->isGiven() == true)
00581 {
00582 if (lhs->getType() == REAL || lhs->getType() == STATIC_CONST_REAL)
00583 stream << scientific << rhs1(i, j);
00584 else
00585 stream << (int)rhs1(i, j);
00586
00587 stream << ";\n";
00588 }
00589 else
00590 {
00591 stream << rhs1->get(i, j) << ";\n";
00592 }
00593 }
00594 }
00595 else
00596 {
00597 ExportIndex ii, jj;
00598
00599 if (lhs.isVector() && rhs1.isVector())
00600 {
00601 memAllocator->acquire( ii );
00602
00603 stream << "for (" << ii.get() << " = 0; " << ii.get() << " < ";
00604
00605 if (lhs->getNumCols() == 1)
00606 {
00607 stream << lhs->getNumRows() << "; ++" << ii.getName() << ")" << endl
00608 << lhs.get(ii, 0) << " " << _op << " " << rhs1.get(ii, 0)
00609 << ";" << endl << endl;
00610 }
00611 else
00612 {
00613 stream << lhs.getNumCols() << "; ++" << ii.getName() << ")" << endl;
00614 stream << lhs.get(0, ii) << " " << _op << " " << rhs1.get(0, ii)
00615 << ";" << endl << endl;
00616 }
00617
00618 memAllocator->release( ii );
00619 }
00620 else
00621 {
00622 memAllocator->acquire( ii );
00623 memAllocator->acquire( jj );
00624
00625 stream << "for (" << ii.getName() << " = 0;" << ii.getName() << " < "
00626 << lhs->getNumRows() << "; ++" << ii.getName() << ")" << endl;
00627
00628 stream << "for (" << jj.getName() << " = 0;" << jj.getName() << " < "
00629 << lhs->getNumCols() << "; ++" << jj.getName() << ")" << endl;
00630
00631 stream << lhs->get(ii, jj) << " " << _op << " " << rhs1->get(ii, jj) << ";" << endl;
00632
00633 memAllocator->release( ii );
00634 memAllocator->release( jj );
00635 }
00636 }
00637
00638 return SUCCESSFUL_RETURN;
00639 }
00640
00641
00642 std::string ExportArithmeticStatement::getAssignString( ) const
00643 {
00644 switch ( op0 )
00645 {
00646 case ESO_ASSIGN:
00647 return "=";
00648
00649 case ESO_ADD_ASSIGN:
00650 return "+=";
00651
00652 case ESO_SUBTRACT_ASSIGN:
00653 return "-=";
00654
00655 default:
00656 return "foo";
00657 }
00658 }
00659
00660 ExportArithmeticStatement& ExportArithmeticStatement::allocate( MemoryAllocatorPtr allocator )
00661 {
00662 memAllocator = allocator;
00663
00664 return *this;
00665 }
00666
00667 CLOSE_NAMESPACE_ACADO
00668
00669