Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

positive item sampling and fix infinite loop #7

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions 2 - Triplet sampling (Random).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"import pandas as pd\n",
"from tqdm.auto import tqdm\n",
"\n",
"from utils.data import extract_embedding, get_interactions_dataframe\n",
"from utils.data import extract_embedding, get_interactions_dataframe, mark_evaluation_rows\n",
"from utils.hashing import pre_hash, HashesContainer\n"
]
},
Expand Down Expand Up @@ -134,6 +134,14 @@
"interactions_df[\"user_id\"] = interactions_df[\"user_id\"].map(user_id2index)\n",
"print(f\">> Mapping applied, ({n_missing_ids} values in 'user_id2index')\")\n",
"\n",
"# Mark interactions used for evaluation procedure if needed\n",
"if \"evaluation\" not in interactions_df:\n",
" print(\"\\nApply evaluation split...\")\n",
" interactions_df = mark_evaluation_rows(interactions_df)\n",
" # Check if new column exists and has boolean dtype\n",
" assert interactions_df[\"evaluation\"].dtype.name == \"bool\"\n",
" print(f\">> Interactions: {interactions_df.shape}\")\n",
"\n",
Comment on lines +137 to +144
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot, why was this needed here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed, the code was not present in this repository but it was in mine, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctly.

"# Split interactions data according to evaluation column\n",
"evaluation_df = interactions_df[interactions_df[\"evaluation\"]]\n",
"interactions_df = interactions_df[~interactions_df[\"evaluation\"]]\n",
Expand Down Expand Up @@ -202,22 +210,27 @@
"metadata": {},
"outputs": [],
"source": [
"def random_triplet_sampling(samples_per_user, hashes_container, desc=None):\n",
"def random_triplet_sampling(samples_per_user, hashes_container, desc=None, limit_iteration=10000):\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 10000? Maybe we could use the number of interaction as limit, or a proportion of said number. If I have a million records, and need to sample an important number of it, a proportion of len(interactions_df) (or interactions_df.size, not sure which one is better) would be more appropriate than a fixed number

" interactions = interactions_df.copy()\n",
" samples = []\n",
" for ui, group in tqdm(interactions.groupby(\"user_id\"), desc=desc):\n",
" # Get profile artworks\n",
" full_profile = np.hstack(group[\"item_id\"].values).tolist()\n",
" full_profile_set = set(full_profile)\n",
" n = samples_per_user\n",
" aux_limit = limit_iteration\n",
" while n > 0:\n",
" if aux_limit == 0:\n",
" break\n",
Comment on lines +221 to +224
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aux_limit does not change its value, in line 247 we should use aux_limit instead of limit_iteration and that may be a fix

" # Sample positive and negative items\n",
" pi_index = random.randrange(len(full_profile))\n",
" pi = full_profile[pi_index]\n",
" # Get profile\n",
" if MAX_PROFILE_SIZE:\n",
" # \"pi_index + 1\" to include pi in profile\n",
" profile = full_profile[max(0, pi_index - MAX_PROFILE_SIZE + 1):pi_index + 1]\n",
" profile = random.sample(full_profile, min(len(full_profile), MAX_PROFILE_SIZE))\n",
" if pi not in profile:\n",
" profile = profile[0: -1]\n",
" profile.append(pi)\n",
" else:\n",
" profile = list(full_profile)\n",
" # (While loop is in the sampling method)\n",
Expand All @@ -231,6 +244,7 @@
" else:\n",
" triple = (ui, pi, ni)\n",
" if not hashes_container.enroll(pre_hash(triple, contains_iter=MODE_PROFILE)):\n",
" limit_iteration -= 1\n",
" continue\n",
" # If not seen, store sample\n",
" samples.append((profile, pi, ni, ui))\n",
Expand All @@ -255,9 +269,6 @@
" desc=\"Random sampling (testing)\"\n",
")\n",
"\n",
"assert len(samples_training) >= TOTAL_SAMPLES_TRAIN\n",
"assert len(samples_testing) >= TOTAL_SAMPLES_VALID\n",
"\n",
Comment on lines -258 to -260
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

"# Total collected samples\n",
"print(f\"Training samples: {len(samples_training)} ({TOTAL_SAMPLES_TRAIN})\")\n",
"print(f\"Testing samples: {len(samples_testing)} ({TOTAL_SAMPLES_VALID})\")\n",
Expand Down Expand Up @@ -367,4 +378,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
21 changes: 21 additions & 0 deletions utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,24 @@ def get_interactions_dataframe(interactions_path, display_stats=False):
print(f"Interactions - {column}: {interactions_df[column].nunique()} unique values")

return interactions_df

def mark_evaluation_rows(interactions_df, threshold=None):
if threshold is None:
threshold = 1

def _mark_evaluation_rows(group):
# Only the last 'threshold' items are used for evaluation,
# unless less items are available (then they're used for training)
evaluation_series = pd.Series(False, index=group.index)
if len(group) > threshold:
evaluation_series.iloc[-threshold:] = True
return evaluation_series

# Mark evaluation rows
interactions_df["evaluation"] = interactions_df.groupby(
["user_id"])["user_id"].apply(_mark_evaluation_rows)
# Sort transactions by timestamp
interactions_df = interactions_df.sort_values("timestamp")
# Reset index according to new order
interactions_df = interactions_df.reset_index(drop=True)
return interactions_df