Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shuffling to examples in example features #136

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tensorflow_ranking/python/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ def _get_params(mode, params):
raise ValueError('Invalid mode: {}.'.format(mode))
return num_shuffles

def _get_params_shuffle_peritem(mode, params):
params = params or {}
# 'shuffle_peritem' should be bool
_SHUFFLE_PERITEM = 'shuffle_peritem'
if mode == tf.estimator.ModeKeys.TRAIN:
shuffle_peritem = bool(params.get(_SHUFFLE_PERITEM, None))
elif mode == tf.estimator.ModeKeys.EVAL:
shuffle_peritem = False
elif mode == tf.estimator.ModeKeys.PREDICT:
shuffle_peritem = False
else:
raise ValueError('Invalid mode: {}.'.format(mode))
return shuffle_peritem


class _RankingModel(object):
"""Interface for a ranking model."""
Expand Down Expand Up @@ -335,6 +349,32 @@ def _update_scatter_gather_indices(self, is_valid, mode, params):

def _compute_logits_impl(self, context_features, example_features, labels,
mode, params, config):
if _get_params_shuffle_peritem(mode, params):
with tf.compat.v1.name_scope("shuffle_peritem"):
# Shuffle labels and example features along list_size
# example_features are shape (batch, list_size, feature_space)

first_example = next(iter(example_features.values()))
cur_list_size = tf.shape(input=first_example)[1]

indicies = tf.range(start=0, limit=cur_list_size, dtype=tf.int32)
shuffled_indicies = tf.random.shuffle(indicies)

for name, value in six.iteritems(example_features):
# Transpose to expose LIST_SIZE dimension on the 0th axis
transposed = tf.transpose(value, perm=[1,0,2])

# Shuffle along the new LIST_SIZE axis
shuffled_feature = tf.gather(transposed, shuffled_indicies)

# Revert back to (Batch, LIST_SIZE, feature_space)
reverted = tf.transpose(shuffled_feature, perm=[1,0,2])
example_features[name] = reverted

transposed_label = tf.transpose(labels, perm=[1,0])
shuffled_label = tf.gather(transposed_label, shuffled_indicies)
labels = tf.transpose(shuffled_label, perm=[1,0])

# Scatter/Gather per-example scores through groupwise comparison. Each
# instance in a mini-batch will form a number of groups. Each group of
# examples are scored by `_score_fn` and scores for individual examples are
Expand Down