4 from __future__
import division
6 from collections
import OrderedDict
10 from dynamic_reconfigure.server
import Server
13 from jsk_recognition_msgs.msg
import ClusterPointIndices
14 from jsk_recognition_msgs.msg
import PolygonArray
15 from jsk_recognition_msgs.msg
import Rect
16 from jsk_recognition_msgs.msg
import RectArray
17 from jsk_topic_tools
import ConnectionBasedTransport
19 from pcl_msgs.msg
import PointIndices
21 from sensor_msgs.msg
import Image
23 from torch.autograd
import Variable
25 from jsk_perception.cfg
import CRAFTConfig
as Config
27 import craft.craft
as craft
28 import craft.craft_utils
as craft_utils
29 import craft.imgproc
as imgproc
30 from craft.refinenet
import RefineNet
34 if list(state_dict.keys())[0].startswith(
"module"):
38 new_state_dict = OrderedDict()
39 for k, v
in state_dict.items():
40 name =
".".join(k.split(
".")[start_idx:])
41 new_state_dict[name] = v
45 def test_net(net, image, text_threshold, link_threshold, text_low_bound_score,
46 device, poly=False, refine_net=None,
50 img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
53 interpolation=cv2.INTER_LINEAR,
55 ratio_h = ratio_w = 1 / target_ratio
58 x = imgproc.normalizeMeanVariance(img_resized)
60 x = torch.from_numpy(x).permute(2, 0, 1)
62 x = Variable(x.unsqueeze(0))
70 score_text = y[0, :, :, 0].cpu().data.numpy()
71 score_link = y[0, :, :, 1].cpu().data.numpy()
74 if refine_net
is not None:
76 y_refiner = refine_net(y, feature)
77 score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
80 boxes, polys = craft_utils.getDetBoxes(
81 score_text, score_link,
82 text_threshold, link_threshold,
83 text_low_bound_score, poly)
86 boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
87 polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
88 for k
in range(len(polys)):
92 render_img = score_text.copy()
93 render_img = np.hstack((render_img, score_link))
94 ret_score_text = imgproc.cvt2HeatmapImg(render_img)
95 return boxes, polys, ret_score_text
105 model_path = rospy.get_param(
'~model_path')
106 refine_model_path = rospy.get_param(
'~refine_model_path')
108 gpu = rospy.get_param(
'~gpu', -1)
109 if torch.cuda.is_available()
and gpu >= 0:
110 device = torch.device(
'cuda:{}'.format(gpu))
112 device = torch.device(
'cpu')
120 refine_net.load_state_dict(
122 torch.load(refine_model_path, map_location=device)))
123 refine_net = refine_net.to(device)
124 if torch.cuda.is_available()
and gpu >= 0:
125 refine_net = torch.nn.DataParallel(refine_net)
137 "~output/polygons", PolygonArray,
140 "~output/rects", RectArray,
143 '~output/cluster_indices', ClusterPointIndices, queue_size=1)
154 self.
sub = rospy.Subscriber(
155 '~input', Image, self.
callback, queue_size=1, buff_size=2**24)
158 self.
sub.unregister()
161 bridge = cv_bridge.CvBridge()
162 img = bridge.imgmsg_to_cv2(img_msg, desired_encoding=
'rgb8')
164 bboxes, polys, score_text =
test_net(
174 msg_indices = ClusterPointIndices(header=img_msg.header)
175 polygon_array_msg = PolygonArray(header=img_msg.header)
177 indices_img = np.zeros(
178 (img.shape[0], img.shape[1]), dtype=np.uint8)
179 poly = np.array(poly).astype(np.int32).reshape((-1))
180 poly = poly.reshape(-1, 2)
182 indices_img, [poly.reshape((-1, 1, 2))], color=255)
183 indices = np.where(indices_img.reshape(-1))[0]
184 indices_msg = PointIndices(
185 header=img_msg.header, indices=indices)
186 msg_indices.cluster_indices.append(indices_msg)
188 polygon_stamp_msg = PolygonStamped(header=img_msg.header)
190 polygon_stamp_msg.polygon.points.append(Point32(x=x, y=y))
191 polygon_array_msg.polygons.append(polygon_stamp_msg)
193 rect_array_msg = RectArray(header=img_msg.header)
195 poly = np.array(poly, dtype=np.int32)
196 x_min = max(np.min(poly[:, 0]), 0)
197 x_max = min(np.max(poly[:, 0]), img.shape[1])
198 y_min = max(np.min(poly[:, 1]), 0)
199 y_max = min(np.max(poly[:, 1]), img.shape[0])
200 width = x_max - x_min
201 height = y_max - y_min
202 rect_array_msg.rects.append(
203 Rect(x=x_min, y=y_min, width=width, height=height))
209 if __name__ ==
'__main__':
210 rospy.init_node(
'craft_node')