-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_openai.py
50 lines (39 loc) · 1.43 KB
/
run_openai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from tqdm import tqdm
from openai_utils import predict, predict_chatgpt
from utils import (
HitsMetric,
adjust_top_k,
get_args,
get_filename,
load_data,
prepare_input,
update_history,
update_metric,
write_results,
)
if __name__ == "__main__":
args = get_args()
test_data, head_search_space, tail_search_space = load_data(args)
adjust_top_k(test_data, args)
metric = HitsMetric()
filename = get_filename(args)
with torch.no_grad(), open(filename, "w", encoding="utf-8") as writer, tqdm(test_data) as pbar:
for i, (x, direction) in enumerate(pbar):
if i % args.world_size != args.rank:
continue
if direction == "tail":
search_space = head_search_space
elif direction == "head":
search_space = tail_search_space
else:
raise ValueError
model_input, candidates = prepare_input(x, search_space, args, return_prompt=True)
if args.model == "chatgpt":
predictions = predict_chatgpt(model_input, args)
else:
predictions = predict(model_input, args)
update_history(x, search_space, predictions, candidates, args)
example = write_results(x, predictions, candidates, direction, writer, args)
update_metric(example, metric, args)
pbar.set_postfix(metric.dump())