ocr_node.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding:utf-8 -*-
3 
4 from __future__ import division
5 from __future__ import print_function
6 
7 import multiprocessing
8 from multiprocessing.pool import ThreadPool as Pool
9 import os
10 import os.path as osp
11 
12 import std_msgs.msg
13 import cv2
14 import cv_bridge
15 from dynamic_reconfigure.server import Server
16 from jsk_recognition_msgs.msg import Label
17 from jsk_recognition_msgs.msg import LabelArray
18 from jsk_recognition_msgs.msg import PolygonArray
19 from jsk_recognition_msgs.msg import RectArray
20 from jsk_recognition_utils.put_text import put_text_to_image
21 from jsk_recognition_utils import get_tile_image
22 from jsk_topic_tools import ConnectionBasedTransport
23 import message_filters
24 import numpy as np
25 import pytesseract
26 import rospy
27 from sensor_msgs.msg import Image
28 
29 from jsk_perception.cfg import OCRConfig as Config
30 
31 
32 def crop_img(img, poly):
33  if poly.shape != (4, 2):
34  raise ValueError('Not supported shape size {}'.format(poly.shape))
35  # make clock-wise order
36  poly = np.array(poly, dtype=np.int32)
37  startidx = poly.sum(axis=1).argmin()
38  poly = np.roll(poly, len(poly) - startidx, 0)
39  # crop target area.
40  poly = np.array(poly, dtype=np.int32)
41  x_min = max(np.min(poly[:, 0]), 0)
42  x_max = min(np.max(poly[:, 0]), img.shape[1])
43  y_min = max(np.min(poly[:, 1]), 0)
44  y_max = min(np.max(poly[:, 1]), img.shape[0])
45  w = x_max - x_min
46  h = y_max - y_min
47  pts1 = np.float32(poly)
48  pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
49  rot_mat = cv2.getPerspectiveTransform(pts1, pts2)
50  croppped_img = cv2.warpPerspective(img, rot_mat, (w, h))
51  return croppped_img
52 
53 
54 def ocr_image(process_index, poly, img, lang='eng'):
55  croppped_img = crop_img(img, poly)
56  if croppped_img.shape[0] == 0 or croppped_img.shape[1] == 0:
57  txt = u''
58  binary_img = np.zeros((0, 0), dtype=np.uint8)
59  else:
60  gray = cv2.cvtColor(croppped_img, cv2.COLOR_RGB2GRAY)
61  _, binary_img = cv2.threshold(
62  gray, 127, 255,
63  cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
64  txt = pytesseract.image_to_string(
65  binary_img, lang=lang, config="--psm 6")
66  txt = txt.encode('utf-8')
67  txt = txt.lstrip().rstrip()
68  return process_index, txt, croppped_img, binary_img
69 
70 
72  img, polys, texts,
73  font_path,
74  box_thickness=3,
75  font_size=16):
76  if not os.path.exists(font_path):
77  raise OSError("Font not exists!")
78  img = np.array(img)
79  if texts is None:
80  texts = [''] * len(polys)
81 
82  # draw polygons
83  for poly in polys:
84  poly = np.array(poly).astype(np.int32).reshape((-1))
85  poly = poly.reshape(-1, 2)
86  cv2.polylines(
87  img, [poly.reshape((-1, 1, 2))],
88  True, color=(255, 0, 0), thickness=box_thickness)
89 
90  # draw texts
91  for poly, text in zip(polys, texts):
92  poly = np.array(poly).astype(np.int32).reshape((-1))
93  poly = poly.reshape(-1, 2)
94  if isinstance(text, str):
95  text = "{}".format(text)
96  else:
97  text = "{}".format(text.decode('utf-8'))
98  pos = (poly[0][0] + 1, poly[0][1] + 1)
99  put_text_to_image(
100  img, text, pos,
101  font_path, font_size,
102  color=(0, 0, 255),
103  background_color=(255, 255, 255),
104  loc='center')
105  return img
106 
107 
108 class OCRNode(ConnectionBasedTransport):
109 
110  def __init__(self):
111  super(OCRNode, self).__init__()
112  self.bridge = cv_bridge.CvBridge()
113 
114  self.valid_font = False
115  self.subscribe_polygon = Config.defaults['subscribe_polygon']
116  self.approximate_sync = Config.defaults['approximate_sync']
117  self.queue_size = Config.defaults['queue_size']
118 
119  # dynamic reconfigure
120  self.srv = Server(Config, self.config_callback)
121 
122  # publish topics
123  self.pub_str = self.advertise(
124  '~output', std_msgs.msg.String, queue_size=1)
125  self.pub_viz = self.advertise(
126  '~output/viz', Image, queue_size=1)
127  self.pub_labels = self.advertise(
128  '~output/labels', LabelArray, queue_size=1)
129  self.pub_debug_viz = self.advertise(
130  '~output/debug/viz', Image, queue_size=1)
131  self.pub_debug_binary_viz = self.advertise(
132  '~output/debug/binary_viz', Image, queue_size=1)
133 
134  def config_callback(self, config, level):
135  resubscribe = False
136  if self.subscribe_polygon != config.subscribe_polygon \
137  or self.approximate_sync != config.approximate_sync \
138  or self.queue_size != config.queue_size:
139  resubscribe = True
140 
141  self.language = config.language
142  self.font_size = config.font_size
143  self.font_path = config.font_path
144  self.valid_font = osp.exists(self.font_path)
145  if self.valid_font is False:
146  rospy.logwarn('Not valid font_path: {}'.format(self.font_path))
147  self.box_thickness = config.box_thickness
148  self.subscribe_polygon = config.subscribe_polygon
149 
150  self.n_jobs = config.number_of_jobs
151  if self.n_jobs == -1:
152  self.n_jobs = multiprocessing.cpu_count()
153  self.n_jobs = max(self.n_jobs, 1)
154 
155  self.resolution_factor = config.resolution_factor
156  self.interpolation_method = config.interpolation_method
157 
158  if resubscribe and self.is_subscribed():
159  self.unsubscribe()
160  self.subscribe()
161  return config
162 
163  def subscribe(self):
164  sub_image = message_filters.Subscriber('~input', Image)
165  subs = [sub_image]
166  if self.subscribe_polygon:
167  sub_polygons = message_filters.Subscriber(
168  '~input/polygons', PolygonArray)
169  subs.append(sub_polygons)
170  callback = self.polygons_callback
171  else:
172  sub_rects = message_filters.Subscriber('~input/rects', RectArray)
173  subs.append(sub_rects)
174  callback = self.rects_callback
175 
176  if self.approximate_sync:
177  slop = rospy.get_param('~slop', 0.1)
178  sync = message_filters.ApproximateTimeSynchronizer(
179  subs,
180  queue_size=self.queue_size, slop=slop)
181  sync.registerCallback(callback)
182  else:
184  subs, queue_size=self.queue_size)
185  sync.registerCallback(callback)
186  self.subs = subs
187 
188  def unsubscribe(self):
189  for sub in self.subs:
190  sub.sub.unregister()
191 
192  def publish_results(self, img, polys, texts, header, imgs=None,
193  binary_imgs=None):
194  label_array_msg = LabelArray(header=header)
195  for i, text in enumerate(texts):
196  label_array_msg.labels.append(
197  Label(id=i, name=text.decode('utf-8')))
198  self.pub_labels.publish(label_array_msg)
199 
200  if self.pub_viz.get_num_connections() > 0:
201  img = cv2.resize(
202  img, None,
203  fx=self.resolution_factor,
204  fy=self.resolution_factor,
205  interpolation=self.interpolation_method)
206  if self.valid_font:
208  img, polys * self.resolution_factor,
209  texts=texts,
210  font_path=self.font_path,
211  box_thickness=self.box_thickness,
212  font_size=self.font_size)
213  msg_viz = self.bridge.cv2_to_imgmsg(viz, encoding='rgb8')
214  msg_viz.header = header
215  self.pub_viz.publish(msg_viz)
216 
217  if self.pub_debug_viz.get_num_connections() > 0:
218  if not imgs:
219  return
220  viz = get_tile_image(imgs)
221  msg_viz = self.bridge.cv2_to_imgmsg(viz, encoding='rgb8')
222  msg_viz.header = header
223  self.pub_debug_viz.publish(msg_viz)
224 
225  if self.pub_debug_binary_viz.get_num_connections() > 0:
226  if not binary_imgs:
227  return
228  viz = get_tile_image(binary_imgs)
229  msg_viz = self.bridge.cv2_to_imgmsg(viz, encoding='mono8')
230  msg_viz.header = header
231  self.pub_debug_binary_viz.publish(msg_viz)
232 
233  # Sort the polygons in order of distance from the top left.
234  if len(polys) > 0:
235  polys = np.array(polys, dtype=np.int32)
236  indices = np.argsort(polys.sum(axis=2).min(axis=1))
237  text = ' '.join([texts[i].decode('utf-8') for i in indices])
238  else:
239  text = ''
240  self.pub_str.publish(
241  std_msgs.msg.String(data=text))
242 
243  def rects_callback(self, img_msg, rects_msg):
244  img = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding='rgb8')
245  polys = []
246  for rect in rects_msg.rects:
247  polys.append(
248  np.array([[rect.x, rect.y],
249  [rect.x, rect.y + rect.height],
250  [rect.x + rect.width, rect.y + rect.height],
251  [rect.x + rect.width, rect.y]], dtype=np.int32))
252  polys = np.array(polys)
253  texts, imgs, binary_imgs = self.process_ocr(img, polys)
254  self.publish_results(img, polys, texts, img_msg.header)
255 
256  def polygons_callback(self, img_msg, polygons_msg):
257  img = self.bridge.imgmsg_to_cv2(img_msg, desired_encoding='rgb8')
258  polys = []
259  for polygon_stamp_msg in polygons_msg.polygons:
260  polys.append(
261  np.array([[point.x, point.y]
262  for point in polygon_stamp_msg.polygon.points]))
263  polys = np.array(polys, dtype=np.int32)
264  texts, imgs, binary_imgs = self.process_ocr(img, polys)
265  self.publish_results(img, polys, texts, img_msg.header,
266  imgs=imgs, binary_imgs=binary_imgs)
267 
268  def process_ocr(self, img, polys):
269  texts = []
270  imgs = []
271  binary_imgs = []
272  if len(polys) > 0:
273  n_jobs = min(self.n_jobs, len(polys))
274  process = Pool(n_jobs)
275  multiple_results = [
276  process.apply_async(ocr_image, (i, poly, img, self.language))
277  for i, poly in enumerate(polys)]
278  process.close()
279  process.join()
280  results = sorted([res.get() for res in multiple_results], key=lambda a: a[0])
281  texts = [b for a, b, c, d in results]
282  imgs = []
283  for _, _, img, _ in results:
284  if img.shape[0] > 0 and img.shape[1] > 0:
285  imgs.append(img.copy())
286  binary_imgs = []
287  for _, _, _, img in results:
288  if img.shape[0] > 0 and img.shape[1] > 0:
289  binary_imgs.append(img.copy())
290  return texts, imgs, binary_imgs
291 
292 
293 if __name__ == '__main__':
294  rospy.init_node('ocr_node')
295  ocr = OCRNode() # NOQA
296  rospy.spin()
node_scripts.ocr_node.OCRNode.__init__
def __init__(self)
Definition: ocr_node.py:110
node_scripts.ocr_node.OCRNode.process_ocr
def process_ocr(self, img, polys)
Definition: ocr_node.py:268
node_scripts.ocr_node.OCRNode.config_callback
def config_callback(self, config, level)
Definition: ocr_node.py:134
node_scripts.ocr_node.OCRNode.pub_viz
pub_viz
Definition: ocr_node.py:125
node_scripts.ocr_node.OCRNode.pub_debug_binary_viz
pub_debug_binary_viz
Definition: ocr_node.py:131
node_scripts.ocr_node.OCRNode.pub_str
pub_str
Definition: ocr_node.py:123
node_scripts.ocr_node.OCRNode.subs
subs
Definition: ocr_node.py:186
node_scripts.ocr_node.crop_img
def crop_img(img, poly)
Definition: ocr_node.py:32
node_scripts.ocr_node.OCRNode.polygons_callback
def polygons_callback(self, img_msg, polygons_msg)
Definition: ocr_node.py:256
node_scripts.ocr_node.OCRNode.interpolation_method
interpolation_method
Definition: ocr_node.py:156
node_scripts.ocr_node.OCRNode.font_size
font_size
Definition: ocr_node.py:142
node_scripts.ocr_node.OCRNode.queue_size
queue_size
Definition: ocr_node.py:117
node_scripts.ocr_node.OCRNode.srv
srv
Definition: ocr_node.py:120
message_filters::Subscriber
node_scripts.ocr_node.visualize_polygons_with_texts
def visualize_polygons_with_texts(img, polys, texts, font_path, box_thickness=3, font_size=16)
Definition: ocr_node.py:71
node_scripts.ocr_node.OCRNode.rects_callback
def rects_callback(self, img_msg, rects_msg)
Definition: ocr_node.py:243
node_scripts.ocr_node.OCRNode.publish_results
def publish_results(self, img, polys, texts, header, imgs=None, binary_imgs=None)
Definition: ocr_node.py:192
node_scripts.ocr_node.OCRNode.font_path
font_path
Definition: ocr_node.py:143
node_scripts.ocr_node.OCRNode.box_thickness
box_thickness
Definition: ocr_node.py:147
node_scripts.ocr_node.OCRNode.pub_labels
pub_labels
Definition: ocr_node.py:127
node_scripts.ocr_node.OCRNode.language
language
Definition: ocr_node.py:141
node_scripts.ocr_node.OCRNode.approximate_sync
approximate_sync
Definition: ocr_node.py:116
node_scripts.ocr_node.OCRNode.valid_font
valid_font
Definition: ocr_node.py:114
node_scripts.ocr_node.OCRNode.unsubscribe
def unsubscribe(self)
Definition: ocr_node.py:188
node_scripts.ocr_node.OCRNode.n_jobs
n_jobs
Definition: ocr_node.py:150
node_scripts.ocr_node.OCRNode.bridge
bridge
Definition: ocr_node.py:112
node_scripts.ocr_node.OCRNode.resolution_factor
resolution_factor
Definition: ocr_node.py:155
jsk_recognition_utils::put_text
message_filters::TimeSynchronizer
node_scripts.ocr_node.ocr_image
def ocr_image(process_index, poly, img, lang='eng')
Definition: ocr_node.py:54
node_scripts.ocr_node.OCRNode.pub_debug_viz
pub_debug_viz
Definition: ocr_node.py:129
node_scripts.ocr_node.OCRNode
Definition: ocr_node.py:108
node_scripts.ocr_node.OCRNode.subscribe
def subscribe(self)
Definition: ocr_node.py:163
node_scripts.ocr_node.OCRNode.subscribe_polygon
subscribe_polygon
Definition: ocr_node.py:115


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:17