forked from AIVIETNAMResearch/AI-City-2023-Track2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
187 lines (152 loc) · 7.46 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import json
import math
import os
import sys
from datetime import datetime
import argparse
import torch
import torch.distributed as dist
import torch.multiprocessing
import torch.multiprocessing as mp
from absl import flags
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
import os.path as osp
import refile
from config import get_default_config
from models import build_model
from utils_ import TqdmToLogger, get_logger,AverageMeter,accuracy,ProgressMeter
from datasets import CityFlowNLDataset
from datasets import CityFlowNLInferenceDataset
from torch.optim.lr_scheduler import _LRScheduler
import torchvision
import time
import torch.nn.functional as F
from transformers import BertTokenizer,RobertaTokenizer
import pickle
from collections import OrderedDict
from main import prepare_start
from utils_ import get_mrr, MgvSaveHelper
import IPython
def inference_vis_and_lang(config_name, args, enforced=False):
cfg = get_default_config()
path = 'configs/' + config_name + '.yaml'
cfg.merge_from_file(path)
checkpoint_name = cfg.TEST.RESTORE_FROM.split('/')[-1].split('.')[0]
save_dir = 'extracted_feats'
feat_pth_path = save_dir + '/img_lang_feat_%s.pth' % checkpoint_name
if args.ossSaver.check_s3_path(feat_pth_path):
if not enforced and refile.s3_isfile(feat_pth_path):
return feat_pth_path
else:
if not enforced and osp.isfile(feat_pth_path):
return feat_pth_path
print(f"====> Generating {feat_pth_path}")
transform_test = torchvision.transforms.Compose([
torchvision.transforms.Resize((cfg.DATA.SIZE, cfg.DATA.SIZE)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
print("USE FRAME CONCAT: ", cfg.DATA.FRAMES_CONCAT)
print("USE MOTION HEATMAP: ", cfg.DATA.USE_HEATMAP)
test_data = CityFlowNLInferenceDataset(cfg.DATA, transform=transform_test, frames_concat=cfg.DATA.FRAMES_CONCAT,
use_multi_frames=cfg.DATA.MULTI_FRAMES)
testloader = DataLoader(dataset=test_data, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=False,
num_workers=cfg.TRAIN.NUM_WORKERS, pin_memory=True)
args.resume = True
args.use_cuda = True
cfg.MODEL.NUM_CLASS = 2155
model = build_model(cfg, args)
if cfg.MODEL.BERT_TYPE == "BERT":
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
elif cfg.MODEL.BERT_TYPE == "ROBERTA":
tokenizer = RobertaTokenizer.from_pretrained(cfg.MODEL.BERT_NAME)
else:
assert False
model.eval()
index = cfg.MODEL.MAIN_FEAT_IDX
all_lang_embeds = dict()
with open(cfg.TEST.QUERY_JSON_PATH) as f:
print(f"====> Query {cfg.TEST.QUERY_JSON_PATH} load")
queries = json.load(f)
with torch.no_grad():
for text_id in tqdm(queries):
text = queries[text_id]['nl'][:-1]
car_text = queries[text_id]['nl'][-1:]
# same dual Text
if cfg.MODEL.SAME_TEXT:
car_text = text
tokens = tokenizer.batch_encode_plus(text, padding='longest', return_tensors='pt')
if 'dual-text' in cfg.MODEL.NAME:
if cfg.DATA.USE_CLIP_FEATS:
car_tokens = tokenizer.batch_encode_plus(car_text, padding='longest', return_tensors='pt')
clip_feats = torch.load(cfg.DATA.CLIP_PATH+"/%s.pth"%text_id)
clip_feats_text = clip_feats['text']
lang_embeds_list = model.module.encode_text(tokens['input_ids'].cuda(), tokens['attention_mask'].cuda(),
car_tokens['input_ids'].cuda(),
car_tokens['attention_mask'].cuda(),
clip_feats_text=clip_feats_text.cuda())
else:
car_tokens = tokenizer.batch_encode_plus(car_text, padding='longest', return_tensors='pt')
lang_embeds_list = model.module.encode_text(tokens['input_ids'].cuda(), tokens['attention_mask'].cuda(),
car_tokens['input_ids'].cuda(),
car_tokens['attention_mask'].cuda())
else:
if cfg.DATA.USE_MULTI_QUERIES:
lang_embeds_list = model.module.encode_text(torch.unsqueeze(tokens['input_ids'].cuda(), dim=1),
torch.unsqueeze(tokens['attention_mask'].cuda(), dim=1))
else:
lang_embeds_list = model.module.encode_text(tokens['input_ids'].cuda(), tokens['attention_mask'].cuda())
lang_embeds = lang_embeds_list[index]
all_lang_embeds[text_id] = lang_embeds.data.cpu().numpy()
all_visual_embeds = dict()
out = dict()
with torch.no_grad():
if cfg.DATA.MULTI_FRAMES:
for batch_idx, (image, motion, track_id, frames_id) in tqdm(enumerate(testloader)):
vis_embed_list = model.module.encode_images(image.cuda(), motion.cuda())
vis_embed = vis_embed_list[index]
for i in range(len(track_id)):
if track_id[i] not in out:
out[track_id[i]] = dict()
out[track_id[i]][frames_id[i].item()] = vis_embed[i, :]
else:
if cfg.DATA.USE_CLIP_FEATS:
for batch_idx, (image, motion, track_id, frames_id, clip_feats_vis) in tqdm(enumerate(testloader)):
vis_embed_list = model.module.encode_images(image.cuda(), motion.cuda(), clip_feats_vis=clip_feats_vis.cuda())
vis_embed = vis_embed_list[index]
for i in range(len(track_id)):
if track_id[i] not in out:
out[track_id[i]] = dict()
out[track_id[i]][frames_id[i].item()] = vis_embed[i, :]
else:
for batch_idx, (image, motion, track_id, frames_id) in tqdm(enumerate(testloader)):
vis_embed_list = model.module.encode_images(image.cuda(), motion.cuda())
vis_embed = vis_embed_list[index]
for i in range(len(track_id)):
if track_id[i] not in out:
out[track_id[i]] = dict()
out[track_id[i]][frames_id[i].item()] = vis_embed[i, :]
for track_id, img_feat in out.items():
tmp = []
for fid in img_feat:
tmp.append(img_feat[fid])
tmp = torch.stack(tmp)
tmp = torch.mean(tmp, 0)
all_visual_embeds[track_id] = tmp.data.cpu().numpy()
feats = (all_visual_embeds, all_lang_embeds)
torch.save(feats, feat_pth_path)
return feat_pth_path
def main():
args, cfg = prepare_start()
config_dict = {
"single_baseline_aug1_plus": 1.,
}
config_file_list = list(config_dict.keys())
merge_weights = list(config_dict.values())
for config_name in config_file_list:
vis_pkl, lang_pkl = inference_vis_and_lang(config_name, args, enforced=False)
if __name__ == '__main__':
main()