7 from fairseq
import utils, tasks
8 from fairseq
import checkpoint_utils
9 from fairseq
import options
10 from fairseq.dataclass.utils
import convert_namespace_to_omegaconf
11 from utils.eval_utils
import eval_step
12 from utils.zero_shot_utils
import zero_shot_step
13 from tasks.mm_tasks.caption
import CaptionTask
14 from tasks.mm_tasks.refcoco
import RefcocoTask
15 from tasks.mm_tasks.vqa_gen
import VqaGenTask
16 from models.ofa
import OFAModel
17 from torchvision
import transforms
20 from flask
import Flask, request, Response
24 PARAM_DIR =
"/var/mount/params"
28 "large":
"caption_large_best_clean.pt",
29 "huge":
"caption_huge_best.pt"
32 "large":
"refcocog_large_best.pt",
33 "huge":
"refcocog_large_best.pt"
36 "large":
"vqa_large_best.pt",
37 "huge":
"vqa_large_best.pt"
42 if t.dtype
is torch.float32:
43 return t.to(dtype=torch.half)
52 param = OFA_PARAM[task][model_scale]
53 param_path = os.path.join(PARAM_DIR, param)
54 overrides={
"bpe_dir":
"utils/BPE",
"eval_cider":
False,
"beam":5,
55 "max_len_b":16,
"no_repeat_ngram_size":3,
"seed":7}
59 tasks.register_task(task, CaptionTask)
60 self.models, self.
cfg, self.
task = checkpoint_utils.load_model_ensemble_and_task(
61 utils.split_paths(param_path),
62 arg_overrides=overrides)
63 elif task ==
"refcoco":
64 tasks.register_task(self.
task, RefcocoTask)
65 self.models, self.
cfg, self.
task = checkpoint_utils.load_model_ensemble_and_task(
66 utils.split_paths(param_path),
67 arg_overrides=overrides)
68 self.
cfg.common.seed = 7
69 self.
cfg.generation.beam = 5
70 self.
cfg.generation.min_len = 4
71 self.
cfg.generation.max_len_a = 0
72 self.
cfg.generation.max_len_b = 4
73 self.
cfg.generation.no_repeat_ngram_size = 3
74 if self.
cfg.common.seed
is not None and not self.
cfg.generation.no_seed_provided:
75 np.random.seed(self.
cfg.common.seed)
76 utils.set_torch_seed(self.
cfg.common.seed)
77 elif task ==
"vqa_gen":
78 tasks.register_task(
'vqa_gen', VqaGenTask)
79 parser = options.get_generation_parser()
80 input_args = [
"",
"--task=vqa_gen",
"--beam=100",
"--unnormalized",
"--path={}".format(param_path),
"--bpe-dir=utils/BPE"]
81 args = options.parse_args_and_arch(parser, input_args)
82 cfg = convert_namespace_to_omegaconf(args)
83 self.
task = tasks.setup_task(cfg.task)
84 self.models, self.
cfg = checkpoint_utils.load_model_ensemble(
85 utils.split_paths(cfg.common_eval.path),
88 raise RuntimeError(
"Please select models from caption, refcoco, vqa_gen")
92 for model
in self.models:
96 if self.
use_cuda and not self.
cfg.distributed_training.pipeline_model_parallel:
98 model.prepare_for_inference_(self.
cfg)
102 mean = [0.5, 0.5, 0.5]
103 std = [0.5, 0.5, 0.5]
105 lambda image: image.convert(
"RGB"),
106 transforms.Resize((self.
cfg.task.patch_image_size, self.
cfg.task.patch_image_size), interpolation=Image.BICUBIC),
107 transforms.ToTensor(),
108 transforms.Normalize(mean=mean, std=std),
115 sample = utils.move_to_cuda(sample)
if self.
use_cuda else sample
116 sample = utils.apply_to_sample(apply_half, sample)
if self.
use_fp16 else sample
117 with torch.no_grad():
118 result, scores = eval_step(self.
task, self.
generator, self.models, sample)
119 img = np.asarray(Image)
122 (
int(result[0][
"box"][0]),
int(result[0][
"box"][1])),
123 (
int(result[0][
"box"][2]),
int(result[0][
"box"][3])),
127 def encode_text(self, text, length=None, append_bos=False, append_eos=False):
128 bos_item = torch.LongTensor([self.
task.src_dict.bos()])
129 eos_item = torch.LongTensor([self.
task.src_dict.eos()])
131 s = self.
task.tgt_dict.encode_line(
132 line=self.
task.bpe.encode(text),
133 add_if_not_exist=
False,
134 append_eos=
False).long()
135 if length
is not None:
138 s = torch.cat([bos_item, s])
140 s = torch.cat([s, eos_item])
146 patch_mask = torch.tensor([
True])
147 src_text = self.
encode_text(
" " + text, append_bos=
True, append_eos=
True).unsqueeze(0)
148 src_length = torch.LongTensor([s.ne(self.
pad_idx).long().sum()
for s
in src_text])
151 "id":np.array([
'42']),
153 "src_tokens": src_text,
154 "src_lengths": src_length,
155 "patch_images": patch_image,
156 "patch_masks": patch_mask
160 ref_dict = np.array([{
'yes': 1.0}])
162 "id":np.array([
'42']),
164 "src_tokens": src_text,
165 "src_lengths": src_length,
166 "patch_images": patch_image,
167 "patch_masks": patch_mask
169 "ref_dict": ref_dict,
173 patch_image_size = self.
cfg.task.patch_image_size
175 w_resize_ratio = torch.tensor(patch_image_size / w).unsqueeze(0)
176 h_resize_ratio = torch.tensor(patch_image_size / h).unsqueeze(0)
178 patch_mask = torch.tensor([
True])
179 src_text = self.
encode_text(
' which region does the text " {} " describe?'.format(text), append_bos=
True,
180 append_eos=
True).unsqueeze(0)
181 src_length = torch.LongTensor([s.ne(self.
pad_idx).long().sum()
for s
in src_text])
183 "id": np.array([
'42']),
185 "src_tokens": src_text,
186 "src_lengths": src_length,
187 "patch_images": patch_image,
188 "patch_masks": patch_mask,
190 "w_resize_ratios": w_resize_ratio,
191 "h_resize_ratios": h_resize_ratio,
192 "region_coords": torch.randn(1, 4)
199 image = cv2.resize(img, dsize=(640, 480))
200 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
201 image = Image.fromarray(image)
204 sample = utils.move_to_cuda(sample)
if self.
use_cuda else sample
205 sample = utils.apply_to_sample(apply_half, sample)
if self.
use_fp16 else sample
207 with torch.no_grad():
208 result, scores = eval_step(self.
task, self.
generator, self.models, sample)
209 text = result[0][
'caption']
212 with torch.no_grad():
213 result, scores = zero_shot_step(self.
task, self.
generator, self.models, sample)
214 text = result[0][
'answer']
220 if __name__ ==
"__main__":
221 app = Flask(__name__)
222 ofa_task = os.environ[
"OFA_TASK"]
223 ofa_model_scale = os.environ[
"OFA_MODEL_SCALE"]
225 if ofa_task ==
"all":
229 elif ofa_task ==
"caption":
230 caption_infer =
Inference(
"caption", ofa_model_scale)
232 elif ofa_task ==
"vqa_gen":
233 vqa_infer =
Inference(
"vqa_gen", ofa_model_scale)
236 raise RuntimeError(
"No application is available")
239 @app.route(
"/caption", methods=[
'POST'])
241 data = request.data.decode(
"utf-8")
242 data_json = json.loads(data)
244 image_b = data_json[
'image']
245 image_dec = base64.b64decode(image_b)
246 data_np = np.fromstring(image_dec, dtype=
'uint8')
247 img = cv2.imdecode(data_np, 1)
249 texts = data_json[
'queries']
252 answer = caption_infer.infer(img, text)
253 results.append({
"question": text,
"answer": answer})
254 return Response(response=json.dumps({
"results": results}), status=200)
256 print(
"Skipping create caption app")
259 @app.route(
"/vqa_gen", methods=[
'POST'])
261 data = request.data.decode(
"utf-8")
262 data_json = json.loads(data)
264 image_b = data_json[
'image']
265 image_dec = base64.b64decode(image_b)
266 data_np = np.fromstring(image_dec, dtype=
'uint8')
267 img = cv2.imdecode(data_np, 1)
269 texts = data_json[
'queries']
272 answer = vqa_infer.infer(img, text)
273 results.append({
"question": text,
"answer": answer})
274 return Response(response=json.dumps({
"results": results}), status=200)
276 print(
"Skipping create vqa_gen app")
278 app.run(
"0.0.0.0", 8080, threaded=
True)