2 GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
3 Atlanta, Georgia 30332-0415
6 See LICENSE for the license information
8 Script for running hybrid estimator on the City10000 dataset.
18 from matplotlib
import pyplot
as plt
21 from gtsam
import (BetweenFactorPose2, HybridNonlinearFactor,
22 HybridNonlinearFactorGraph, HybridSmoother, HybridValues,
23 Pose2, PriorFactorPose2, Values)
27 """Parse command line arguments"""
28 parser = argparse.ArgumentParser()
29 parser.add_argument(
"--data_file",
30 help=
"The path to the City10000 data file",
31 default=
"T1_city10000_04.txt")
37 help=
"The maximum number of loops to run over the dataset")
43 help=
"After how many steps to run the smoother update.")
45 "--max_num_hypotheses",
49 help=
"The maximum number of hypotheses to keep at any time.")
54 help=
"Plot all hypotheses. NOTE: This is exponential, use with caution."
56 return parser.parse_args()
61 open_loop_constant = open_loop_model.negLogConstant()
64 np.asarray([0.0001, 0.0001, 0.0001]))
67 np.asarray([1.0 / 20.0, 1.0 / 20.0, 1.0 / 100.0]))
68 pose_noise_constant = pose_noise_model.negLogConstant()
72 """Class representing the City10000 dataset."""
79 print(f
"Failed to open file: {self.filename_}")
84 def read_line(self, line: str, delimiter: str =
" "):
85 """Read a `line` from the dataset, separated by the `delimiter`."""
86 return line.split(delimiter)
89 line: str) -> tuple[list[Pose2], tuple[int, int], bool]:
90 """Parse line from file"""
96 is_ambiguous_loop = bool(
int(parts[4]))
98 num_measurements =
int(parts[5])
99 pose_array = [
Pose2()] * num_measurements
101 for i
in range(num_measurements):
102 x =
float(parts[6 + 3 * i])
103 y =
float(parts[7 + 3 * i])
104 rad =
float(parts[8 + 3 * i])
105 pose_array[i] =
Pose2(x, y, rad)
107 return pose_array, (key_s, key_t), is_ambiguous_loop
110 """Read and parse the next line."""
111 line = self.
f_.readline()
115 return None,
None,
None
121 estimate_color=(0.1, 0.1, 0.9, 0.4),
122 estimate_label=
"Hybrid Factor Graphs",
124 filename=
"city10000_results.svg"):
125 """Plot the City10000 estimates against the ground truth.
128 ground_truth: The ground truth trajectory as xy values.
129 all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values,
130 as well as assginment strings.
131 estimate_color (tuple, optional): The color to use for the graph of estimates.
132 Defaults to (0.1, 0.1, 0.9, 0.4).
133 estimate_label (str, optional): Label for the estimates, used in the legend.
134 Defaults to "Hybrid Factor Graphs".
136 if len(all_results) == 1:
137 fig, axes = plt.subplots(1, 1)
140 fig, axes = plt.subplots(
int(np.ceil(
len(all_results) / 2)), 2)
141 axes = axes.flatten()
143 for i, (estimates, s, prob)
in enumerate(all_results):
146 ax.axis((-75.0, 100.0, -75.0, 75.0))
148 gt = ground_truth[:estimates.shape[0]]
153 color=(0.1, 0.7, 0.1, 0.5),
154 label=
"Ground Truth")
155 ax.plot(estimates[:, 0],
159 color=estimate_color,
160 label=estimate_label)
162 ax.set_title(f
"P={prob:.3f}\n{s}", fontdict={
'fontsize': 10})
164 fig.suptitle(f
"After {iters} iterations")
166 num_chunks =
int(np.ceil(
len(text) / 90))
167 text =
"\n".join(text[i * 60:(i + 1) * 60]
for i
in range(num_chunks))
172 horizontalalignment=
'center',
175 fig.savefig(filename, format=
"svg")
179 """Experiment Class"""
183 marginal_threshold: float = 0.9999,
184 max_loop_count: int = 150,
185 update_frequency: int = 3,
186 max_num_hypotheses: int = 10,
187 relinearization_frequency: int = 10,
188 plot_hypotheses: bool =
False):
204 Create a hybrid loop closure factor where
205 0 - loose noise model and 1 - loop noise model.
207 l = (
L(loop_counter), 2)
212 factors = [(f0, open_loop_constant), (f1, pose_noise_constant)]
214 return mixture_factor
217 pose_array) -> HybridNonlinearFactor:
218 """Create hybrid odometry factor with discrete measurement choices."""
224 factors = [(f0, pose_noise_constant), (f1, pose_noise_constant)]
227 return mixture_factor
230 """Perform smoother update and optimize the graph."""
231 print(f
"Smoother update: {self.new_factors_.size()}")
232 before_update = time.time()
236 after_update = time.time()
237 return after_update - before_update
240 """Re-linearize, solve ALL, and re-initialize smoother."""
241 print(f
"================= Re-Initialize: {self.smoother_.allFactors().size()}")
242 before_update = time.time()
245 after_update = time.time()
246 print(f
"Took {after_update - before_update} seconds.")
247 return after_update - before_update
250 """Run the main experiment with a given max_loop_count."""
260 priorPose =
Pose2(0, 0, 0)
267 smoother_update_times = []
268 smoother_update_times.append((index, update_time))
271 number_of_hybrid_factors = 0
275 start_time = time.time()
278 pose_array, keys, is_ambiguous_loop = self.
dataset_.next()
279 if pose_array
is None:
284 num_measurements =
len(pose_array)
288 odom_pose = pose_array[0]
289 if key_s == key_t - 1:
291 if num_measurements > 1:
293 m = (
M(discrete_count), num_measurements)
295 key_s, key_t, m, pose_array)
299 number_of_hybrid_factors += 1
300 print(f
"mixture_factor: {key_s} {key_t}")
309 self.
initial_.atPose2(
X(key_s)) * odom_pose)
312 if is_ambiguous_loop:
314 loop_count, key_s, key_t, odom_pose)
322 print(f
"Loop closure: {key_s} {key_t}")
324 number_of_hybrid_factors += 1
329 smoother_update_times.append((index, update_time))
330 number_of_hybrid_factors = 0
337 if key_s == key_t - 1:
338 cur_time = time.time()
339 time_list.append(cur_time - start_time)
343 print(f
"Index: {index}")
345 if len(time_list) != 0:
346 print(f
"Accumulate time: {time_list[-1]} seconds")
352 smoother_update_times.append((index, update_time))
359 print(f
"Final error: {self.smoother_.hybridBayesNet().error(delta)}")
361 end_time = time.time()
362 total_time = end_time - start_time
363 print(f
"Total time: {total_time} seconds")
370 for key
in delta.discrete().
keys():
372 discrete_keys.push_back((key, 2))
373 print(
"plotting all hypotheses")
377 """Plot all possible hypotheses."""
384 for i
in range(discrete_keys.size()):
385 key, cardinality = discrete_keys.at(i)
387 dkeys.push_back((key, cardinality))
388 fixed_values_str =
" ".join(
389 f
"{gtsam.DefaultKeyFormatter(k)}:{v}"
390 for k, v
in self.
smoother_.fixedValues().items())
395 for assignment
in all_assignments:
400 is_invalid_gbn =
False
401 for i
in range(gbn.size()):
402 if gbn.at(i)
is None:
403 is_invalid_gbn =
True
411 poses = np.zeros((num_poses, 3))
412 for i
in range(num_poses):
413 pose = result.atPose2(
X(i))
414 poses[i] = np.asarray((pose.x(), pose.y(), pose.theta()))
416 assignment_string =
" ".join([
417 f
"{gtsam.DefaultKeyFormatter(k)}={v}"
418 for k, v
in assignment.items()
421 conditional = self.
smoother_.hybridBayesNet().at(
423 discrete_values = self.
smoother_.fixedValues()
424 for k, v
in assignment.items():
425 discrete_values[k] = v
427 if conditional
is None:
430 probability = conditional.evaluate(discrete_values)
432 all_results.append((poses, assignment_string, probability))
437 text=fixed_values_str,
438 filename=f
"city10000_results_{num_iters}.svg")
441 """Save results to file."""
443 self.
write_result(result, final_key,
"Hybrid_City10000.txt")
448 def write_result(self, result, num_poses, filename="Hybrid_city10000.txt"):
450 Write the result of optimization to file.
453 result (Values): he Values object with the final result.
454 num_poses (int): The number of poses to write to the file.
455 filename (str): The file name to save the result to.
457 with open(filename,
'w')
as outfile:
459 for i
in range(num_poses):
460 out_pose = result.atPose2(
X(i))
462 f
"{out_pose.x()} {out_pose.y()} {out_pose.theta()}\n")
464 print(f
"Output written to {filename}")
468 time_filename="Hybrid_City10000_time.txt"):
469 """Log all the timing information to a file"""
471 with open(time_filename,
'w')
as out_file_time:
473 for acc_time
in time_list:
474 out_file_time.write(f
"{acc_time}\n")
476 print(f
"Output {time_filename} file.")
484 max_loop_count=args.max_loop_count,
485 update_frequency=args.update_frequency,
486 max_num_hypotheses=args.max_num_hypotheses,
487 plot_hypotheses=args.plot_hypotheses)
491 if __name__ ==
"__main__":