9 from cv_bridge
import CvBridge
10 from dynamic_reconfigure.server
import Server
11 from jsk_perception.cfg
import ClassificationConfig, VQAConfig
12 from jsk_recognition_msgs.msg
import (ClassificationResult,
13 ClassificationTaskAction,
14 ClassificationTaskFeedback,
15 ClassificationTaskResult,
16 QuestionAndAnswerText, VQAResult,
17 VQATaskAction, VQATaskFeedback,
19 from requests.exceptions
import ConnectionError
20 from sensor_msgs.msg
import CompressedImage, Image
21 from std_msgs.msg
import String
34 self.
host = rospy.get_param(
"~host", default=
"localhost")
35 self.
port = rospy.get_param(
"~port", default=8888)
45 "{}/compressed".format(rospy.resolve_name(
'~image')),
58 self.
result_pub = rospy.Publisher(
"~result", result_topic, queue_size=1)
60 self.
image_pub = rospy.Publisher(
"~result/image/compressed", CompressedImage, queue_size=1)
62 self.
image_pub = rospy.Publisher(
"~result/image", Image, queue_size=1)
63 self.
vis_pub = rospy.Publisher(
"~visualize", String, queue_size=1)
74 if type(ros_img)
is CompressedImage:
75 cv_img = self.
_bridge.compressed_imgmsg_to_cv2(ros_img, desired_encoding=
"bgr8")
76 elif type(ros_img)
is Image:
77 cv_img = self.
_bridge.imgmsg_to_cv2(ros_img, desired_encoding=
"bgr8")
79 raise RuntimeError(
"Unknown type {}".format(type(ros_img)))
81 encimg = cv2.imencode(
".png", cv_img)[1]
82 img_str = encimg.tostring()
83 img_byte = base64.b64encode(img_str).decode(
"utf-8")
96 result = self.action_result()
97 feedback = self.action_feedback()
98 if goal.image.data
and (
not goal.compressed_image.data):
101 elif (
not goal.image.data)
and goal.compressed_image.data:
102 image = goal.compressed_image
104 elif goal.image.data
and goal.image.compressed_image.data:
105 rospy.logerr(
"Both image and compressed image can not be added simultaneously")
108 rospy.loginfo(
"No images in goal message, so using subscribed image topic instead")
109 image = self.default_img
110 queries = self.create_queries(goal)
112 result.result = self.inference(image, queries)
113 except Exception
as e:
115 feedback.status =
str(e)
118 self.action_server.publish_feedback(feedback)
119 result.done = success
120 self.action_server.set_succeeded(result)
131 url =
"http://{}:{}/{}".format(self.host,
str(self.port), self.app_name)
133 response = requests.post(url, data=content)
134 except ConnectionError
as e:
135 rospy.logwarn_once(
"Cannot establish the connection with API server. Is it running?")
138 if response.status_code == 200:
141 err_msg =
"Invalid http status code: {}".format(
str(response.status_code))
142 rospy.logerr(err_msg)
143 raise RuntimeError(err_msg)
148 DockerInferenceClientBase.__init__(self,
149 ClassificationTaskAction,
150 ClassificationConfig,
151 ClassificationResult,
152 ClassificationTaskFeedback,
153 ClassificationTaskResult,
157 if not self.
config: rospy.logwarn(
"No queries");
return
158 if not self.
config.queries: rospy.logwarn(
"No queries");
return
159 queries = self.
config.queries.split(
";")
162 except Exception:
return
166 msg.header = data.header
170 for i, label
in enumerate(msg.label_names):
171 vis_msg +=
"{}: {:.2f}% ".format(label, msg.probabilities[i]*100)
172 vis_msg +=
"\n\nCosine Similarity\n"
173 for i, label
in enumerate(msg.label_names):
174 vis_msg +=
"{}: {:.4f} ".format(label, msg.label_proba[i])
182 req = json.dumps({
"image": img_byte,
183 "queries": queries}).encode(
"utf-8")
185 result_dic = json.loads(response.text)[
"results"]
190 labels.append(r[
"question"])
191 probabilities.append(
float(r[
"probability"]))
192 similarities.append(
float(r[
"similarity"]))
193 labels = [label
for _,label
in sorted(zip(probabilities, labels), reverse=
True)]
194 probabilities.sort(reverse=
True)
195 similarities.sort(reverse=
True)
198 msg.labels = list(range(len(labels)))
199 msg.label_names = labels
200 msg.label_proba = similarities
201 msg.probabilities = probabilities
202 msg.classifier =
'clip'
203 msg.target_names = queries
209 self.
vqa_type = rospy.get_param(
"~vqa_type", default=
"caption")
210 if self.
vqa_type not in [
"caption",
"vqa_gen"]:
211 raise RuntimeError(
"VQA type must be caption or vqa_gen")
212 DockerInferenceClientBase.__init__(self,
221 return goal.questions
225 req = json.dumps({
"image": img_byte,
226 "queries": queries}).encode(
"utf-8")
228 json_result = json.loads(response.text)
230 for result
in json_result[
"results"]:
231 result_msg = QuestionAndAnswerText()
232 result_msg.question = result[
"question"]
233 result_msg.answer = result[
"answer"]
234 msg.result.append(result_msg)
238 if not self.
config.questions: rospy.logwarn(
"No questions");
return
239 queries = self.
config.questions.split(
";")
242 except Exception:
return
246 for qa
in msg.result:
247 vis +=
"Q:{}\n A:{}\n".format(qa.question,