diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 3a8c4629..0844c5fd 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -1,13 +1,12 @@ -import random -from typing import List, Optional +from typing import Callable, List, Literal, Optional import torch -from torchcodec.decoders import FrameBatch, VideoDecoder +from torchcodec.decoders import Frame, FrameBatch, VideoDecoder def _validate_params( - *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames + *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames, policy ): if len(decoder) < 1: raise ValueError( @@ -25,21 +24,36 @@ def _validate_params( f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" ) + if policy not in _POLICY_FUNCTIONS.keys(): + raise ValueError( + f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}." + ) + def _validate_sampling_range( - *, sampling_range_start, sampling_range_end, num_frames, clip_span + *, + num_indices_between_frames, + num_frames_per_clip, + sampling_range_start, + sampling_range_end, + num_frames_in_video, ): if sampling_range_start < 0: - sampling_range_start = num_frames + sampling_range_start + sampling_range_start = num_frames_in_video + sampling_range_start - if sampling_range_start >= num_frames: + if sampling_range_start >= num_frames_in_video: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be smaller than " - f"the number of frames ({num_frames})." + f"the number of frames ({num_frames_in_video})." ) + clip_span = _get_clip_span( + num_indices_between_frames=num_indices_between_frames, + num_frames_per_clip=num_frames_per_clip, + ) + if sampling_range_end is None: - sampling_range_end = num_frames - clip_span + 1 + sampling_range_end = max(num_frames_in_video - clip_span + 1, 1) if sampling_range_start >= sampling_range_end: raise ValueError( f"We determined that sampling_range_end should be {sampling_range_end}, " @@ -49,8 +63,8 @@ def _validate_sampling_range( else: if sampling_range_end < 0: # Support negative values so that -1 means last frame. - sampling_range_end = num_frames + sampling_range_end - sampling_range_end = min(sampling_range_end, num_frames) + sampling_range_end = num_frames_in_video + sampling_range_end + sampling_range_end = min(sampling_range_end, num_frames_in_video) if sampling_range_start >= sampling_range_end: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be smaller than " @@ -74,14 +88,119 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): return num_indices_between_frames * (num_frames_per_clip - 1) + 1 -def clips_at_random_indices( +def _repeat_last_policy( + frame_indices: list[int], num_frames_per_clip: int +) -> list[int]: + # frame_indices = [1, 2, 3], num_frames_per_clip = 5 + # output = [1, 2, 3, 3, 3] + frame_indices += [frame_indices[-1]] * (num_frames_per_clip - len(frame_indices)) + return frame_indices + + +def _wrap_policy(frame_indices: list[int], num_frames_per_clip: int) -> list[int]: + # frame_indices = [1, 2, 3], num_frames_per_clip = 5 + # output = [1, 2, 3, 1, 2] + return (frame_indices * (num_frames_per_clip // len(frame_indices) + 1))[ + :num_frames_per_clip + ] + + +def _error_policy(frames_indices: list[int], num_frames_per_clip: int) -> list[int]: + raise ValueError( + "You set the 'error' policy, and the sampler tried to decode a frame " + "that is beyond the number of frames in the video. " + "Try to leave sampling_range_end to its default value?" + ) + + +_POLICY_FUNCTION_TYPE = Callable[[list[int], int], list[int]] +_POLICY_FUNCTIONS: dict[str, _POLICY_FUNCTION_TYPE] = { + "repeat_last": _repeat_last_policy, + "wrap": _wrap_policy, + "error": _error_policy, +} + + +def _build_all_clips_indices( + *, + clip_start_indices: torch.Tensor, # 1D int tensor + num_frames_per_clip: int, + num_indices_between_frames: int, + num_frames_in_video: int, + policy_fun: _POLICY_FUNCTION_TYPE, +) -> list[int]: + # From the clip_start_indices [f_00, f_10, f_20, ...] + # and from the rest of the parameters, return the list of all the frame + # indices that make up all the clips. + # I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...] + # where f_01 is the index of frame 1 in clip 0. + # + # All clips in the output are of length num_frames_per_clip (=4 in example + # above). When the frame indices go beyond num_frames_in_video, we force the + # frame indices back to valid values by applying the user's policy (wrap, + # repeat, etc.). + all_clips_indices: list[int] = [] + + clip_span = _get_clip_span( + num_indices_between_frames=num_indices_between_frames, + num_frames_per_clip=num_frames_per_clip, + ) + + for start_index in clip_start_indices: + frame_index_upper_bound = min(start_index + clip_span, num_frames_in_video) + frame_indices = list( + range(start_index, frame_index_upper_bound, num_indices_between_frames) + ) + if len(frame_indices) < num_frames_per_clip: + frame_indices = policy_fun(frame_indices, num_frames_per_clip) + all_clips_indices += frame_indices + return all_clips_indices + + +def _decode_all_clips_indices( + decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int +) -> list[FrameBatch]: + # This takes the list of all the frames to decode, decode all the frames, + # and then packs them into clips of length num_frames_per_clip. + # This is slow, unoptimized, and u.g.l.y. It is not meant to stay. + # TODO: + # - sort the frames to avoid backward seeks, dedup, decode, and re-organize frames. + # - write most of this in C++ + + def chunk_list(lst, chunk_size): + # return list of sublists of length chunk_size + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + def to_framebatch(frames: list[Frame]) -> FrameBatch: + data = torch.stack([frame.data for frame in frames]) + pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) + duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) + return FrameBatch( + data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds + ) + + all_decoded_frames: list[Frame] = [ + decoder.get_frame_at(index) for index in all_clips_indices + ] + all_clips: list[list[Frame]] = chunk_list( + all_decoded_frames, chunk_size=num_frames_per_clip + ) + + return [to_framebatch(clip) for clip in all_clips] + + +def _generic_sampler( + kind: Literal["random", "regular"], decoder: VideoDecoder, *, - num_clips: int = 1, - num_frames_per_clip: int = 1, - num_indices_between_frames: int = 1, - sampling_range_start: int = 0, - sampling_range_end: Optional[int] = None, # interval is [start, end). + num_clips: int, + num_frames_per_clip: int, + num_indices_between_frames: int, + sampling_range_start: int, + sampling_range_end: Optional[int], # interval is [start, end). + # Important note: sampling_range_end defines the upper bound of where a clip + # can *start*, not where a clip can end. + policy: Literal["repeat_last", "wrap", "error"], ) -> List[FrameBatch]: _validate_params( @@ -89,60 +208,52 @@ def clips_at_random_indices( num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, num_indices_between_frames=num_indices_between_frames, + policy=policy, ) - clip_span = _get_clip_span( - num_indices_between_frames=num_indices_between_frames, - num_frames_per_clip=num_frames_per_clip, - ) - - # TODO: We should probably not error. - if clip_span > len(decoder): - raise ValueError( - f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" - ) - sampling_range_start, sampling_range_end = _validate_sampling_range( + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, - num_frames=len(decoder), - clip_span=clip_span, + num_frames_in_video=len(decoder), ) - clip_start_indices = torch.randint( - low=sampling_range_start, high=sampling_range_end, size=(num_clips,) - ) - - # We want to avoid seeking backwards, so we sort the clip start indices - # before decoding the frames, and then re-shuffle the clips afterwards. - # Backward seeks may still happen if there are overlapping clips, i.e. if a - # clip ends after the next one starts. - # TODO: We should use a different strategy to avoid backward seeks: - # - flatten all frames indices, irrespective of their clip - # - sort the indices and dedup - # - decode all frames in index order - # - re-arrange the frames back into their original clips - clip_start_indices = torch.sort(clip_start_indices).values - clips = [ - decoder.get_frames_at( - start=clip_start_index, - stop=clip_start_index + clip_span, - step=num_indices_between_frames, + if kind == "random": + clip_start_indices = torch.randint( + low=sampling_range_start, high=sampling_range_end, size=(num_clips,) + ) + else: + # Note [num clips larger than sampling range] + # If we ask for more clips than there are frames in the sampling range or + # in the video, we rely on torch.linspace behavior which will return + # duplicated indices. + # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns + # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 + # Alternatively we could wrap around, but the current behavior is closer to + # the expected "equally spaced indices" sampling. + clip_start_indices = torch.linspace( + sampling_range_start, + sampling_range_end - 1, + steps=num_clips, + dtype=torch.int, ) - for clip_start_index in clip_start_indices - ] - - # This an ugly way to shuffle the clips using pytorch RNG *without* - # affecting the python builtin RNG. - builtin_random_state = random.getstate() - random.seed(torch.randint(0, 2**32, (1,)).item()) - random.shuffle(clips) - random.setstate(builtin_random_state) - return clips + all_clips_indices = _build_all_clips_indices( + clip_start_indices=clip_start_indices, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + num_frames_in_video=len(decoder), + policy_fun=_POLICY_FUNCTIONS[policy], + ) + return _decode_all_clips_indices( + decoder, + all_clips_indices=all_clips_indices, + num_frames_per_clip=num_frames_per_clip, + ) -def clips_at_regular_indices( +def clips_at_random_indices( decoder: VideoDecoder, *, num_clips: int = 1, @@ -150,54 +261,38 @@ def clips_at_regular_indices( num_indices_between_frames: int = 1, sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: - - _validate_params( + return _generic_sampler( + kind="random", decoder=decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, num_indices_between_frames=num_indices_between_frames, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + policy=policy, ) - clip_span = _get_clip_span( - num_indices_between_frames=num_indices_between_frames, - num_frames_per_clip=num_frames_per_clip, - ) - # TODO: We should probably not error. - if clip_span > len(decoder): - raise ValueError( - f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" - ) +def clips_at_regular_indices( + decoder: VideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, + sampling_range_start: int = 0, + sampling_range_end: Optional[int] = None, # interval is [start, end). + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", +) -> List[FrameBatch]: - sampling_range_start, sampling_range_end = _validate_sampling_range( + return _generic_sampler( + kind="regular", + decoder=decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, - num_frames=len(decoder), - clip_span=clip_span, + policy=policy, ) - - # Note [num clips larger than sampling range] - # If we ask for more clips than there are frames in the sampling range or - # in the video, we rely on torch.linspace behavior which will return - # duplicated indices. - # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns - # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 - # Alternatively we could wrap around, but the current behavior is closer to - # the expected "equally spaced indices" sampling. - clip_start_indices = torch.linspace( - sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int - ) - - # Similarly to clip_at_random_indices, there may be backward seeks if clips overlap. - # See other TODO over there, and apply similar changes here. - clips = [ - decoder.get_frames_at( - start=clip_start_index, - stop=clip_start_index + clip_span, - step=num_indices_between_frames, - ) - for clip_start_index in clip_start_indices - ] - - return clips diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 2df7bf79..3496a002 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -7,6 +7,7 @@ from torchcodec.decoders import FrameBatch, VideoDecoder from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices +from torchcodec.samplers._implem import _build_all_clips_indices, _POLICY_FUNCTIONS from ..utils import assert_tensor_equal, NASA_VIDEO @@ -50,11 +51,12 @@ def test_sampler(sampler, num_indices_between_frames): # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. + avg_distance_between_frames_seconds = torch.concat( [clip.pts_seconds.diff() for clip in clips] ).mean() assert avg_distance_between_frames_seconds == pytest.approx( - num_indices_between_frames / decoder.metadata.average_fps + num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5 ) @@ -131,6 +133,67 @@ def test_sampling_range_negative(sampler): assert_tensor_equal(clip.data, clips_1[0].data) +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_sampling_range_default_behavior(sampler): + # This is a functional test for the default behavior of the + # sampling_range_end parameter. By default it's None, which means the + # sampler automatically sets its value such that we never sample "beyond" + # the number of frames in the video. That means that the last few frames of + # the video are less likely to be part of a clip. + # When sampling_range_end is set manually to e.g. len(decoder), the last + # frames are way more likely to be part of a clip, since there is no + # restriction on the sampling range (and the user-defined policy comes into + # action, potentially repeating that last frame). + # + # In this test we assert that the last clip starts significantly earlier + # when sampling_range_end=None than when sampling_range_end=len(decoder). + # This is only a proxy, for lack of better testing oppportunities. + + torch.manual_seed(0) + + decoder = VideoDecoder(NASA_VIDEO.path) + + num_clips = 20 + num_frames_per_clip = 15 + sampling_range_start = -20 + + # with default sampling_range_end value + clips_default = sampler( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + sampling_range_start=sampling_range_start, + sampling_range_end=None, + ) + + last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) + + # with manual sampling_range_end value set to last frame + clips_manual = sampler( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + sampling_range_start=sampling_range_start, + sampling_range_end=len(decoder), + ) + last_clip_start_manual = max([clip.pts_seconds[0] for clip in clips_manual]) + + assert last_clip_start_manual - last_clip_start_default > 0.3 + + +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_sampling_range_error_policy(sampler): + decoder = VideoDecoder(NASA_VIDEO.path) + with pytest.raises(ValueError, match="beyond the number of frames"): + sampler( + decoder, + num_frames_per_clip=10, + sampling_range_start=-1, + sampling_range_end=len(decoder), + policy="error", + ) + + def test_random_sampler_randomness(): decoder = VideoDecoder(NASA_VIDEO.path) num_clips = 5 @@ -215,18 +278,6 @@ def test_random_sampler_errors(sampler): ): sampler(decoder, num_indices_between_frames=0) - with pytest.raises( - ValueError, - match=re.escape("Clip span (1000) is larger than the number of frames"), - ): - sampler(decoder, num_frames_per_clip=1000) - - with pytest.raises( - ValueError, - match=re.escape("Clip span (1001) is larger than the number of frames"), - ): - sampler(decoder, num_frames_per_clip=2, num_indices_between_frames=1000) - with pytest.raises( ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") ): @@ -251,3 +302,80 @@ def test_random_sampler_errors(sampler): sampling_range_start=len(decoder) - 1, sampling_range_end=None, ) + + with pytest.raises(ValueError, match="Invalid policy"): + sampler(decoder, policy="BAD") + + +class TestPolicy: + @pytest.mark.parametrize( + "policy, frame_indices, expected_frame_indices", + ( + ("repeat_last", [1, 2, 3], [1, 2, 3, 3, 3]), + ("repeat_last", [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]), + ("wrap", [1, 2, 3], [1, 2, 3, 1, 2]), + ("wrap", [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]), + ), + ) + def test_policy(self, policy, frame_indices, expected_frame_indices): + policy_fun = _POLICY_FUNCTIONS[policy] + assert ( + policy_fun(frame_indices, num_frames_per_clip=5) == expected_frame_indices + ) + + def test_error_policy(self): + with pytest.raises(ValueError, match="beyond the number of frames"): + _POLICY_FUNCTIONS["error"]([1, 2, 3], num_frames_per_clip=5) + + +@pytest.mark.parametrize( + "clip_start_indices, num_indices_between_frames, policy, expected_all_clips_indices", + ( + ( + [0, 1, 2], # clip_start_indices + 1, # num_indices_between_frames + "repeat_last", # policy + # expected_all_clips_indices = + [0, 1, 2, 3, 4] + [1, 2, 3, 4, 4] + [2, 3, 4, 4, 4], + ), + # Same as above but with num_indices_between_frames=2 + ( + [0, 1, 2], # clip_start_indices + 2, # num_indices_between_frames + "repeat_last", # policy + # expected_all_clips_indices = + [0, 2, 4, 4, 4] + [1, 3, 3, 3, 3] + [2, 4, 4, 4, 4], + ), + # Same tests as above, for wrap policy + ( + [0, 1, 2], # clip_start_indices + 1, # num_indices_between_frames + "wrap", # policy + # expected_all_clips_indices = + [0, 1, 2, 3, 4] + [1, 2, 3, 4, 1] + [2, 3, 4, 2, 3], + ), + ( + [0, 1, 2], # clip_start_indices + 2, # num_indices_between_frames + "wrap", # policy + # expected_all_clips_indices = + [0, 2, 4, 0, 2] + [1, 3, 1, 3, 1] + [2, 4, 2, 4, 2], + ), + ), +) +def test_build_all_clips_indices( + clip_start_indices, num_indices_between_frames, policy, expected_all_clips_indices +): + NUM_FRAMES_PER_CLIP = 5 + all_clips_indices = _build_all_clips_indices( + clip_start_indices=clip_start_indices, + num_frames_per_clip=5, + num_indices_between_frames=num_indices_between_frames, + num_frames_in_video=5, + policy_fun=_POLICY_FUNCTIONS[policy], + ) + + assert isinstance(all_clips_indices, list) + assert all(isinstance(index, int) for index in all_clips_indices) + assert len(all_clips_indices) == len(clip_start_indices) * NUM_FRAMES_PER_CLIP + assert all_clips_indices == expected_all_clips_indices