craft_node.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding:utf-8 -*-
3 
4 from __future__ import division
5 
6 from collections import OrderedDict
7 
8 import cv2
9 import cv_bridge
10 from dynamic_reconfigure.server import Server
11 from geometry_msgs.msg import Point32
12 from geometry_msgs.msg import PolygonStamped
13 from jsk_recognition_msgs.msg import ClusterPointIndices
14 from jsk_recognition_msgs.msg import PolygonArray
15 from jsk_recognition_msgs.msg import Rect
16 from jsk_recognition_msgs.msg import RectArray
17 from jsk_topic_tools import ConnectionBasedTransport
18 import numpy as np
19 from pcl_msgs.msg import PointIndices
20 import rospy
21 from sensor_msgs.msg import Image
22 import torch
23 from torch.autograd import Variable
24 
25 from jsk_perception.cfg import CRAFTConfig as Config
26 
27 import craft.craft as craft
28 import craft.craft_utils as craft_utils
29 import craft.imgproc as imgproc
30 from craft.refinenet import RefineNet
31 
32 
33 def copy_state_dict(state_dict):
34  if list(state_dict.keys())[0].startswith("module"):
35  start_idx = 1
36  else:
37  start_idx = 0
38  new_state_dict = OrderedDict()
39  for k, v in state_dict.items():
40  name = ".".join(k.split(".")[start_idx:])
41  new_state_dict[name] = v
42  return new_state_dict
43 
44 
45 def test_net(net, image, text_threshold, link_threshold, text_low_bound_score,
46  device, poly=False, refine_net=None,
47  mag_ratio=1.5,
48  max_image_size=1280):
49  # resize
50  img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
51  image,
52  max_image_size,
53  interpolation=cv2.INTER_LINEAR,
54  mag_ratio=mag_ratio)
55  ratio_h = ratio_w = 1 / target_ratio
56 
57  # preprocessing
58  x = imgproc.normalizeMeanVariance(img_resized)
59  # [h, w, c] to [c, h, w]
60  x = torch.from_numpy(x).permute(2, 0, 1)
61  # [c, h, w] to [b, c, h, w]
62  x = Variable(x.unsqueeze(0))
63  x = x.to(device)
64 
65  # forward pass
66  with torch.no_grad():
67  y, feature = net(x)
68 
69  # make score and link map
70  score_text = y[0, :, :, 0].cpu().data.numpy()
71  score_link = y[0, :, :, 1].cpu().data.numpy()
72 
73  # refine link
74  if refine_net is not None:
75  with torch.no_grad():
76  y_refiner = refine_net(y, feature)
77  score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
78 
79  # Post-processing
80  boxes, polys = craft_utils.getDetBoxes(
81  score_text, score_link,
82  text_threshold, link_threshold,
83  text_low_bound_score, poly)
84 
85  # coordinate adjustment
86  boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
87  polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
88  for k in range(len(polys)):
89  if polys[k] is None:
90  polys[k] = boxes[k]
91 
92  render_img = score_text.copy()
93  render_img = np.hstack((render_img, score_link))
94  ret_score_text = imgproc.cvt2HeatmapImg(render_img)
95  return boxes, polys, ret_score_text
96 
97 
98 class CRAFTNode(ConnectionBasedTransport):
99 
100  def __init__(self):
101  super(CRAFTNode, self).__init__()
102 
103  # init neural networks.
104  net = craft.CRAFT()
105  model_path = rospy.get_param('~model_path')
106  refine_model_path = rospy.get_param('~refine_model_path')
107 
108  gpu = rospy.get_param('~gpu', -1)
109  if torch.cuda.is_available() and gpu >= 0:
110  device = torch.device('cuda:{}'.format(gpu))
111  else:
112  device = torch.device('cpu')
113  net.load_state_dict(
114  copy_state_dict(torch.load(model_path, map_location=device)))
115 
116  net = net.to(device)
117  net.eval()
118 
119  refine_net = RefineNet()
120  refine_net.load_state_dict(
122  torch.load(refine_model_path, map_location=device)))
123  refine_net = refine_net.to(device)
124  if torch.cuda.is_available() and gpu >= 0:
125  refine_net = torch.nn.DataParallel(refine_net)
126  refine_net.eval()
127 
128  self.net = net
129  self.refine_net = refine_net
130  self.device = device
131 
132  # dynamic reconfigure
133  self.srv = Server(Config, self.config_callback)
134 
135  # publish topics
136  self.pub_polygons = self.advertise(
137  "~output/polygons", PolygonArray,
138  queue_size=1)
139  self.pub_rects = self.advertise(
140  "~output/rects", RectArray,
141  queue_size=1)
142  self.pub_indices = self.advertise(
143  '~output/cluster_indices', ClusterPointIndices, queue_size=1)
144 
145  def config_callback(self, config, level):
146  self.text_threshold = config.text_threshold
147  self.link_threshold = config.link_threshold
148  self.text_low_bound_score = config.text_low_bound_score
149  self.mag_ratio = config.mag_ratio
150  self.max_image_size = config.max_image_size
151  return config
152 
153  def subscribe(self):
154  self.sub = rospy.Subscriber(
155  '~input', Image, self.callback, queue_size=1, buff_size=2**24)
156 
157  def unsubscribe(self):
158  self.sub.unregister()
159 
160  def callback(self, img_msg):
161  bridge = cv_bridge.CvBridge()
162  img = bridge.imgmsg_to_cv2(img_msg, desired_encoding='rgb8')
163 
164  bboxes, polys, score_text = test_net(
165  self.net, img,
166  self.text_threshold,
167  self.link_threshold,
169  self.device, False,
170  self.refine_net,
171  mag_ratio=self.mag_ratio,
172  max_image_size=self.max_image_size)
173 
174  msg_indices = ClusterPointIndices(header=img_msg.header)
175  polygon_array_msg = PolygonArray(header=img_msg.header)
176  for poly in polys:
177  indices_img = np.zeros(
178  (img.shape[0], img.shape[1]), dtype=np.uint8)
179  poly = np.array(poly).astype(np.int32).reshape((-1))
180  poly = poly.reshape(-1, 2)
181  cv2.fillPoly(
182  indices_img, [poly.reshape((-1, 1, 2))], color=255)
183  indices = np.where(indices_img.reshape(-1))[0]
184  indices_msg = PointIndices(
185  header=img_msg.header, indices=indices)
186  msg_indices.cluster_indices.append(indices_msg)
187 
188  polygon_stamp_msg = PolygonStamped(header=img_msg.header)
189  for x, y in poly:
190  polygon_stamp_msg.polygon.points.append(Point32(x=x, y=y))
191  polygon_array_msg.polygons.append(polygon_stamp_msg)
192 
193  rect_array_msg = RectArray(header=img_msg.header)
194  for poly in polys:
195  poly = np.array(poly, dtype=np.int32)
196  x_min = max(np.min(poly[:, 0]), 0)
197  x_max = min(np.max(poly[:, 0]), img.shape[1])
198  y_min = max(np.min(poly[:, 1]), 0)
199  y_max = min(np.max(poly[:, 1]), img.shape[0])
200  width = x_max - x_min
201  height = y_max - y_min
202  rect_array_msg.rects.append(
203  Rect(x=x_min, y=y_min, width=width, height=height))
204  self.pub_polygons.publish(polygon_array_msg)
205  self.pub_rects.publish(rect_array_msg)
206  self.pub_indices.publish(msg_indices)
207 
208 
209 if __name__ == '__main__':
210  rospy.init_node('craft_node')
211  node = CRAFTNode() # NOQA
212  rospy.spin()
node_scripts.craft_node.CRAFTNode.text_low_bound_score
text_low_bound_score
Definition: craft_node.py:148
node_scripts.craft_node.CRAFTNode
Definition: craft_node.py:98
node_scripts.craft_node.CRAFTNode.pub_polygons
pub_polygons
Definition: craft_node.py:136
node_scripts.craft_node.CRAFTNode.__init__
def __init__(self)
Definition: craft_node.py:100
msg
node_scripts.craft_node.CRAFTNode.refine_net
refine_net
Definition: craft_node.py:129
node_scripts.craft_node.CRAFTNode.max_image_size
max_image_size
Definition: craft_node.py:150
node_scripts.craft_node.CRAFTNode.text_threshold
text_threshold
Definition: craft_node.py:146
node_scripts.craft_node.CRAFTNode.callback
def callback(self, img_msg)
Definition: craft_node.py:160
node_scripts.craft_node.copy_state_dict
def copy_state_dict(state_dict)
Definition: craft_node.py:33
node_scripts.craft_node.test_net
def test_net(net, image, text_threshold, link_threshold, text_low_bound_score, device, poly=False, refine_net=None, mag_ratio=1.5, max_image_size=1280)
Definition: craft_node.py:45
node_scripts.craft_node.CRAFTNode.net
net
Definition: craft_node.py:128
node_scripts.craft_node.CRAFTNode.pub_indices
pub_indices
Definition: craft_node.py:142
node_scripts.craft_node.CRAFTNode.config_callback
def config_callback(self, config, level)
Definition: craft_node.py:145
node_scripts.craft_node.CRAFTNode.srv
srv
Definition: craft_node.py:133
node_scripts.craft_node.CRAFTNode.pub_rects
pub_rects
Definition: craft_node.py:139
node_scripts.craft_node.CRAFTNode.device
device
Definition: craft_node.py:130
node_scripts.craft_node.CRAFTNode.subscribe
def subscribe(self)
Definition: craft_node.py:153
node_scripts.craft_node.CRAFTNode.link_threshold
link_threshold
Definition: craft_node.py:147
node_scripts.craft_node.CRAFTNode.unsubscribe
def unsubscribe(self)
Definition: craft_node.py:157
node_scripts.craft_node.CRAFTNode.sub
sub
Definition: craft_node.py:154
node_scripts.craft.refinenet.RefineNet
Definition: refinenet.py:15
node_scripts.craft_node.CRAFTNode.mag_ratio
mag_ratio
Definition: craft_node.py:149


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:16