ceres_solver.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 Simbe Robotics, Inc.
3  * Author: Steve Macenski (stevenmacenski@gmail.com)
4  */
5 
6 #include "ceres_solver.hpp"
7 #include <karto_sdk/Karto.h>
8 
9 #include "ros/console.h"
11 
13 
14 namespace solver_plugins
15 {
16 
17 /*****************************************************************************/
18 CeresSolver::CeresSolver() :
19  nodes_(new std::unordered_map<int, Eigen::Vector3d>()),
20  blocks_(new std::unordered_map<std::size_t,
21  ceres::ResidualBlockId>()),
22  problem_(NULL), was_constant_set_(false)
23 /*****************************************************************************/
24 {
25  ros::NodeHandle nh("~");
26  std::string solver_type, preconditioner_type, dogleg_type,
27  trust_strategy, loss_fn, mode;
28  nh.getParam("ceres_linear_solver", solver_type);
29  nh.getParam("ceres_preconditioner", preconditioner_type);
30  nh.getParam("ceres_dogleg_type", dogleg_type);
31  nh.getParam("ceres_trust_strategy", trust_strategy);
32  nh.getParam("ceres_loss_function", loss_fn);
33  nh.getParam("mode", mode);
34  nh.getParam("debug_logging", debug_logging_);
35 
36  corrections_.clear();
37  first_node_ = nodes_->end();
38 
39  // formulate problem
41 
42  // choose loss function default squared loss (NULL)
43  loss_function_ = NULL;
44  if (loss_fn == "HuberLoss")
45  {
46  ROS_INFO("CeresSolver: Using HuberLoss loss function.");
47  loss_function_ = new ceres::HuberLoss(0.7);
48  }
49  else if (loss_fn == "CauchyLoss")
50  {
51  ROS_INFO("CeresSolver: Using CauchyLoss loss function.");
52  loss_function_ = new ceres::CauchyLoss(0.7);
53  }
54 
55  // choose linear solver default CHOL
56  options_.linear_solver_type = ceres::SPARSE_NORMAL_CHOLESKY;
57  if (solver_type == "SPARSE_SCHUR")
58  {
59  ROS_INFO("CeresSolver: Using SPARSE_SCHUR solver.");
60  options_.linear_solver_type = ceres::SPARSE_SCHUR;
61  }
62  else if (solver_type == "ITERATIVE_SCHUR")
63  {
64  ROS_INFO("CeresSolver: Using ITERATIVE_SCHUR solver.");
65  options_.linear_solver_type = ceres::ITERATIVE_SCHUR;
66  }
67  else if (solver_type == "CGNR")
68  {
69  ROS_INFO("CeresSolver: Using CGNR solver.");
70  options_.linear_solver_type = ceres::CGNR;
71  }
72 
73  // choose preconditioner default Jacobi
74  options_.preconditioner_type = ceres::JACOBI;
75  if (preconditioner_type == "IDENTITY")
76  {
77  ROS_INFO("CeresSolver: Using IDENTITY preconditioner.");
78  options_.preconditioner_type = ceres::IDENTITY;
79  }
80  else if (preconditioner_type == "SCHUR_JACOBI")
81  {
82  ROS_INFO("CeresSolver: Using SCHUR_JACOBI preconditioner.");
83  options_.preconditioner_type = ceres::SCHUR_JACOBI;
84  }
85 
86  if (options_.preconditioner_type == ceres::CLUSTER_JACOBI ||
87  options_.preconditioner_type == ceres::CLUSTER_TRIDIAGONAL)
88  {
89  //default canonical view is O(n^2) which is unacceptable for
90  // problems of this size
91  options_.visibility_clustering_type = ceres::SINGLE_LINKAGE;
92  }
93 
94  // choose trust region strategy default LM
95  options_.trust_region_strategy_type = ceres::LEVENBERG_MARQUARDT;
96  if (trust_strategy == "DOGLEG")
97  {
98  ROS_INFO("CeresSolver: Using DOGLEG trust region strategy.");
99  options_.trust_region_strategy_type = ceres::DOGLEG;
100  }
101 
102  // choose dogleg type default traditional
103  if(options_.trust_region_strategy_type == ceres::DOGLEG)
104  {
105  options_.dogleg_type = ceres::TRADITIONAL_DOGLEG;
106  if (dogleg_type == "SUBSPACE_DOGLEG")
107  {
108  ROS_INFO("CeresSolver: Using SUBSPACE_DOGLEG dogleg type.");
109  options_.dogleg_type = ceres::SUBSPACE_DOGLEG;
110  }
111  }
112 
113  // a typical ros map is 5cm, this is 0.001, 50x the resolution
114  options_.function_tolerance = 1e-3;
115  options_.gradient_tolerance = 1e-6;
116  options_.parameter_tolerance = 1e-3;
117 
118  options_.sparse_linear_algebra_library_type = ceres::SUITE_SPARSE;
119  options_.max_num_consecutive_invalid_steps = 3;
120  options_.max_consecutive_nonmonotonic_steps =
121  options_.max_num_consecutive_invalid_steps;
122  options_.num_threads = 50;
123  options_.use_nonmonotonic_steps = true;
124  options_.jacobi_scaling = true;
125 
126  options_.min_relative_decrease = 1e-3;
127 
128  options_.initial_trust_region_radius = 1e4;
129  options_.max_trust_region_radius = 1e8;
130  options_.min_trust_region_radius = 1e-16;
131 
132  options_.min_lm_diagonal = 1e-6;
133  options_.max_lm_diagonal = 1e32;
134 
135  if(options_.linear_solver_type == ceres::SPARSE_NORMAL_CHOLESKY)
136  {
137  options_.dynamic_sparsity = true;
138  }
139 
140  if (mode == std::string("localization"))
141  {
142  // doubles the memory footprint, but lets us remove contraints faster
143  options_problem_.enable_fast_removal = true;
144  }
145 
146  problem_ = new ceres::Problem(options_problem_);
147 
148  return;
149 }
150 
151 /*****************************************************************************/
153 /*****************************************************************************/
154 {
155  if ( loss_function_ != NULL)
156  {
157  delete loss_function_;
158  }
159  if (nodes_ != NULL)
160  {
161  delete nodes_;
162  }
163  if (problem_ != NULL)
164  {
165  delete problem_;
166  }
167 }
168 
169 /*****************************************************************************/
171 /*****************************************************************************/
172 {
173  boost::mutex::scoped_lock lock(nodes_mutex_);
174 
175  if (nodes_->size() == 0)
176  {
177  ROS_ERROR("CeresSolver: Ceres was called when there are no nodes."
178  " This shouldn't happen.");
179  return;
180  }
181 
182  // populate contraint for static initial pose
183  if (!was_constant_set_ && first_node_ != nodes_->end())
184  {
185  ROS_DEBUG("CeresSolver: Setting first node as a constant pose:"
186  "%0.2f, %0.2f, %0.2f.", first_node_->second(0),
187  first_node_->second(1), first_node_->second(2));
188  problem_->SetParameterBlockConstant(&first_node_->second(0));
189  problem_->SetParameterBlockConstant(&first_node_->second(1));
190  problem_->SetParameterBlockConstant(&first_node_->second(2));
192  }
193 
194  const ros::Time start_time = ros::Time::now();
195  ceres::Solver::Summary summary;
196  ceres::Solve(options_, problem_, &summary);
197  if (debug_logging_)
198  {
199  std::cout << summary.FullReport() << '\n';
200  }
201 
202  if (!summary.IsSolutionUsable())
203  {
204  ROS_WARN("CeresSolver: "
205  "Ceres could not find a usable solution to optimize.");
206  return;
207  }
208 
209  // store corrected poses
210  if (!corrections_.empty())
211  {
212  corrections_.clear();
213  }
214  corrections_.reserve(nodes_->size());
215  karto::Pose2 pose;
216  ConstGraphIterator iter = nodes_->begin();
217  for ( iter; iter != nodes_->end(); ++iter )
218  {
219  pose.SetX(iter->second(0));
220  pose.SetY(iter->second(1));
221  pose.SetHeading(iter->second(2));
222  corrections_.push_back(std::make_pair(iter->first, pose));
223  }
224 
225  return;
226 }
227 
228 /*****************************************************************************/
230 /*****************************************************************************/
231 {
232  return corrections_;
233 }
234 
235 /*****************************************************************************/
237 /*****************************************************************************/
238 {
239  corrections_.clear();
240 }
241 
242 /*****************************************************************************/
244 /*****************************************************************************/
245 {
246  boost::mutex::scoped_lock lock(nodes_mutex_);
247 
248  corrections_.clear();
249  was_constant_set_ = false;
250 
251  if (problem_)
252  {
253  delete problem_;
254  }
255 
256  if (nodes_)
257  {
258  delete nodes_;
259  }
260 
261  if (blocks_)
262  {
263  delete blocks_;
264  }
265 
266  nodes_ = new std::unordered_map<int, Eigen::Vector3d>();
267  blocks_ = new std::unordered_map<std::size_t, ceres::ResidualBlockId>();
268  problem_ = new ceres::Problem(options_problem_);
269  first_node_ = nodes_->end();
270 
272 }
273 
274 /*****************************************************************************/
276 /*****************************************************************************/
277 {
278  // store nodes
279  if (!pVertex)
280  {
281  return;
282  }
283 
284  karto::Pose2 pose = pVertex->GetObject()->GetCorrectedPose();
285  Eigen::Vector3d pose2d(pose.GetX(), pose.GetY(), pose.GetHeading());
286 
287  const int id = pVertex->GetObject()->GetUniqueId();
288 
289  boost::mutex::scoped_lock lock(nodes_mutex_);
290  nodes_->insert(std::pair<int,Eigen::Vector3d>(id,pose2d));
291 
292  if (nodes_->size() == 1)
293  {
294  first_node_ = nodes_->find(id);
295  }
296 }
297 
298 /*****************************************************************************/
300 /*****************************************************************************/
301 {
302  // get IDs in graph for this edge
303  boost::mutex::scoped_lock lock(nodes_mutex_);
304 
305  if (!pEdge)
306  {
307  return;
308  }
309 
310  const int node1 = pEdge->GetSource()->GetObject()->GetUniqueId();
311  GraphIterator node1it = nodes_->find(node1);
312  const int node2 = pEdge->GetTarget()->GetObject()->GetUniqueId();
313  GraphIterator node2it = nodes_->find(node2);
314 
315  if (node1it == nodes_->end() ||
316  node2it == nodes_->end() || node1it == node2it)
317  {
318  ROS_WARN("CeresSolver: Failed to add constraint, could not find nodes.");
319  return;
320  }
321 
322  // extract transformation
323  karto::LinkInfo* pLinkInfo = (karto::LinkInfo*)(pEdge->GetLabel());
324  karto::Pose2 diff = pLinkInfo->GetPoseDifference();
325  Eigen::Vector3d pose2d(diff.GetX(), diff.GetY(), diff.GetHeading());
326 
327  karto::Matrix3 precisionMatrix = pLinkInfo->GetCovariance().Inverse();
328  Eigen::Matrix3d information;
329  information(0, 0) = precisionMatrix(0, 0);
330  information(0, 1) = information(1, 0) = precisionMatrix(0, 1);
331  information(0, 2) = information(2, 0) = precisionMatrix(0, 2);
332  information(1, 1) = precisionMatrix(1, 1);
333  information(1, 2) = information(2, 1) = precisionMatrix(1, 2);
334  information(2, 2) = precisionMatrix(2, 2);
335  Eigen::Matrix3d sqrt_information = information.llt().matrixU();
336 
337  // populate residual and parameterization for heading normalization
338  ceres::CostFunction* cost_function = PoseGraph2dErrorTerm::Create(pose2d(0),
339  pose2d(1), pose2d(2), sqrt_information);
340  ceres::ResidualBlockId block = problem_->AddResidualBlock(
341  cost_function, loss_function_,
342  &node1it->second(0), &node1it->second(1), &node1it->second(2),
343  &node2it->second(0), &node2it->second(1), &node2it->second(2));
344  problem_->SetParameterization(&node1it->second(2),
346  problem_->SetParameterization(&node2it->second(2),
348 
349  blocks_->insert(std::pair<std::size_t, ceres::ResidualBlockId>(
350  GetHash(node1, node2), block));
351  return;
352 }
353 
354 /*****************************************************************************/
356 /*****************************************************************************/
357 {
358  boost::mutex::scoped_lock lock(nodes_mutex_);
359  GraphIterator nodeit = nodes_->find(id);
360  if (nodeit != nodes_->end())
361  {
362  if (problem_->HasParameterBlock(&nodeit->second(0)) &&
363  problem_->HasParameterBlock(&nodeit->second(1)) &&
364  problem_->HasParameterBlock(&nodeit->second(2)))
365  {
366  problem_->RemoveParameterBlock(&nodeit->second(0));
367  problem_->RemoveParameterBlock(&nodeit->second(1));
368  problem_->RemoveParameterBlock(&nodeit->second(2));
369  ROS_DEBUG("RemoveNode: Removed node id %d", nodeit->first);
370  }
371  else
372  {
373  ROS_DEBUG("RemoveNode: Failed to remove parameter blocks for node id %d", nodeit->first);
374  }
375  nodes_->erase(nodeit);
376  }
377  else
378  {
379  ROS_ERROR("RemoveNode: Failed to find node matching id %i", (int)id);
380  }
381 }
382 
383 /*****************************************************************************/
385 /*****************************************************************************/
386 {
387  boost::mutex::scoped_lock lock(nodes_mutex_);
388  std::unordered_map<std::size_t, ceres::ResidualBlockId>::iterator it_a =
389  blocks_->find(GetHash(sourceId, targetId));
390  std::unordered_map<std::size_t, ceres::ResidualBlockId>::iterator it_b =
391  blocks_->find(GetHash(targetId, sourceId));
392  if (it_a != blocks_->end())
393  {
394  problem_->RemoveResidualBlock(it_a->second);
395  blocks_->erase(it_a);
396  }
397  else if (it_b != blocks_->end())
398  {
399  problem_->RemoveResidualBlock(it_b->second);
400  blocks_->erase(it_b);
401  }
402  else
403  {
404  ROS_ERROR("RemoveConstraint: Failed to find residual block for %i %i",
405  (int)sourceId, (int)targetId);
406  }
407 }
408 
409 /*****************************************************************************/
410 void CeresSolver::ModifyNode(const int& unique_id, Eigen::Vector3d pose)
411 /*****************************************************************************/
412 {
413  boost::mutex::scoped_lock lock(nodes_mutex_);
414  GraphIterator it = nodes_->find(unique_id);
415  if (it != nodes_->end())
416  {
417  double yaw_init = it->second(2);
418  it->second = pose;
419  it->second(2) += yaw_init;
420  }
421 }
422 
423 /*****************************************************************************/
424 void CeresSolver::GetNodeOrientation(const int& unique_id, double& pose)
425 /*****************************************************************************/
426 {
427  boost::mutex::scoped_lock lock(nodes_mutex_);
428  GraphIterator it = nodes_->find(unique_id);
429  if (it != nodes_->end())
430  {
431  pose = it->second(2);
432  }
433 }
434 
435 /*****************************************************************************/
436 std::unordered_map<int, Eigen::Vector3d>* CeresSolver::getGraph()
437 /*****************************************************************************/
438 {
439  boost::mutex::scoped_lock lock(nodes_mutex_);
440  return nodes_;
441 }
442 
443 } // end namespace
solver_plugins::CeresSolver::RemoveNode
virtual void RemoveNode(kt_int32s id)
Definition: ceres_solver.cpp:355
karto::ScanSolver
Definition: Mapper.h:947
solver_plugins::CeresSolver::blocks_
std::unordered_map< size_t, ceres::ResidualBlockId > * blocks_
Definition: ceres_solver.hpp:65
Eigen
solver_plugins::CeresSolver::GetCorrections
virtual const karto::ScanSolver::IdPoseVector & GetCorrections() const
Definition: ceres_solver.cpp:229
solver_plugins::CeresSolver::nodes_mutex_
boost::mutex nodes_mutex_
Definition: ceres_solver.hpp:67
ros::NodeHandle::getParam
bool getParam(const std::string &key, bool &b) const
solver_plugins::CeresSolver::Compute
virtual void Compute()
Definition: ceres_solver.cpp:170
karto::Pose2::GetY
kt_double GetY() const
Definition: Karto.h:2117
solver_plugins::CeresSolver::debug_logging_
bool debug_logging_
Definition: ceres_solver.hpp:61
solver_plugins::CeresSolver::first_node_
std::unordered_map< int, Eigen::Vector3d >::iterator first_node_
Definition: ceres_solver.hpp:66
solver_plugins::CeresSolver::getGraph
virtual std::unordered_map< int, Eigen::Vector3d > * getGraph()
Definition: ceres_solver.cpp:436
solver_plugins::CeresSolver::AddConstraint
virtual void AddConstraint(karto::Edge< karto::LocalizedRangeScan > *pEdge)
Definition: ceres_solver.cpp:299
solver_plugins::CeresSolver::ModifyNode
virtual void ModifyNode(const int &unique_id, Eigen::Vector3d pose)
Definition: ceres_solver.cpp:410
solver_plugins::CeresSolver::corrections_
karto::ScanSolver::IdPoseVector corrections_
Definition: ceres_solver.hpp:53
solver_plugins::CeresSolver::~CeresSolver
virtual ~CeresSolver()
Definition: ceres_solver.cpp:152
karto::LinkInfo::GetCovariance
const Matrix3 & GetCovariance()
Definition: Mapper.h:219
kt_int32s
int32_t kt_int32s
Definition: Types.h:51
solver_plugins::CeresSolver::nodes_
std::unordered_map< int, Eigen::Vector3d > * nodes_
Definition: ceres_solver.hpp:64
karto::Matrix3
Definition: Karto.h:2444
karto::ScanSolver::IdPoseVector
std::vector< std::pair< kt_int32s, Pose2 > > IdPoseVector
Definition: Mapper.h:953
class_list_macros.h
console.h
karto::Matrix3::Inverse
Matrix3 Inverse() const
Definition: Karto.h:2545
solver_plugins::CeresSolver::Clear
virtual void Clear()
Definition: ceres_solver.cpp:236
solver_plugins::CeresSolver::RemoveConstraint
virtual void RemoveConstraint(kt_int32s sourceId, kt_int32s targetId)
Definition: ceres_solver.cpp:384
solver_plugins::CeresSolver::angle_local_parameterization_
ceres::LocalParameterization * angle_local_parameterization_
Definition: ceres_solver.hpp:60
karto::Vertex::GetObject
T * GetObject() const
Definition: Mapper.h:319
Karto.h
solver_plugins::CeresSolver::GetNodeOrientation
virtual void GetNodeOrientation(const int &unique_id, double &pose)
Definition: ceres_solver.cpp:424
karto::LinkInfo
Definition: Mapper.h:141
GetHash
std::size_t GetHash(const int &x, const int &y)
Definition: ceres_utils.h:14
PLUGINLIB_EXPORT_CLASS
#define PLUGINLIB_EXPORT_CLASS(class_type, base_class_type)
solver_plugins::CeresSolver::options_
ceres::Solver::Options options_
Definition: ceres_solver.hpp:56
solver_plugins::CeresSolver::AddNode
virtual void AddNode(karto::Vertex< karto::LocalizedRangeScan > *pVertex)
Definition: ceres_solver.cpp:275
ROS_DEBUG
#define ROS_DEBUG(...)
karto::Pose2::SetY
void SetY(kt_double y)
Definition: Karto.h:2126
karto::Edge::GetLabel
EdgeLabel * GetLabel()
Definition: Mapper.h:457
ceres_solver.hpp
ROS_WARN
#define ROS_WARN(...)
karto::LocalizedRangeScan::GetCorrectedPose
const Pose2 & GetCorrectedPose() const
Definition: Karto.h:5562
PoseGraph2dErrorTerm::Create
static ceres::CostFunction * Create(double x_ab, double y_ab, double yaw_ab_radians, const Eigen::Matrix3d &sqrt_information)
Definition: ceres_utils.h:89
solver_plugins::CeresSolver::loss_function_
ceres::LossFunction * loss_function_
Definition: ceres_solver.hpp:58
karto::Edge::GetTarget
Vertex< T > * GetTarget() const
Definition: Mapper.h:448
solver_plugins::CeresSolver::Reset
virtual void Reset()
Definition: ceres_solver.cpp:243
ros::Time
karto::SensorData::GetUniqueId
kt_int32s GetUniqueId() const
Definition: Karto.h:5163
std
karto::Pose2::GetX
kt_double GetX() const
Definition: Karto.h:2099
ROS_ERROR
#define ROS_ERROR(...)
karto::Vertex< karto::LocalizedRangeScan >
solver_plugins::CeresSolver::problem_
ceres::Problem * problem_
Definition: ceres_solver.hpp:59
karto::Pose2::SetHeading
void SetHeading(kt_double heading)
Definition: Karto.h:2162
AngleLocalParameterization::Create
static ceres::LocalParameterization * Create()
Definition: ceres_utils.h:44
solver_plugins::CeresSolver
Definition: ceres_solver.hpp:30
karto::Pose2
Definition: Karto.h:2046
solver_plugins
Definition: ceres_solver.cpp:14
ROS_INFO
#define ROS_INFO(...)
karto::Edge::GetSource
Vertex< T > * GetSource() const
Definition: Mapper.h:439
solver_plugins::CeresSolver::options_problem_
ceres::Problem::Options options_problem_
Definition: ceres_solver.hpp:57
solver_plugins::CeresSolver::was_constant_set_
bool was_constant_set_
Definition: ceres_solver.hpp:61
karto::Pose2::SetX
void SetX(kt_double x)
Definition: Karto.h:2108
karto::Edge
Definition: Mapper.h:247
karto::Pose2::GetHeading
kt_double GetHeading() const
Definition: Karto.h:2153
toolbox_types::GraphIterator
std::unordered_map< int, Eigen::Vector3d >::iterator GraphIterator
Definition: toolbox_types.hpp:124
ros::NodeHandle
ros::Time::now
static Time now()
toolbox_types::ConstGraphIterator
std::unordered_map< int, Eigen::Vector3d >::const_iterator ConstGraphIterator
Definition: toolbox_types.hpp:125
karto::LinkInfo::GetPoseDifference
const Pose2 & GetPoseDifference()
Definition: Mapper.h:210


slam_toolbox
Author(s): Steve Macenski
autogenerated on Thu Jan 11 2024 03:37:55