From 34a43e7e1d45624faeaa10e4b21f8ff54733d801 Mon Sep 17 00:00:00 2001 From: isipalma Date: Tue, 15 Jun 2021 11:17:58 -0400 Subject: [PATCH] positive item sampling and fix infinite loop --- 2 - Triplet sampling (Random).ipynb | 27 +++++++++++++++++++-------- utils/data.py | 21 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/2 - Triplet sampling (Random).ipynb b/2 - Triplet sampling (Random).ipynb index 27c0834..0146d69 100644 --- a/2 - Triplet sampling (Random).ipynb +++ b/2 - Triplet sampling (Random).ipynb @@ -13,7 +13,7 @@ "import pandas as pd\n", "from tqdm.auto import tqdm\n", "\n", - "from utils.data import extract_embedding, get_interactions_dataframe\n", + "from utils.data import extract_embedding, get_interactions_dataframe, mark_evaluation_rows\n", "from utils.hashing import pre_hash, HashesContainer\n" ] }, @@ -134,6 +134,14 @@ "interactions_df[\"user_id\"] = interactions_df[\"user_id\"].map(user_id2index)\n", "print(f\">> Mapping applied, ({n_missing_ids} values in 'user_id2index')\")\n", "\n", + "# Mark interactions used for evaluation procedure if needed\n", + "if \"evaluation\" not in interactions_df:\n", + " print(\"\\nApply evaluation split...\")\n", + " interactions_df = mark_evaluation_rows(interactions_df)\n", + " # Check if new column exists and has boolean dtype\n", + " assert interactions_df[\"evaluation\"].dtype.name == \"bool\"\n", + " print(f\">> Interactions: {interactions_df.shape}\")\n", + "\n", "# Split interactions data according to evaluation column\n", "evaluation_df = interactions_df[interactions_df[\"evaluation\"]]\n", "interactions_df = interactions_df[~interactions_df[\"evaluation\"]]\n", @@ -202,7 +210,7 @@ "metadata": {}, "outputs": [], "source": [ - "def random_triplet_sampling(samples_per_user, hashes_container, desc=None):\n", + "def random_triplet_sampling(samples_per_user, hashes_container, desc=None, limit_iteration=10000):\n", " interactions = interactions_df.copy()\n", " samples = []\n", " for ui, group in tqdm(interactions.groupby(\"user_id\"), desc=desc):\n", @@ -210,14 +218,19 @@ " full_profile = np.hstack(group[\"item_id\"].values).tolist()\n", " full_profile_set = set(full_profile)\n", " n = samples_per_user\n", + " aux_limit = limit_iteration\n", " while n > 0:\n", + " if aux_limit == 0:\n", + " break\n", " # Sample positive and negative items\n", " pi_index = random.randrange(len(full_profile))\n", " pi = full_profile[pi_index]\n", " # Get profile\n", " if MAX_PROFILE_SIZE:\n", - " # \"pi_index + 1\" to include pi in profile\n", - " profile = full_profile[max(0, pi_index - MAX_PROFILE_SIZE + 1):pi_index + 1]\n", + " profile = random.sample(full_profile, min(len(full_profile), MAX_PROFILE_SIZE))\n", + " if pi not in profile:\n", + " profile = profile[0: -1]\n", + " profile.append(pi)\n", " else:\n", " profile = list(full_profile)\n", " # (While loop is in the sampling method)\n", @@ -231,6 +244,7 @@ " else:\n", " triple = (ui, pi, ni)\n", " if not hashes_container.enroll(pre_hash(triple, contains_iter=MODE_PROFILE)):\n", + " limit_iteration -= 1\n", " continue\n", " # If not seen, store sample\n", " samples.append((profile, pi, ni, ui))\n", @@ -255,9 +269,6 @@ " desc=\"Random sampling (testing)\"\n", ")\n", "\n", - "assert len(samples_training) >= TOTAL_SAMPLES_TRAIN\n", - "assert len(samples_testing) >= TOTAL_SAMPLES_VALID\n", - "\n", "# Total collected samples\n", "print(f\"Training samples: {len(samples_training)} ({TOTAL_SAMPLES_TRAIN})\")\n", "print(f\"Testing samples: {len(samples_testing)} ({TOTAL_SAMPLES_VALID})\")\n", @@ -367,4 +378,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/utils/data.py b/utils/data.py index ade7608..87b92e4 100644 --- a/utils/data.py +++ b/utils/data.py @@ -34,3 +34,24 @@ def get_interactions_dataframe(interactions_path, display_stats=False): print(f"Interactions - {column}: {interactions_df[column].nunique()} unique values") return interactions_df + +def mark_evaluation_rows(interactions_df, threshold=None): + if threshold is None: + threshold = 1 + + def _mark_evaluation_rows(group): + # Only the last 'threshold' items are used for evaluation, + # unless less items are available (then they're used for training) + evaluation_series = pd.Series(False, index=group.index) + if len(group) > threshold: + evaluation_series.iloc[-threshold:] = True + return evaluation_series + + # Mark evaluation rows + interactions_df["evaluation"] = interactions_df.groupby( + ["user_id"])["user_id"].apply(_mark_evaluation_rows) + # Sort transactions by timestamp + interactions_df = interactions_df.sort_values("timestamp") + # Reset index according to new order + interactions_df = interactions_df.reset_index(drop=True) + return interactions_df