MinxuanQin
commited on
Commit
·
0c9e22d
1
Parent(s):
a5ab0ec
fix error in visualbert
Browse files- model_loader.py +16 -8
model_loader.py
CHANGED
|
@@ -62,13 +62,20 @@ def load_dataset(type):
|
|
| 62 |
raise ValueError("invalid dataset: ", type)
|
| 63 |
'''
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
sample = {}
|
| 68 |
-
sample['inputs'] = processor(images=examples['image'], text=examples['question'], return_tensors="pt")
|
| 69 |
-
sample['outputs'] = examples['multiple_choice_answer']
|
| 70 |
-
return sample
|
| 71 |
-
|
| 72 |
|
| 73 |
def label_count_list(labels):
|
| 74 |
res = {}
|
|
@@ -88,7 +95,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
|
|
| 88 |
)
|
| 89 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
| 90 |
.squeeze(2, 3).unsqueeze(0)
|
| 91 |
-
|
| 92 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 93 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 94 |
upd_dict = {
|
|
@@ -192,7 +199,8 @@ def get_answer(model_loader_args, img, question, model_name):
|
|
| 192 |
|
| 193 |
# load question and image (processor = tokenizer)
|
| 194 |
## MOD Minxuan: fix error
|
| 195 |
-
|
|
|
|
| 196 |
outputs = model(**inputs)
|
| 197 |
#except Exception:
|
| 198 |
# return err_msg()
|
|
|
|
| 62 |
raise ValueError("invalid dataset: ", type)
|
| 63 |
'''
|
| 64 |
|
| 65 |
+
def load_img_model(name):
|
| 66 |
+
"""
|
| 67 |
+
loads image models for feature extraction
|
| 68 |
+
returns model name and the loaded model
|
| 69 |
+
"""
|
| 70 |
+
if name == "resnet50":
|
| 71 |
+
model = resnet50(weights='DEFAULT')
|
| 72 |
+
elif name == "vitb16":
|
| 73 |
+
## MOD Minxuan: add param
|
| 74 |
+
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError("undefined model name: ", name)
|
| 77 |
|
| 78 |
+
return model, name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
def label_count_list(labels):
|
| 81 |
res = {}
|
|
|
|
| 95 |
)
|
| 96 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
| 97 |
.squeeze(2, 3).unsqueeze(0)
|
| 98 |
+
|
| 99 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 100 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 101 |
upd_dict = {
|
|
|
|
| 199 |
|
| 200 |
# load question and image (processor = tokenizer)
|
| 201 |
## MOD Minxuan: fix error
|
| 202 |
+
img_model, name = load_img_model("resnet50")
|
| 203 |
+
_, inputs = get_item(img, question, processor, img_model, name)
|
| 204 |
outputs = model(**inputs)
|
| 205 |
#except Exception:
|
| 206 |
# return err_msg()
|