Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samplers: add support for edge-case policies #241

Merged
merged 20 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 194 additions & 99 deletions src/torchcodec/samplers/_implem.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import random
from typing import List, Optional
from typing import Callable, List, Literal, Optional

import torch

from torchcodec.decoders import FrameBatch, VideoDecoder
from torchcodec.decoders import Frame, FrameBatch, VideoDecoder


def _validate_params(
*, decoder, num_clips, num_frames_per_clip, num_indices_between_frames
*, decoder, num_clips, num_frames_per_clip, num_indices_between_frames, policy
):
if len(decoder) < 1:
raise ValueError(
Expand All @@ -25,21 +24,36 @@ def _validate_params(
f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive"
)

if policy not in _POLICY_FUNCTIONS.keys():
raise ValueError(
f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}."
)


def _validate_sampling_range(
*, sampling_range_start, sampling_range_end, num_frames, clip_span
*,
num_indices_between_frames,
num_frames_per_clip,
sampling_range_start,
sampling_range_end,
num_frames_in_video,
):
if sampling_range_start < 0:
sampling_range_start = num_frames + sampling_range_start
sampling_range_start = num_frames_in_video + sampling_range_start

if sampling_range_start >= num_frames:
if sampling_range_start >= num_frames_in_video:
raise ValueError(
f"sampling_range_start ({sampling_range_start}) must be smaller than "
f"the number of frames ({num_frames})."
f"the number of frames ({num_frames_in_video})."
)

clip_span = _get_clip_span(
num_indices_between_frames=num_indices_between_frames,
num_frames_per_clip=num_frames_per_clip,
)

if sampling_range_end is None:
sampling_range_end = num_frames - clip_span + 1
sampling_range_end = max(num_frames_in_video - clip_span + 1, 1)
if sampling_range_start >= sampling_range_end:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the discussion below, shouldn't this just be num_frames_in_video? That is, the default last index at which a clip can start is the final frame in the video. And if we're worried that num_frames_in_video is not at least 1, maybe we should explicitly check for (and error?) on that?

Even if we decide line 56 is correct, I can't reconcile it and 67. There, it looks like we're making sure that sampling_range_end is not larger than the video itself - which makes sense to me. But there, we don't subtract off the size of a clip, but here at line 56, we do.

Copy link
Member Author

@NicolasHug NicolasHug Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the lack of documentation is making this significantly harder to review, sorry about that.

