vil_inference_client.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 import abc
3 import base64
4 import json
5 
6 import actionlib
7 import requests
8 import rospy
9 from cv_bridge import CvBridge
10 from dynamic_reconfigure.server import Server
11 from jsk_perception.cfg import ClassificationConfig, VQAConfig
12 from jsk_recognition_msgs.msg import (ClassificationResult,
13  ClassificationTaskAction,
14  ClassificationTaskFeedback,
15  ClassificationTaskResult,
16  QuestionAndAnswerText, VQAResult,
17  VQATaskAction, VQATaskFeedback,
18  VQATaskResult)
19 from requests.exceptions import ConnectionError
20 from sensor_msgs.msg import CompressedImage, Image
21 from std_msgs.msg import String
22 
23 import cv2
24 
25 
27  def __init__(self, action,
28  server_config,
29  result_topic,
30  action_feedback,
31  action_result,
32  app_name):
33  # inference server configuration
34  self.host = rospy.get_param("~host", default="localhost")
35  self.port = rospy.get_param("~port", default=8888)
36  self.app_name = app_name
37  # cv bridge
38  self._bridge = CvBridge()
39  # default inference image
40  self.default_img = None
41  # ROS
42  self.transport_hint = rospy.get_param('~image_transport', 'raw')
43  if self.transport_hint == 'compressed':
44  self.image_sub = rospy.Subscriber(
45  "{}/compressed".format(rospy.resolve_name('~image')),
46  CompressedImage,
47  callback=self.topic_cb,
48  queue_size=1,
49  buff_size=2**26
50  )
51 
52  else:
53  self.image_sub = rospy.Subscriber("~image", Image,
54  callback=self.topic_cb,
55  queue_size=1,
56  buff_size=2**26)
57  self.result_topic_type = result_topic
58  self.result_pub = rospy.Publisher("~result", result_topic, queue_size=1)
59  if self.transport_hint == 'compressed':
60  self.image_pub = rospy.Publisher("~result/image/compressed", CompressedImage, queue_size=1)
61  else:
62  self.image_pub = rospy.Publisher("~result/image", Image, queue_size=1)
63  self.vis_pub = rospy.Publisher("~visualize", String, queue_size=1)
64  self.action_server = actionlib.SimpleActionServer("~inference_server",
65  action,
66  execute_cb=self.action_cb,
67  auto_start=False)
68  self.action_feedback = action_feedback
69  self.action_result = action_result
70  self.reconfigure_server = Server(server_config, self.config_cb)
71  self.action_server.start()
72 
73  def ros_img_to_base(self, ros_img):
74  if type(ros_img) is CompressedImage:
75  cv_img = self._bridge.compressed_imgmsg_to_cv2(ros_img, desired_encoding="bgr8")
76  elif type(ros_img) is Image:
77  cv_img = self._bridge.imgmsg_to_cv2(ros_img, desired_encoding="bgr8")
78  else:
79  raise RuntimeError("Unknown type {}".format(type(ros_img)))
80  # convert to base64
81  encimg = cv2.imencode(".png", cv_img)[1]
82  img_str = encimg.tostring()
83  img_byte = base64.b64encode(img_str).decode("utf-8")
84  return img_byte
85 
86  def config_cb(self, config, level):
87  self.config = config
88  return config
89 
90  @abc.abstractmethod
91  def topic_cb(self, msg):
92  pass
93 
94  def action_cb(self, goal):
95  success = True
96  result = self.action_result()
97  feedback = self.action_feedback()
98  if goal.image.data and (not goal.compressed_image.data):
99  image = goal.image
100  # result.result.image = image
101  elif (not goal.image.data) and goal.compressed_image.data:
102  image = goal.compressed_image
103  # result.result.compressed_image = image
104  elif goal.image.data and goal.image.compressed_image.data:
105  rospy.logerr("Both image and compressed image can not be added simultaneously")
106  return
107  else:
108  rospy.loginfo("No images in goal message, so using subscribed image topic instead")
109  image = self.default_img
110  queries = self.create_queries(goal)
111  try:
112  result.result = self.inference(image, queries)
113  except Exception as e:
114  rospy.logerr(str(e))
115  feedback.status = str(e)
116  success = False
117  finally:
118  self.action_server.publish_feedback(feedback)
119  result.done = success
120  self.action_server.set_succeeded(result)
121 
122  @abc.abstractmethod
123  def create_queries(self, goal):
124  pass
125 
126  @abc.abstractmethod
127  def inference(self, img_msg, queries):
128  pass
129 
130  def send_request(self, content):
131  url = "http://{}:{}/{}".format(self.host, str(self.port), self.app_name)
132  try:
133  response = requests.post(url, data=content)
134  except ConnectionError as e:
135  rospy.logwarn_once("Cannot establish the connection with API server. Is it running?")
136  raise e
137  else:
138  if response.status_code == 200:
139  return response
140  else:
141  err_msg = "Invalid http status code: {}".format(str(response.status_code))
142  rospy.logerr(err_msg)
143  raise RuntimeError(err_msg)
144 
145 
147  def __init__(self):
148  DockerInferenceClientBase.__init__(self,
149  ClassificationTaskAction,
150  ClassificationConfig,
151  ClassificationResult,
152  ClassificationTaskFeedback,
153  ClassificationTaskResult,
154  "inference")
155 
156  def topic_cb(self, data):
157  if not self.config: rospy.logwarn("No queries"); return
158  if not self.config.queries: rospy.logwarn("No queries"); return
159  queries = self.config.queries.split(";")
160  try:
161  msg = self.inference(data, queries)
162  except Exception: return
163  # publish debug image
164  self.image_pub.publish(data)
165  # publish classification result
166  msg.header = data.header
167  self.result_pub.publish(msg)
168  # publish probabilities result as string
169  vis_msg = ""
170  for i, label in enumerate(msg.label_names):
171  vis_msg += "{}: {:.2f}% ".format(label, msg.probabilities[i]*100)
172  vis_msg += "\n\nCosine Similarity\n"
173  for i, label in enumerate(msg.label_names):
174  vis_msg += "{}: {:.4f} ".format(label, msg.label_proba[i])
175  self.vis_pub.publish(vis_msg)
176 
177  def create_queries(self, goal):
178  return goal.queries
179 
180  def inference(self, img_msg, queries):
181  img_byte = self.ros_img_to_base(img_msg)
182  req = json.dumps({"image": img_byte,
183  "queries": queries}).encode("utf-8")
184  response = self.send_request(req)
185  result_dic = json.loads(response.text)["results"]
186  labels = []
187  probabilities = []
188  similarities = []
189  for r in result_dic:
190  labels.append(r["question"])
191  probabilities.append(float(r["probability"]))
192  similarities.append(float(r["similarity"]))
193  labels = [label for _,label in sorted(zip(probabilities, labels), reverse=True)]
194  probabilities.sort(reverse=True)
195  similarities.sort(reverse=True)
196  # build ClassificationResult message
197  msg = self.result_topic_type()
198  msg.labels = list(range(len(labels)))
199  msg.label_names = labels
200  msg.label_proba = similarities # cosine similarities
201  msg.probabilities = probabilities # sum(probabilities) is 1
202  msg.classifier = 'clip'
203  msg.target_names = queries
204  return msg
205 
206 
208  def __init__(self):
209  self.vqa_type = rospy.get_param("~vqa_type", default="caption") # caption, vqa_gen. caption is better than vqa_gen in OFA
210  if self.vqa_type not in ["caption", "vqa_gen"]:
211  raise RuntimeError("VQA type must be caption or vqa_gen")
212  DockerInferenceClientBase.__init__(self,
213  VQATaskAction,
214  VQAConfig,
215  VQAResult,
216  VQATaskFeedback,
217  VQATaskResult,
218  self.vqa_type)
219 
220  def create_queries(self, goal):
221  return goal.questions
222 
223  def inference(self, img_msg, queries):
224  img_byte = self.ros_img_to_base(img_msg)
225  req = json.dumps({"image": img_byte,
226  "queries": queries}).encode("utf-8")
227  response = self.send_request(req)
228  json_result = json.loads(response.text)
229  msg = self.result_topic_type()
230  for result in json_result["results"]:
231  result_msg = QuestionAndAnswerText()
232  result_msg.question = result["question"]
233  result_msg.answer = result["answer"]
234  msg.result.append(result_msg)
235  return msg
236 
237  def topic_cb(self, data):
238  if not self.config.questions: rospy.logwarn("No questions"); return
239  queries = self.config.questions.split(";")
240  try:
241  msg = self.inference(data, queries)
242  except Exception: return
243  self.image_pub.publish(data)
244  self.result_pub.publish(msg)
245  vis = ""
246  for qa in msg.result:
247  vis += "Q:{}\n A:{}\n".format(qa.question,
248  qa.answer)
249  self.vis_pub.publish(vis)
object
jsk_perception.vil_inference_client.DockerInferenceClientBase.__init__
def __init__(self, action, server_config, result_topic, action_feedback, action_result, app_name)
Definition: vil_inference_client.py:27
jsk_perception.vil_inference_client.OFAClientNode.create_queries
def create_queries(self, goal)
Definition: vil_inference_client.py:220
jsk_perception.vil_inference_client.DockerInferenceClientBase.app_name
app_name
Definition: vil_inference_client.py:31
jsk_perception.vil_inference_client.ClipClientNode.inference
def inference(self, img_msg, queries)
Definition: vil_inference_client.py:180
ssd_train_dataset.float
float
Definition: ssd_train_dataset.py:180
jsk_perception.vil_inference_client.DockerInferenceClientBase.transport_hint
transport_hint
Definition: vil_inference_client.py:37
jsk_perception.vil_inference_client.DockerInferenceClientBase.result_topic_type
result_topic_type
Definition: vil_inference_client.py:52
ssd_train_dataset.str
str
Definition: ssd_train_dataset.py:178
jsk_perception.vil_inference_client.DockerInferenceClientBase.vis_pub
vis_pub
Definition: vil_inference_client.py:58
jsk_perception.vil_inference_client.ClipClientNode.create_queries
def create_queries(self, goal)
Definition: vil_inference_client.py:177
jsk_perception.vil_inference_client.DockerInferenceClientBase.action_feedback
action_feedback
Definition: vil_inference_client.py:63
jsk_perception.vil_inference_client.DockerInferenceClientBase.result_pub
result_pub
Definition: vil_inference_client.py:53
jsk_perception.vil_inference_client.DockerInferenceClientBase.action_server
action_server
Definition: vil_inference_client.py:59
jsk_perception.vil_inference_client.DockerInferenceClientBase.inference
def inference(self, img_msg, queries)
Definition: vil_inference_client.py:127
jsk_perception.vil_inference_client.DockerInferenceClientBase.action_result
action_result
Definition: vil_inference_client.py:64
jsk_perception.vil_inference_client.DockerInferenceClientBase.reconfigure_server
reconfigure_server
Definition: vil_inference_client.py:65
jsk_perception.vil_inference_client.DockerInferenceClientBase
Definition: vil_inference_client.py:26
jsk_perception.vil_inference_client.OFAClientNode
Definition: vil_inference_client.py:207
jsk_perception.vil_inference_client.DockerInferenceClientBase.config_cb
def config_cb(self, config, level)
Definition: vil_inference_client.py:86
jsk_perception.vil_inference_client.DockerInferenceClientBase.default_img
default_img
Definition: vil_inference_client.py:35
jsk_perception.vil_inference_client.OFAClientNode.__init__
def __init__(self)
Definition: vil_inference_client.py:208
jsk_perception.vil_inference_client.OFAClientNode.topic_cb
def topic_cb(self, data)
Definition: vil_inference_client.py:237
jsk_perception.vil_inference_client.ClipClientNode
Definition: vil_inference_client.py:146
jsk_perception.vil_inference_client.ClipClientNode.__init__
def __init__(self)
Definition: vil_inference_client.py:147
jsk_perception.vil_inference_client.DockerInferenceClientBase.send_request
def send_request(self, content)
Definition: vil_inference_client.py:130
jsk_perception.vil_inference_client.DockerInferenceClientBase.ros_img_to_base
def ros_img_to_base(self, ros_img)
Definition: vil_inference_client.py:73
jsk_perception.vil_inference_client.ClipClientNode.topic_cb
def topic_cb(self, data)
Definition: vil_inference_client.py:156
jsk_perception.vil_inference_client.DockerInferenceClientBase.topic_cb
def topic_cb(self, msg)
Definition: vil_inference_client.py:91
jsk_perception.vil_inference_client.DockerInferenceClientBase.config
config
Definition: vil_inference_client.py:87
actionlib::SimpleActionServer
jsk_perception.vil_inference_client.OFAClientNode.vqa_type
vqa_type
Definition: vil_inference_client.py:209
jsk_perception.vil_inference_client.DockerInferenceClientBase.image_sub
image_sub
Definition: vil_inference_client.py:39
jsk_perception.vil_inference_client.DockerInferenceClientBase.port
port
Definition: vil_inference_client.py:30
jsk_perception.vil_inference_client.DockerInferenceClientBase.create_queries
def create_queries(self, goal)
Definition: vil_inference_client.py:123
jsk_perception.vil_inference_client.DockerInferenceClientBase._bridge
_bridge
Definition: vil_inference_client.py:33
jsk_perception.vil_inference_client.DockerInferenceClientBase.action_cb
def action_cb(self, goal)
Definition: vil_inference_client.py:94
jsk_perception.vil_inference_client.DockerInferenceClientBase.host
host
Definition: vil_inference_client.py:29
jsk_perception.vil_inference_client.OFAClientNode.inference
def inference(self, img_msg, queries)
Definition: vil_inference_client.py:223
jsk_perception.vil_inference_client.DockerInferenceClientBase.image_pub
image_pub
Definition: vil_inference_client.py:55


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:17