-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
positive item sampling and fix infinite loop #7
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
"import pandas as pd\n", | ||
"from tqdm.auto import tqdm\n", | ||
"\n", | ||
"from utils.data import extract_embedding, get_interactions_dataframe\n", | ||
"from utils.data import extract_embedding, get_interactions_dataframe, mark_evaluation_rows\n", | ||
"from utils.hashing import pre_hash, HashesContainer\n" | ||
] | ||
}, | ||
|
@@ -134,6 +134,14 @@ | |
"interactions_df[\"user_id\"] = interactions_df[\"user_id\"].map(user_id2index)\n", | ||
"print(f\">> Mapping applied, ({n_missing_ids} values in 'user_id2index')\")\n", | ||
"\n", | ||
"# Mark interactions used for evaluation procedure if needed\n", | ||
"if \"evaluation\" not in interactions_df:\n", | ||
" print(\"\\nApply evaluation split...\")\n", | ||
" interactions_df = mark_evaluation_rows(interactions_df)\n", | ||
" # Check if new column exists and has boolean dtype\n", | ||
" assert interactions_df[\"evaluation\"].dtype.name == \"bool\"\n", | ||
" print(f\">> Interactions: {interactions_df.shape}\")\n", | ||
"\n", | ||
"# Split interactions data according to evaluation column\n", | ||
"evaluation_df = interactions_df[interactions_df[\"evaluation\"]]\n", | ||
"interactions_df = interactions_df[~interactions_df[\"evaluation\"]]\n", | ||
|
@@ -202,22 +210,27 @@ | |
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def random_triplet_sampling(samples_per_user, hashes_container, desc=None):\n", | ||
"def random_triplet_sampling(samples_per_user, hashes_container, desc=None, limit_iteration=10000):\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why 10000? Maybe we could use the number of interaction as limit, or a proportion of said number. If I have a million records, and need to sample an important number of it, a proportion of |
||
" interactions = interactions_df.copy()\n", | ||
" samples = []\n", | ||
" for ui, group in tqdm(interactions.groupby(\"user_id\"), desc=desc):\n", | ||
" # Get profile artworks\n", | ||
" full_profile = np.hstack(group[\"item_id\"].values).tolist()\n", | ||
" full_profile_set = set(full_profile)\n", | ||
" n = samples_per_user\n", | ||
" aux_limit = limit_iteration\n", | ||
" while n > 0:\n", | ||
" if aux_limit == 0:\n", | ||
" break\n", | ||
Comment on lines
+221
to
+224
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
" # Sample positive and negative items\n", | ||
" pi_index = random.randrange(len(full_profile))\n", | ||
" pi = full_profile[pi_index]\n", | ||
" # Get profile\n", | ||
" if MAX_PROFILE_SIZE:\n", | ||
" # \"pi_index + 1\" to include pi in profile\n", | ||
" profile = full_profile[max(0, pi_index - MAX_PROFILE_SIZE + 1):pi_index + 1]\n", | ||
" profile = random.sample(full_profile, min(len(full_profile), MAX_PROFILE_SIZE))\n", | ||
" if pi not in profile:\n", | ||
" profile = profile[0: -1]\n", | ||
" profile.append(pi)\n", | ||
" else:\n", | ||
" profile = list(full_profile)\n", | ||
" # (While loop is in the sampling method)\n", | ||
|
@@ -231,6 +244,7 @@ | |
" else:\n", | ||
" triple = (ui, pi, ni)\n", | ||
" if not hashes_container.enroll(pre_hash(triple, contains_iter=MODE_PROFILE)):\n", | ||
" limit_iteration -= 1\n", | ||
" continue\n", | ||
" # If not seen, store sample\n", | ||
" samples.append((profile, pi, ni, ui))\n", | ||
|
@@ -255,9 +269,6 @@ | |
" desc=\"Random sampling (testing)\"\n", | ||
")\n", | ||
"\n", | ||
"assert len(samples_training) >= TOTAL_SAMPLES_TRAIN\n", | ||
"assert len(samples_testing) >= TOTAL_SAMPLES_VALID\n", | ||
"\n", | ||
Comment on lines
-258
to
-260
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this removed? |
||
"# Total collected samples\n", | ||
"print(f\"Training samples: {len(samples_training)} ({TOTAL_SAMPLES_TRAIN})\")\n", | ||
"print(f\"Testing samples: {len(samples_testing)} ({TOTAL_SAMPLES_VALID})\")\n", | ||
|
@@ -367,4 +378,4 @@ | |
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot, why was this needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed, the code was not present in this repository but it was in mine, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correctly.