-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Samplers: add support for edge-case policies #241
Changes from 15 commits
c090e44
6b52cfd
c618160
b5545a9
2c5a559
ce46196
5fed662
1873174
e86e017
63462e9
a97afe7
83c6763
71a839a
a333e9b
08833b0
054b72e
2bb1d58
9b93214
d706e72
dd62d32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,12 @@ | ||
import random | ||
from typing import List, Optional | ||
from typing import Callable, List, Literal, Optional | ||
|
||
import torch | ||
|
||
from torchcodec.decoders import FrameBatch, VideoDecoder | ||
from torchcodec.decoders import Frame, FrameBatch, VideoDecoder | ||
|
||
|
||
def _validate_params( | ||
*, decoder, num_clips, num_frames_per_clip, num_indices_between_frames | ||
*, decoder, num_clips, num_frames_per_clip, num_indices_between_frames, policy | ||
): | ||
if len(decoder) < 1: | ||
raise ValueError( | ||
|
@@ -25,21 +24,36 @@ def _validate_params( | |
f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" | ||
) | ||
|
||
if policy not in _POLICY_FUNCTIONS.keys(): | ||
raise ValueError( | ||
f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}." | ||
) | ||
|
||
|
||
def _validate_sampling_range( | ||
*, sampling_range_start, sampling_range_end, num_frames, clip_span | ||
*, | ||
num_indices_between_frames, | ||
num_frames_per_clip, | ||
sampling_range_start, | ||
sampling_range_end, | ||
num_frames_in_video, | ||
): | ||
if sampling_range_start < 0: | ||
sampling_range_start = num_frames + sampling_range_start | ||
sampling_range_start = num_frames_in_video + sampling_range_start | ||
|
||
if sampling_range_start >= num_frames: | ||
if sampling_range_start >= num_frames_in_video: | ||
raise ValueError( | ||
f"sampling_range_start ({sampling_range_start}) must be smaller than " | ||
f"the number of frames ({num_frames})." | ||
f"the number of frames ({num_frames_in_video})." | ||
) | ||
|
||
clip_span = _get_clip_span( | ||
num_indices_between_frames=num_indices_between_frames, | ||
num_frames_per_clip=num_frames_per_clip, | ||
) | ||
|
||
if sampling_range_end is None: | ||
sampling_range_end = num_frames - clip_span + 1 | ||
sampling_range_end = max(num_frames_in_video - clip_span + 1, 1) | ||
if sampling_range_start >= sampling_range_end: | ||
raise ValueError( | ||
f"We determined that sampling_range_end should be {sampling_range_end}, " | ||
|
@@ -49,8 +63,8 @@ def _validate_sampling_range( | |
else: | ||
if sampling_range_end < 0: | ||
# Support negative values so that -1 means last frame. | ||
sampling_range_end = num_frames + sampling_range_end | ||
sampling_range_end = min(sampling_range_end, num_frames) | ||
sampling_range_end = num_frames_in_video + sampling_range_end | ||
sampling_range_end = min(sampling_range_end, num_frames_in_video) | ||
if sampling_range_start >= sampling_range_end: | ||
raise ValueError( | ||
f"sampling_range_start ({sampling_range_start}) must be smaller than " | ||
|
@@ -74,130 +88,211 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): | |
return num_indices_between_frames * (num_frames_per_clip - 1) + 1 | ||
|
||
|
||
def clips_at_random_indices( | ||
def _repeat_last_policy( | ||
frame_indices: list[int], num_frames_per_clip: int | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ❤️ the elegance of these policies. |
||
) -> list[int]: | ||
# frame_indices = [1, 2, 3], num_frames_per_clip = 5 | ||
# output = [1, 2, 3, 3, 3] | ||
frame_indices += [frame_indices[-1]] * (num_frames_per_clip - len(frame_indices)) | ||
return frame_indices | ||
|
||
|
||
def _wrap_policy(frame_indices: list[int], num_frames_per_clip: int) -> list[int]: | ||
# frame_indices = [1, 2, 3], num_frames_per_clip = 5 | ||
# output = [1, 2, 3, 1, 2] | ||
return (frame_indices * (num_frames_per_clip // len(frame_indices) + 1))[ | ||
:num_frames_per_clip | ||
] | ||
|
||
|
||
def _error_policy(frames_indices: list[int], num_frames_per_clip: int) -> list[int]: | ||
raise ValueError( | ||
"You set the 'error' policy, and the sampler tried the decode a frame " | ||
"that is beyond the number of frames in the video. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"Try to leave sampling_range_end to its default value?" | ||
) | ||
|
||
|
||
_POLICY_FUNCTION_TYPE = Callable[[list[int], int], list[int]] | ||
_POLICY_FUNCTIONS: dict[str, _POLICY_FUNCTION_TYPE] = { | ||
"repeat_last": _repeat_last_policy, | ||
"wrap": _wrap_policy, | ||
"error": _error_policy, | ||
} | ||
|
||
|
||
def _build_all_clips_indices( | ||
*, | ||
clip_start_indices: torch.Tensor, # 1D int tensor | ||
num_frames_per_clip: int, | ||
num_indices_between_frames: int, | ||
num_frames_in_video: int, | ||
policy_fun: _POLICY_FUNCTION_TYPE, | ||
) -> list[int]: | ||
# From the clip_start_indices [f_00, f10, f20, ...] | ||
# and from the rest of the parameters, return the list of all the frame | ||
# indices that make up all the clips. | ||
# I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...] | ||
# where f_01 is the index of frame 1 in clip 0. | ||
# | ||
# All clips in the output are of length num_frames_per_clip (=4 in example | ||
# above). When the frame indices go beyond num_frames_in_video, we force the | ||
# frame indices back to valid values by applying the user's policy (wrap, | ||
# repeat, etc.). | ||
all_clips_indices: list[int] = [] | ||
|
||
clip_span = _get_clip_span( | ||
num_indices_between_frames=num_indices_between_frames, | ||
num_frames_per_clip=num_frames_per_clip, | ||
) | ||
|
||
for start_index in clip_start_indices: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fact that this loop is so easy to implement - we can just call Python's |
||
frame_index_upper_bound = min(start_index + clip_span, num_frames_in_video) | ||
frame_indices = list( | ||
range(start_index, frame_index_upper_bound, num_indices_between_frames) | ||
) | ||
if len(frame_indices) < num_frames_per_clip: | ||
frame_indices = policy_fun(frame_indices, num_frames_per_clip) | ||
all_clips_indices += frame_indices | ||
return all_clips_indices | ||
|
||
|
||
def _decode_all_clips_indices( | ||
decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int | ||
) -> list[FrameBatch]: | ||
# This takes the list of all the frames to decode, decode all the frames, | ||
# and then packs them into clips of length num_frames_per_clip. | ||
# This is slow, unoptimized, and u.g.l.y. It is not meant to stay. | ||
# TODO: | ||
# - sort the frames to avoid backward seeks, dedup, decode, and re-organize frames. | ||
# - write most of this in C++ | ||
|
||
Comment on lines
+166
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want a peek of what this will look like: I'm leaving it for a follow-up PR to ease reviewing. Basic benchmark shows 5X speedups. |
||
def chunk_list(lst, chunk_size): | ||
# return list of sublists of length chunk_size | ||
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] | ||
|
||
def to_framebatch(frames: list[Frame]) -> FrameBatch: | ||
data = torch.stack([frame.data for frame in frames]) | ||
pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) | ||
duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) | ||
return FrameBatch( | ||
data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds | ||
) | ||
|
||
all_decoded_frames: list[Frame] = [ | ||
decoder.get_frame_at(index) for index in all_clips_indices | ||
] | ||
all_clips: list[list[Frame]] = chunk_list( | ||
all_decoded_frames, chunk_size=num_frames_per_clip | ||
) | ||
|
||
return [to_framebatch(clip) for clip in all_clips] | ||
|
||
|
||
def _abstract_sampler( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I think a better name is |
||
kind: Literal["random", "regular"], | ||
decoder: VideoDecoder, | ||
*, | ||
num_clips: int = 1, | ||
num_frames_per_clip: int = 1, | ||
num_indices_between_frames: int = 1, | ||
sampling_range_start: int = 0, | ||
sampling_range_end: Optional[int] = None, # interval is [start, end). | ||
num_clips: int, | ||
num_frames_per_clip: int, | ||
num_indices_between_frames: int, | ||
sampling_range_start: int, | ||
sampling_range_end: Optional[int], # interval is [start, end). | ||
# Important note: sampling_range_end defines the upper bound of where a clip | ||
# can *start*, not where a clip can end. | ||
policy: Literal["repeat_last", "wrap", "error"], | ||
Comment on lines
+201
to
+202
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fairly important, and something we haven't made explicit yet in our discussions. The reason it has to be defined as the upper bound of where a clip can start is because if If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Almost entirely agreed, but the wrinkle is that even if we defined Agreed that this is the right approach. We'll need to document carefully. :) I also think there's no way around needing to document carefully; whatever behavior we choose will end up with subtle behavior. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was trying to prove to myself that that comment about probabilities made sense so I wrote a simple simulator: It seems like the distribution is symmetric in the sense that even the start few frames have smaller chance of being selected. I think that is because when you pick frame X you pick frame X + 1 as well, but you never pick frame -1 for example. So frames at the start also have lower chance. Note that this with with 0 dilation so that may change up things too. I don't know if it breaks your assumptions but you can think about this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may make sense to have a "slow" test with a simulator to see if the fair policy is within some threshold of an "even" distribution. This comment isn't blocking or anything. It was mostly for my own curiosity since we don't have a closed-form mathematical expression for the probability of being selected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, it's true that the first frames are also less likely to be sampled. Oh boy. I don't think we can or should do anything about that side though? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could make your code symmetric to pick those "negative frames" and then repeat them in the other direction (i.e. repeat the first frame) with your policy code. That will make it symmetric. Again, not blocking, but worth a consideration if you want to make it symmetric and fair. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's kind of like how Conv1D adds padding to both sides: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html You would have to do something symmetric like repeating frames on the left and right boundaries to make it truly fair. Not sure if it's worth the effort. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I did write a test to illustrate the difference of behavior depending on the value of Writing an actual statistical test is challenging because in fact when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking about it a bit more, if you want frames to all be totally equal in terms of output probability, you will have to pick first first and last few frames with higher probability. It's not as simple as repeating the last frame, etc. because the second last frame also needs to be weighted higher, etc. And that's why your test doesn't follow a uniform distribution. For now this implementation is easy and simple. Later on if there is a need, we can implement something complex and "fair". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, enforcing a uniform distribution across all frames is pretty hard. |
||
) -> List[FrameBatch]: | ||
|
||
_validate_params( | ||
decoder=decoder, | ||
num_clips=num_clips, | ||
num_frames_per_clip=num_frames_per_clip, | ||
num_indices_between_frames=num_indices_between_frames, | ||
policy=policy, | ||
) | ||
|
||
clip_span = _get_clip_span( | ||
num_indices_between_frames=num_indices_between_frames, | ||
num_frames_per_clip=num_frames_per_clip, | ||
) | ||
|
||
# TODO: We should probably not error. | ||
if clip_span > len(decoder): | ||
raise ValueError( | ||
f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" | ||
) | ||
|
||
sampling_range_start, sampling_range_end = _validate_sampling_range( | ||
num_frames_per_clip=num_frames_per_clip, | ||
num_indices_between_frames=num_indices_between_frames, | ||
sampling_range_start=sampling_range_start, | ||
sampling_range_end=sampling_range_end, | ||
num_frames=len(decoder), | ||
clip_span=clip_span, | ||
num_frames_in_video=len(decoder), | ||
) | ||
|
||
clip_start_indices = torch.randint( | ||
low=sampling_range_start, high=sampling_range_end, size=(num_clips,) | ||
) | ||
|
||
# We want to avoid seeking backwards, so we sort the clip start indices | ||
# before decoding the frames, and then re-shuffle the clips afterwards. | ||
# Backward seeks may still happen if there are overlapping clips, i.e. if a | ||
# clip ends after the next one starts. | ||
# TODO: We should use a different strategy to avoid backward seeks: | ||
# - flatten all frames indices, irrespective of their clip | ||
# - sort the indices and dedup | ||
# - decode all frames in index order | ||
# - re-arrange the frames back into their original clips | ||
clip_start_indices = torch.sort(clip_start_indices).values | ||
clips = [ | ||
decoder.get_frames_at( | ||
start=clip_start_index, | ||
stop=clip_start_index + clip_span, | ||
step=num_indices_between_frames, | ||
if kind == "random": | ||
clip_start_indices = torch.randint( | ||
low=sampling_range_start, high=sampling_range_end, size=(num_clips,) | ||
) | ||
else: | ||
# Note [num clips larger than sampling range] | ||
# If we ask for more clips than there are frames in the sampling range or | ||
# in the video, we rely on torch.linspace behavior which will return | ||
# duplicated indices. | ||
# E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns | ||
# 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 | ||
# Alternatively we could wrap around, but the current behavior is closer to | ||
# the expected "equally spaced indices" sampling. | ||
clip_start_indices = torch.linspace( | ||
sampling_range_start, | ||
sampling_range_end - 1, | ||
steps=num_clips, | ||
dtype=torch.int, | ||
) | ||
for clip_start_index in clip_start_indices | ||
] | ||
|
||
# This an ugly way to shuffle the clips using pytorch RNG *without* | ||
# affecting the python builtin RNG. | ||
builtin_random_state = random.getstate() | ||
random.seed(torch.randint(0, 2**32, (1,)).item()) | ||
random.shuffle(clips) | ||
random.setstate(builtin_random_state) | ||
|
||
return clips | ||
all_clips_indices = _build_all_clips_indices( | ||
clip_start_indices=clip_start_indices, | ||
num_frames_per_clip=num_frames_per_clip, | ||
num_indices_between_frames=num_indices_between_frames, | ||
num_frames_in_video=len(decoder), | ||
policy_fun=_POLICY_FUNCTIONS[policy], | ||
) | ||
return _decode_all_clips_indices( | ||
decoder, | ||
all_clips_indices=all_clips_indices, | ||
num_frames_per_clip=num_frames_per_clip, | ||
) | ||
|
||
|
||
def clips_at_regular_indices( | ||
def clips_at_random_indices( | ||
decoder: VideoDecoder, | ||
*, | ||
num_clips: int = 1, | ||
num_frames_per_clip: int = 1, | ||
num_indices_between_frames: int = 1, | ||
sampling_range_start: int = 0, | ||
sampling_range_end: Optional[int] = None, # interval is [start, end). | ||
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", | ||
) -> List[FrameBatch]: | ||
|
||
_validate_params( | ||
return _abstract_sampler( | ||
kind="random", | ||
decoder=decoder, | ||
num_clips=num_clips, | ||
num_frames_per_clip=num_frames_per_clip, | ||
num_indices_between_frames=num_indices_between_frames, | ||
sampling_range_start=sampling_range_start, | ||
sampling_range_end=sampling_range_end, | ||
policy=policy, | ||
) | ||
|
||
clip_span = _get_clip_span( | ||
num_indices_between_frames=num_indices_between_frames, | ||
num_frames_per_clip=num_frames_per_clip, | ||
) | ||
|
||
# TODO: We should probably not error. | ||
if clip_span > len(decoder): | ||
raise ValueError( | ||
f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" | ||
) | ||
def clips_at_regular_indices( | ||
decoder: VideoDecoder, | ||
*, | ||
num_clips: int = 1, | ||
num_frames_per_clip: int = 1, | ||
num_indices_between_frames: int = 1, | ||
sampling_range_start: int = 0, | ||
sampling_range_end: Optional[int] = None, # interval is [start, end). | ||
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", | ||
) -> List[FrameBatch]: | ||
|
||
sampling_range_start, sampling_range_end = _validate_sampling_range( | ||
return _abstract_sampler( | ||
kind="regular", | ||
decoder=decoder, | ||
num_clips=num_clips, | ||
num_frames_per_clip=num_frames_per_clip, | ||
num_indices_between_frames=num_indices_between_frames, | ||
sampling_range_start=sampling_range_start, | ||
sampling_range_end=sampling_range_end, | ||
num_frames=len(decoder), | ||
clip_span=clip_span, | ||
policy=policy, | ||
) | ||
|
||
# Note [num clips larger than sampling range] | ||
# If we ask for more clips than there are frames in the sampling range or | ||
# in the video, we rely on torch.linspace behavior which will return | ||
# duplicated indices. | ||
# E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns | ||
# 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 | ||
# Alternatively we could wrap around, but the current behavior is closer to | ||
# the expected "equally spaced indices" sampling. | ||
clip_start_indices = torch.linspace( | ||
sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int | ||
) | ||
|
||
# Similarly to clip_at_random_indices, there may be backward seeks if clips overlap. | ||
# See other TODO over there, and apply similar changes here. | ||
clips = [ | ||
decoder.get_frames_at( | ||
start=clip_start_index, | ||
stop=clip_start_index + clip_span, | ||
step=num_indices_between_frames, | ||
) | ||
for clip_start_index in clip_start_indices | ||
] | ||
|
||
return clips |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the discussion below, shouldn't this just be
num_frames_in_video
? That is, the default last index at which a clip can start is the final frame in the video. And if we're worried thatnum_frames_in_video
is not at least 1, maybe we should explicitly check for (and error?) on that?Even if we decide line 56 is correct, I can't reconcile it and 67. There, it looks like we're making sure that
sampling_range_end
is not larger than the video itself - which makes sense to me. But there, we don't subtract off the size of a clip, but here at line 56, we do.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the lack of documentation is making this significantly harder to review, sorry about that.
So, my reasoning is the following:
sampling_range_end
is passed by the userclip_span
fromnum_frames_in_video
.clip_span
(and we don't want them to do that I think?).sampling_range_end
, we just make sure it doesn't go abovenum_frames_in_video
to avoid errors on short videos. If we were to substractclip_span
there, it would change the semantic ofsampling_range_end
from "upper bound of the start" to "upper bound of the end"Basically:
policy
doesn't come into play, but there are no "degenerate" clips (degenerate = policy had to be applied)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that all tracks. I think that is consistent with what I surmised below, when I read the test. I do have a slight preference for what I explained, but it's only a slight preference and not blocking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, thanks. I don't have a super strong preference either. Let me merge now to unblock, but happy to revisit.