So, my reasoning is the following:

  • what happens at line 56 is the default behavior. In line 67, this isn't the default behavior anymore, sampling_range_end is passed by the user
  • the default behavior (56) is: "automatically set the upper bound of where a clip can start to a value such that we never try to sample beyond the last frame." That's why we have to deduct clip_span from num_frames_in_video.
  • the reason this is the default behavior is because there is no way for users to explicitly set such value, unless they do the math of computing the clip_span (and we don't want them to do that I think?).
  • Now, at line 67 when the user explicitly passes sampling_range_end, we just make sure it doesn't go above num_frames_in_video to avoid errors on short videos. If we were to substract clip_span there, it would change the semantic of sampling_range_end from "upper bound of the start" to "upper bound of the end"

Basically:

  • default behavior means that the last few frames are less likely to be sampled, and policy doesn't come into play, but there are no "degenerate" clips (degenerate = policy had to be applied)
  • if user explicitly passes a value, then last few frames probability may increase (depending on the value), and policies become relevant. Clips may be degenrate.

Copy link
Contributor

@scotts scotts Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that all tracks. I think that is consistent with what I surmised below, when I read the test. I do have a slight preference for what I explained, but it's only a slight preference and not blocking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, thanks. I don't have a super strong preference either. Let me merge now to unblock, but happy to revisit.

raise ValueError(
f"We determined that sampling_range_end should be {sampling_range_end}, "
Expand All @@ -49,8 +63,8 @@ def _validate_sampling_range(
else:
if sampling_range_end < 0:
# Support negative values so that -1 means last frame.
sampling_range_end = num_frames + sampling_range_end
sampling_range_end = min(sampling_range_end, num_frames)
sampling_range_end = num_frames_in_video + sampling_range_end
sampling_range_end = min(sampling_range_end, num_frames_in_video)
if sampling_range_start >= sampling_range_end:
raise ValueError(
f"sampling_range_start ({sampling_range_start}) must be smaller than "
Expand All @@ -74,130 +88,211 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip):
return num_indices_between_frames * (num_frames_per_clip - 1) + 1


def clips_at_random_indices(
def _repeat_last_policy(
frame_indices: list[int], num_frames_per_clip: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ❤️ the elegance of these policies.

) -> list[int]:
# frame_indices = [1, 2, 3], num_frames_per_clip = 5
# output = [1, 2, 3, 3, 3]
frame_indices += [frame_indices[-1]] * (num_frames_per_clip - len(frame_indices))
return frame_indices


def _wrap_policy(frame_indices: list[int], num_frames_per_clip: int) -> list[int]:
# frame_indices = [1, 2, 3], num_frames_per_clip = 5
# output = [1, 2, 3, 1, 2]
return (frame_indices * (num_frames_per_clip // len(frame_indices) + 1))[
:num_frames_per_clip
]


def _error_policy(frames_indices: list[int], num_frames_per_clip: int) -> list[int]:
raise ValueError(
"You set the 'error' policy, and the sampler tried to decode a frame "
"that is beyond the number of frames in the video. "
"Try to leave sampling_range_end to its default value?"
)


_POLICY_FUNCTION_TYPE = Callable[[list[int], int], list[int]]
_POLICY_FUNCTIONS: dict[str, _POLICY_FUNCTION_TYPE] = {
"repeat_last": _repeat_last_policy,
"wrap": _wrap_policy,
"error": _error_policy,
}


def _build_all_clips_indices(
*,
clip_start_indices: torch.Tensor, # 1D int tensor
num_frames_per_clip: int,
num_indices_between_frames: int,
num_frames_in_video: int,
policy_fun: _POLICY_FUNCTION_TYPE,
) -> list[int]:
# From the clip_start_indices [f_00, f_10, f_20, ...]
# and from the rest of the parameters, return the list of all the frame
# indices that make up all the clips.
# I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...]
# where f_01 is the index of frame 1 in clip 0.
#
# All clips in the output are of length num_frames_per_clip (=4 in example
# above). When the frame indices go beyond num_frames_in_video, we force the
# frame indices back to valid values by applying the user's policy (wrap,
# repeat, etc.).
all_clips_indices: list[int] = []

clip_span = _get_clip_span(
num_indices_between_frames=num_indices_between_frames,
num_frames_per_clip=num_frames_per_clip,
)

for start_index in clip_start_indices:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that this loop is so easy to implement - we can just call Python's range() directly, and how we're using our policy is obvious and still generic - makes me more confident we're doing the right thing.

frame_index_upper_bound = min(start_index + clip_span, num_frames_in_video)
frame_indices = list(
range(start_index, frame_index_upper_bound, num_indices_between_frames)
)
if len(frame_indices) < num_frames_per_clip:
frame_indices = policy_fun(frame_indices, num_frames_per_clip)
all_clips_indices += frame_indices
return all_clips_indices


def _decode_all_clips_indices(
decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int
) -> list[FrameBatch]:
# This takes the list of all the frames to decode, decode all the frames,
# and then packs them into clips of length num_frames_per_clip.
# This is slow, unoptimized, and u.g.l.y. It is not meant to stay.
# TODO:
# - sort the frames to avoid backward seeks, dedup, decode, and re-organize frames.
# - write most of this in C++

Comment on lines +166 to +168
Copy link
Member Author

@NicolasHug NicolasHug Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want a peek of what this will look like: 3dcbe1e (#245)

I'm leaving it for a follow-up PR to ease reviewing. Basic benchmark shows 5X speedups.

def chunk_list(lst, chunk_size):
# return list of sublists of length chunk_size
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]

def to_framebatch(frames: list[Frame]) -> FrameBatch:
data = torch.stack([frame.data for frame in frames])
pts_seconds = torch.tensor([frame.pts_seconds for frame in frames])
duration_seconds = torch.tensor([frame.duration_seconds for frame in frames])
return FrameBatch(
data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds
)

all_decoded_frames: list[Frame] = [
decoder.get_frame_at(index) for index in all_clips_indices
]
all_clips: list[list[Frame]] = chunk_list(
all_decoded_frames, chunk_size=num_frames_per_clip
)

return [to_framebatch(clip) for clip in all_clips]


def _generic_sampler(
Copy link
Contributor

@scotts scotts Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think a better name is _generic_sampler(). I tend to think of things called "abstract" as defining desired behavior without an actual implementation. I think of things that are called "generic" as having an implementation that is generic to one or several parameters.

kind: Literal["random", "regular"],
decoder: VideoDecoder,
*,
num_clips: int = 1,
num_frames_per_clip: int = 1,
num_indices_between_frames: int = 1,
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
num_clips: int,
num_frames_per_clip: int,
num_indices_between_frames: int,
sampling_range_start: int,
sampling_range_end: Optional[int], # interval is [start, end).
# Important note: sampling_range_end defines the upper bound of where a clip
# can *start*, not where a clip can end.
policy: Literal["repeat_last", "wrap", "error"],
Comment on lines +201 to +202
Copy link
Member Author

@NicolasHug NicolasHug Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fairly important, and something we haven't made explicit yet in our discussions. The reason it has to be defined as the upper bound of where a clip can start is because if sampling_range_end defined the upper bound of where a clip can end, then the last few frames in that sampling range will always have a lower probability of being sampled. I.e. we would always be in this scenario that this comment pointed out #221 (comment).

If sampling_range_end defines the upper bound of where a clip can end, then technically, there really there is no need for a policy. Hope that makes sense.

Copy link
Contributor

@scotts scotts Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost entirely agreed, but the wrinkle is that even if we defined sampling_range_end to be the upper bound of where the last clip and end, we'd still need a policy unless we also defined len(decoder) as the default last possible frame for the end of a clip. If we did that for both, we'd also ensure only whole clips, and then we would not need a policy.

Agreed that this is the right approach. We'll need to document carefully. :) I also think there's no way around needing to document carefully; whatever behavior we choose will end up with subtle behavior.

Copy link
Contributor

@ahmadsharif1 ahmadsharif1 Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to prove to myself that that comment about probabilities made sense so I wrote a simple simulator:

image

It seems like the distribution is symmetric in the sense that even the start few frames have smaller chance of being selected.

I think that is because when you pick frame X you pick frame X + 1 as well, but you never pick frame -1 for example. So frames at the start also have lower chance. Note that this with with 0 dilation so that may change up things too.

I don't know if it breaks your assumptions but you can think about this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may make sense to have a "slow" test with a simulator to see if the fair policy is within some threshold of an "even" distribution.

This comment isn't blocking or anything. It was mostly for my own curiosity since we don't have a closed-form mathematical expression for the probability of being selected.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's true that the first frames are also less likely to be sampled. Oh boy. I don't think we can or should do anything about that side though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could make your code symmetric to pick those "negative frames" and then repeat them in the other direction (i.e. repeat the first frame) with your policy code. That will make it symmetric.

Again, not blocking, but worth a consideration if you want to make it symmetric and fair.

Copy link
Contributor

@ahmadsharif1 ahmadsharif1 Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of like how Conv1D adds padding to both sides:

https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html

You would have to do something symmetric like repeating frames on the left and right boundaries to make it truly fair. Not sure if it's worth the effort.

Copy link
Member Author

@NicolasHug NicolasHug Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may make sense to have a "slow" test with a simulator to see if the fair policy is within some threshold of an "even" distribution.

I did write a test to illustrate the difference of behavior depending on the value of sampling_range_end.

Writing an actual statistical test is challenging because in fact when sampling_range_end=len(decoder), the last frames tend to be over-represented (because they are repeated). I.e. the frames still don't follow a uniform probability, there's no distribution to check against.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it a bit more, if you want frames to all be totally equal in terms of output probability, you will have to pick first first and last few frames with higher probability. It's not as simple as repeating the last frame, etc. because the second last frame also needs to be weighted higher, etc.

And that's why your test doesn't follow a uniform distribution.

For now this implementation is easy and simple. Later on if there is a need, we can implement something complex and "fair".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, enforcing a uniform distribution across all frames is pretty hard.
That being said, we can already enforce a uniform distribution for clip starts: all the user has to do is to set sampling_range_end=len(decoder). With that, all the frames have an equal probability of being picked to start a clip.

) -> List[FrameBatch]:

_validate_params(
decoder=decoder,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
policy=policy,
)

clip_span = _get_clip_span(
num_indices_between_frames=num_indices_between_frames,
num_frames_per_clip=num_frames_per_clip,
)

# TODO: We should probably not error.
if clip_span > len(decoder):
raise ValueError(
f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})"
)

sampling_range_start, sampling_range_end = _validate_sampling_range(
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
sampling_range_start=sampling_range_start,
sampling_range_end=sampling_range_end,
num_frames=len(decoder),
clip_span=clip_span,
num_frames_in_video=len(decoder),
)

clip_start_indices = torch.randint(
low=sampling_range_start, high=sampling_range_end, size=(num_clips,)
)

# We want to avoid seeking backwards, so we sort the clip start indices
# before decoding the frames, and then re-shuffle the clips afterwards.
# Backward seeks may still happen if there are overlapping clips, i.e. if a
# clip ends after the next one starts.
# TODO: We should use a different strategy to avoid backward seeks:
# - flatten all frames indices, irrespective of their clip
# - sort the indices and dedup
# - decode all frames in index order
# - re-arrange the frames back into their original clips
clip_start_indices = torch.sort(clip_start_indices).values
clips = [
decoder.get_frames_at(
start=clip_start_index,
stop=clip_start_index + clip_span,
step=num_indices_between_frames,
if kind == "random":
clip_start_indices = torch.randint(
low=sampling_range_start, high=sampling_range_end, size=(num_clips,)
)
else:
# Note [num clips larger than sampling range]
# If we ask for more clips than there are frames in the sampling range or
# in the video, we rely on torch.linspace behavior which will return
# duplicated indices.
# E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns
# 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10
# Alternatively we could wrap around, but the current behavior is closer to
# the expected "equally spaced indices" sampling.
clip_start_indices = torch.linspace(
sampling_range_start,
sampling_range_end - 1,
steps=num_clips,
dtype=torch.int,
)
for clip_start_index in clip_start_indices
]

# This an ugly way to shuffle the clips using pytorch RNG *without*
# affecting the python builtin RNG.
builtin_random_state = random.getstate()
random.seed(torch.randint(0, 2**32, (1,)).item())
random.shuffle(clips)
random.setstate(builtin_random_state)

return clips
all_clips_indices = _build_all_clips_indices(
clip_start_indices=clip_start_indices,
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
num_frames_in_video=len(decoder),
policy_fun=_POLICY_FUNCTIONS[policy],
)
return _decode_all_clips_indices(
decoder,
all_clips_indices=all_clips_indices,
num_frames_per_clip=num_frames_per_clip,
)


def clips_at_regular_indices(
def clips_at_random_indices(
decoder: VideoDecoder,
*,
num_clips: int = 1,
num_frames_per_clip: int = 1,
num_indices_between_frames: int = 1,
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> List[FrameBatch]:

_validate_params(
return _generic_sampler(
kind="random",
decoder=decoder,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
sampling_range_start=sampling_range_start,
sampling_range_end=sampling_range_end,
policy=policy,
)

clip_span = _get_clip_span(
num_indices_between_frames=num_indices_between_frames,
num_frames_per_clip=num_frames_per_clip,
)

# TODO: We should probably not error.
if clip_span > len(decoder):
raise ValueError(
f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})"
)
def clips_at_regular_indices(
decoder: VideoDecoder,
*,
num_clips: int = 1,
num_frames_per_clip: int = 1,
num_indices_between_frames: int = 1,
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> List[FrameBatch]:

sampling_range_start, sampling_range_end = _validate_sampling_range(
return _generic_sampler(
kind="regular",
decoder=decoder,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
sampling_range_start=sampling_range_start,
sampling_range_end=sampling_range_end,
num_frames=len(decoder),
clip_span=clip_span,
policy=policy,
)

# Note [num clips larger than sampling range]
# If we ask for more clips than there are frames in the sampling range or
# in the video, we rely on torch.linspace behavior which will return
# duplicated indices.
# E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns
# 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10
# Alternatively we could wrap around, but the current behavior is closer to
# the expected "equally spaced indices" sampling.
clip_start_indices = torch.linspace(
sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int
)

# Similarly to clip_at_random_indices, there may be backward seeks if clips overlap.
# See other TODO over there, and apply similar changes here.
clips = [
decoder.get_frames_at(
start=clip_start_index,
stop=clip_start_index + clip_span,
step=num_indices_between_frames,
)
for clip_start_index in clip_start_indices
]

return clips
Loading
Loading