1 """Unit tests for track generation using a Disjoint Set Forest data structure.
7 from typing
import Dict, List, Tuple
14 from gtsam
import IndexPair, Point2, SfmTrack2d
18 """Tests for DsfTrackGenerator."""
23 """Tests DSF for non-transitive matches.
25 Test will result in no tracks since nontransitive tracks are naively discarded by DSF.
33 for (i1,i2), corr_idxs
in nontransitive_matches_dict.items():
34 matches_dict[
IndexPair(i1, i2)] = corr_idxs
41 self.assertEqual(
len(tracks), 0,
"Tracks not filtered correctly")
44 """Ensures that DSF generates three tracks from measurements
45 in 3 images (H=200,W=400)."""
46 kps_i0 =
Keypoints(np.array([[10.0, 20], [30, 40]]))
47 kps_i1 =
Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
48 kps_i2 =
Keypoints(np.array([[110.0, 120], [130, 140]]))
51 keypoints_list.append(kps_i0)
52 keypoints_list.append(kps_i1)
53 keypoints_list.append(kps_i2)
58 matches_dict[
IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
59 matches_dict[
IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
66 assert len(tracks) == 3
70 assert track0.numberMeasurements() == 2
71 np.testing.assert_allclose(track0.measurements[0][1],
Point2(10, 20))
72 np.testing.assert_allclose(track0.measurements[1][1],
Point2(50, 60))
73 assert track0.measurements[0][0] == 0
74 assert track0.measurements[1][0] == 1
75 np.testing.assert_allclose(
76 track0.measurementMatrix(),
82 np.testing.assert_allclose(track0.indexVector(), [0, 1])
86 np.testing.assert_allclose(
87 track1.measurementMatrix(),
94 np.testing.assert_allclose(track1.indexVector(), [0, 1, 2])
98 np.testing.assert_allclose(
99 track2.measurementMatrix(),
105 np.testing.assert_allclose(track2.indexVector(), [1, 2])
109 """Tests for SfmTrack2d."""
112 """Test construction of 2D SfM track."""
114 measurements.append((0,
Point2(10, 20)))
117 assert track.numberMeasurements() == 1
122 img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
123 img1_kp_scale = np.array([6.0, 9.0, 8.5])
124 img2_kp_coords = np.array(
136 img3_kp_coords = np.array(
150 img4_kp_coords = np.array(
165 return keypoints_list
169 """Set up correspondences for each (i1,i2) pair that violates transitivity.
171 (i=0, k=0) (i=0, k=1)
174 (i=1, k=2)--(i=2,k=3)--(i=3, k=4)
176 Transitivity is violated due to the match between frames 0 and 3.
178 nontransitive_matches_dict = {
179 (0, 1): np.array([[0, 2]]),
180 (1, 2): np.array([[2, 3]]),
181 (0, 2): np.array([[0, 3]]),
182 (0, 3): np.array([[1, 4]]),
183 (2, 3): np.array([[3, 4]]),
185 return nontransitive_matches_dict
188 if __name__ ==
"__main__":