-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlinucb_w_transformer_tester.py
155 lines (128 loc) · 3.78 KB
/
linucb_w_transformer_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from core import ucb, logservice, item, util
import numpy as np
import sys
import random
from time import sleep
import signal
import transformer
import memoryDB
class user:
id = None
age = 0
gender = None
occupation = None
zip = None
def __init__(self, id, age, gender, occupation, zip):
self.id = id
self.age = age
self.gender = gender
self.occupation = occupation
self.zip = zip
def getFeatures(self):
feats = dict()
feats['age'] = self.age
feats['occupation'] = self.occupation
feats['gender'] = self.gender
feats['zip'] = self.zip
return feats
class movie:
id = None
name = None
release_date = None
categories = None
def __init__(self, id, name, release_date, categories):
self.id = id
self.name = name
self.release_date = release_date
self.categories = categories
def getFeatures(self):
feats = dict()
for idx, cat in enumerate(self.categories):
feats[str(idx)] = cat
return feats
def signal_handler(signal, frame):
print('You pressed Ctrl+C!')
sys.exit(1)
def run_test(logservice, trainer, transformer, cv = False):
if cv:
raise Exception('CV not implemented')
path = 'testing/dataset/ml-100k/'
rating_data = open(path + 'u1.base')
user_data = open(path + 'u.user')
occupation_data = open(path + 'u.occupation')
all_items = open(path + 'u.item')
users = []
items = []
user_ratings = dict()
occupations = []
zips = []
for line in occupation_data:
occupations.append(line.strip())
for line in all_items:
# movie id | movie title | release date | video release date |
# IMDb URL | unknown | Action | Adventure | Animation |
# Children's | Comedy | Crime | Documentary | Drama | Fantasy |
# Film-Noir | Horror | Musical | Mystery | Romance | Sci-Fi |
# Thriller | War | Western |
movie_info = line.split('|')
id = movie_info[0]
name = movie_info[1]
descriptor = [1]
descriptor.extend(map(float, movie_info[5:]))
items.append(movie(id, name, None, descriptor))
print 'Setting items'
trainer.setItems([(o.id, transformer.transform(o, 20)) for o in items])
print 'Set items'
for line in user_data: # user id | age | gender | occupation | zip code
user_info = line.split('|')
userid = int(user_info[0])
age = int(user_info[1])
gender = user_info[2]
occupation = user_info[3]
zip = user_info[4].strip()
users.append(user(userid, age, gender, occupation, zip))
ratings = []
for line in rating_data: # user id | item id | rating | timestamp.
userid, itemid, rating, timestamp = line.split('\t')
user_ratings[str(userid) + "_" + str(itemid)] = float(rating)
ratings.append((userid, itemid, rating, timestamp))
print 'Running...'
c = 0
total_rating = 0
ratings_count = 0
avg_ratings = []
for userid, itemid, rating, timestamp in ratings:
if c % 2 == 0:
c += 1
continue
context = transformer.transform(users[int(userid)-1], 20)
recommended_item = trainer.get(context)
key = str(userid) + "_" + str(recommended_item.id)
if key in user_ratings:
rated = user_ratings[key]
total_rating += rated
ratings_count += 1
if(rated > 3):
rated = 1
else:
rated = -1
trainer.reward(recommended_item, context, rated)
c += 1
total = 10000
if( c > total):
break
if c % 100 == 0:
print '\n\n ' + "Evaluated %d/%d lines." % (c, total)
print "Avg. Recommended Rating = %f" % (float(total_rating) / ratings_count)
avg_ratings.append(float(total_rating) / ratings_count)
print ''
print '\n\n ' + "Evaluated %d/%d lines." % (c, total)
print "Avg. Recommended Rating = %f" % (float(total_rating) / ratings_count)
avg_ratings.append(float(total_rating) / ratings_count)
print avg_ratings
trainer = ucb()
logger = logservice()
db = memoryDB.memoryDB()
transformer = transformer.transformer(db)
signal.signal(signal.SIGINT, signal_handler)
run_test(logger, trainer,transformer)