ofa/server.py
Go to the documentation of this file.
1 # sys
2 import os
3 # inference
4 import torch
5 import numpy as np
6 import cv2
7 from fairseq import utils, tasks
8 from fairseq import checkpoint_utils
9 from fairseq import options
10 from fairseq.dataclass.utils import convert_namespace_to_omegaconf
11 from utils.eval_utils import eval_step
12 from utils.zero_shot_utils import zero_shot_step
13 from tasks.mm_tasks.caption import CaptionTask
14 from tasks.mm_tasks.refcoco import RefcocoTask
15 from tasks.mm_tasks.vqa_gen import VqaGenTask
16 from models.ofa import OFAModel
17 from torchvision import transforms
18 from PIL import Image
19 # web server
20 from flask import Flask, request, Response
21 import json
22 import base64
23 
24 PARAM_DIR = "/var/mount/params"
25 # OFA_PARAM[TASK][SCALE]
26 OFA_PARAM = {
27  "caption":{
28  "large":"caption_large_best_clean.pt",
29  "huge":"caption_huge_best.pt"
30  },
31  "refcoco":{
32  "large":"refcocog_large_best.pt",
33  "huge":"refcocog_large_best.pt"
34  },
35  "vqa_gen":{
36  "large":"vqa_large_best.pt",
37  "huge":"vqa_large_best.pt"
38  }
39 }
40 
41 def apply_half(t):
42  if t.dtype is torch.float32:
43  return t.to(dtype=torch.half)
44  return t
45 
46 class Inference:
47  def __init__(self, task, model_scale):
48  self.use_cuda = torch.cuda.is_available()
49  self.use_fp16 = False
50 
51  # set params
52  param = OFA_PARAM[task][model_scale]
53  param_path = os.path.join(PARAM_DIR, param)
54  overrides={"bpe_dir":"utils/BPE", "eval_cider":False, "beam":5,
55  "max_len_b":16, "no_repeat_ngram_size":3, "seed":7}
56 
57  self.task_name = task
58  if task == "caption":
59  tasks.register_task(task, CaptionTask)
60  self.models, self.cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
61  utils.split_paths(param_path),
62  arg_overrides=overrides)
63  elif task == "refcoco":
64  tasks.register_task(self.task, RefcocoTask)
65  self.models, self.cfg, self.task = checkpoint_utils.load_model_ensemble_and_task(
66  utils.split_paths(param_path),
67  arg_overrides=overrides)
68  self.cfg.common.seed = 7
69  self.cfg.generation.beam = 5
70  self.cfg.generation.min_len = 4
71  self.cfg.generation.max_len_a = 0
72  self.cfg.generation.max_len_b = 4
73  self.cfg.generation.no_repeat_ngram_size = 3
74  if self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided:
75  np.random.seed(self.cfg.common.seed)
76  utils.set_torch_seed(self.cfg.common.seed)
77  elif task == "vqa_gen":
78  tasks.register_task('vqa_gen', VqaGenTask)
79  parser = options.get_generation_parser()
80  input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", "--path={}".format(param_path), "--bpe-dir=utils/BPE"]
81  args = options.parse_args_and_arch(parser, input_args)
82  cfg = convert_namespace_to_omegaconf(args)
83  self.task = tasks.setup_task(cfg.task)
84  self.models, self.cfg = checkpoint_utils.load_model_ensemble(
85  utils.split_paths(cfg.common_eval.path),
86  task=self.task)
87  else:
88  raise RuntimeError("Please select models from caption, refcoco, vqa_gen")
89  return
90 
91  # Move models to GPU
92  for model in self.models:
93  model.eval()
94  if self.use_fp16:
95  model.half()
96  if self.use_cuda and not self.cfg.distributed_training.pipeline_model_parallel:
97  model.cuda()
98  model.prepare_for_inference_(self.cfg)
99 
100  # Image transform
101  self.generator = self.task.build_generator(self.models, self.cfg.generation)
102  mean = [0.5, 0.5, 0.5]
103  std = [0.5, 0.5, 0.5]
104  self.patch_resize_transform = transforms.Compose([
105  lambda image: image.convert("RGB"),
106  transforms.Resize((self.cfg.task.patch_image_size, self.cfg.task.patch_image_size), interpolation=Image.BICUBIC),
107  transforms.ToTensor(),
108  transforms.Normalize(mean=mean, std=std),
109  ])
110 
111  self.pad_idx = self.task.src_dict.pad()
112 
113  def visual_grounding(self, Image, Text):
114  sample = self.construct_sample(Image, Text.lower())
115  sample = utils.move_to_cuda(sample) if self.use_cuda else sample
116  sample = utils.apply_to_sample(apply_half, sample) if self.use_fp16 else sample
117  with torch.no_grad():
118  result, scores = eval_step(self.task, self.generator, self.models, sample)
119  img = np.asarray(Image)
120  cv2.rectangle(
121  img,
122  (int(result[0]["box"][0]), int(result[0]["box"][1])),
123  (int(result[0]["box"][2]), int(result[0]["box"][3])),
124  (0, 255, 0), 3)
125  return img
126 
127  def encode_text(self, text, length=None, append_bos=False, append_eos=False):
128  bos_item = torch.LongTensor([self.task.src_dict.bos()])
129  eos_item = torch.LongTensor([self.task.src_dict.eos()])
130  # pad_idx = self.task.src_dict.pad()
131  s = self.task.tgt_dict.encode_line(
132  line=self.task.bpe.encode(text),
133  add_if_not_exist=False,
134  append_eos=False).long()
135  if length is not None:
136  s = s[:length]
137  if append_bos:
138  s = torch.cat([bos_item, s])
139  if append_eos:
140  s = torch.cat([s, eos_item])
141  return s
142 
143  def construct_sample(self, image, text):
144  if self.task_name == "caption" or self.task_name == "vqa_gen":
145  patch_image = self.patch_resize_transform(image).unsqueeze(0)
146  patch_mask = torch.tensor([True])
147  src_text = self.encode_text(" " + text, append_bos=True, append_eos=True).unsqueeze(0)
148  src_length = torch.LongTensor([s.ne(self.pad_idx).long().sum() for s in src_text])
149  if self.task_name == "caption":
150  sample = {
151  "id":np.array(['42']),
152  "net_input": {
153  "src_tokens": src_text,
154  "src_lengths": src_length,
155  "patch_images": patch_image,
156  "patch_masks": patch_mask
157  }
158  }
159  elif self.task_name == "vqa_gen":
160  ref_dict = np.array([{'yes': 1.0}]) # just placeholder
161  sample = {
162  "id":np.array(['42']),
163  "net_input": {
164  "src_tokens": src_text,
165  "src_lengths": src_length,
166  "patch_images": patch_image,
167  "patch_masks": patch_mask
168  },
169  "ref_dict": ref_dict,
170  }
171  return sample
172  elif self.task_name == "refcoco":
173  patch_image_size = self.cfg.task.patch_image_size
174  w, h = image.size
175  w_resize_ratio = torch.tensor(patch_image_size / w).unsqueeze(0)
176  h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0)
177  patch_image = self.patch_resize_transform(image).unsqueeze(0)
178  patch_mask = torch.tensor([True])
179  src_text = self.encode_text(' which region does the text " {} " describe?'.format(text), append_bos=True,
180  append_eos=True).unsqueeze(0)
181  src_length = torch.LongTensor([s.ne(self.pad_idx).long().sum() for s in src_text])
182  sample = {
183  "id": np.array(['42']),
184  "net_input": {
185  "src_tokens": src_text,
186  "src_lengths": src_length,
187  "patch_images": patch_image,
188  "patch_masks": patch_mask,
189  },
190  "w_resize_ratios": w_resize_ratio,
191  "h_resize_ratios": h_resize_ratio,
192  "region_coords": torch.randn(1, 4)
193  }
194  return sample
195 
196  def infer(self, img, text):
197  # get cv2 image
198  if self.task_name == "caption" or self.task_name == "vqa_gen":
199  image = cv2.resize(img, dsize=(640, 480)) # NOTE forcely
200  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
201  image = Image.fromarray(image)
202  # Construct input sample & preprocess for GPU if cuda available
203  sample = self.construct_sample(image, text)
204  sample = utils.move_to_cuda(sample) if self.use_cuda else sample
205  sample = utils.apply_to_sample(apply_half, sample) if self.use_fp16 else sample
206  if self.task_name == "caption":
207  with torch.no_grad():
208  result, scores = eval_step(self.task, self.generator, self.models, sample)
209  text = result[0]['caption']
210  return text
211  elif self.task_name == "vqa_gen":
212  with torch.no_grad():
213  result, scores = zero_shot_step(self.task, self.generator, self.models, sample)
214  text = result[0]['answer']
215  return text
216  elif self.task_name == "refcoco":
217  pass
218 
219 # run
220 if __name__ == "__main__":
221  app = Flask(__name__)
222  ofa_task = os.environ["OFA_TASK"]
223  ofa_model_scale = os.environ["OFA_MODEL_SCALE"]
224  # TODO add refcoco
225  if ofa_task == "all":
226  caption_infer = Inference("caption", ofa_model_scale)
227  vqa_infer = Inference("vqa_gen", ofa_model_scale)
228 
229  elif ofa_task == "caption":
230  caption_infer = Inference("caption", ofa_model_scale)
231 
232  elif ofa_task == "vqa_gen":
233  vqa_infer = Inference("vqa_gen", ofa_model_scale)
234 
235  else:
236  raise RuntimeError("No application is available")
237 
238  try:
239  @app.route("/caption", methods=['POST'])
240  def caption_request():
241  data = request.data.decode("utf-8")
242  data_json = json.loads(data)
243  # process image
244  image_b = data_json['image']
245  image_dec = base64.b64decode(image_b)
246  data_np = np.fromstring(image_dec, dtype='uint8')
247  img = cv2.imdecode(data_np, 1)
248  # get text
249  texts = data_json['queries']
250  results = []
251  for text in texts:
252  answer = caption_infer.infer(img, text)
253  results.append({"question": text, "answer": answer})
254  return Response(response=json.dumps({"results": results}), status=200)
255  except NameError:
256  print("Skipping create caption app")
257 
258  try:
259  @app.route("/vqa_gen", methods=['POST'])
260  def vqa_request():
261  data = request.data.decode("utf-8")
262  data_json = json.loads(data)
263  # process image
264  image_b = data_json['image']
265  image_dec = base64.b64decode(image_b)
266  data_np = np.fromstring(image_dec, dtype='uint8')
267  img = cv2.imdecode(data_np, 1)
268  # get text
269  texts = data_json['queries']
270  results = []
271  for text in texts:
272  answer = vqa_infer.infer(img, text)
273  results.append({"question": text, "answer": answer})
274  return Response(response=json.dumps({"results": results}), status=200)
275  except NameError:
276  print("Skipping create vqa_gen app")
277 
278  app.run("0.0.0.0", 8080, threaded=True)
server.Inference.construct_sample
def construct_sample(self, image, text)
Definition: ofa/server.py:143
server.caption_request
def caption_request()
Definition: clip/server.py:51
server.Inference.visual_grounding
def visual_grounding(self, Image, Text)
Definition: ofa/server.py:113
ssd_train_dataset.int
int
Definition: ssd_train_dataset.py:175
server.Inference.generator
generator
Definition: ofa/server.py:101
server.vqa_request
def vqa_request()
Definition: ofa/server.py:260
server.Inference
Definition: clip/server.py:19
server.Inference.pad_idx
pad_idx
Definition: ofa/server.py:111
server.Inference.patch_resize_transform
patch_resize_transform
Definition: ofa/server.py:104
server.Inference.encode_text
def encode_text(self, text, length=None, append_bos=False, append_eos=False)
Definition: ofa/server.py:127
server.Inference.use_fp16
use_fp16
Definition: ofa/server.py:49
server.Inference.cfg
cfg
Definition: ofa/server.py:84
server.apply_half
def apply_half(t)
Definition: clip/server.py:14
server.Inference.task
task
Definition: ofa/server.py:60
server.Inference.__init__
def __init__(self, gpu_id=None)
Definition: clip/server.py:20
server.Inference.infer
def infer(self, img, texts)
Definition: clip/server.py:25
server.Inference.use_cuda
use_cuda
Definition: ofa/server.py:48
server.Inference.task_name
task_name
Definition: ofa/server.py:57


jsk_perception
Author(s): Manabu Saito, Ryohei Ueda
autogenerated on Fri May 16 2025 03:11:17