From a5ab6d2855bbc6218fbbeaf32a52ba39ba2b1f57 Mon Sep 17 00:00:00 2001 From: amanturamatov Date: Sun, 26 Jan 2025 01:47:20 +0600 Subject: [PATCH] Remove unnecessary args from extractor --- oml/models/audio/ecapa_tdnn/extractor.py | 30 ++++++++++--------- .../test_oml/test_models/test_audio_models.py | 2 +- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/oml/models/audio/ecapa_tdnn/extractor.py b/oml/models/audio/ecapa_tdnn/extractor.py index 444296f9..bbd80f92 100644 --- a/oml/models/audio/ecapa_tdnn/extractor.py +++ b/oml/models/audio/ecapa_tdnn/extractor.py @@ -27,7 +27,6 @@ class ECAPATDNNExtractor(IExtractor): "init_args": { "arch": "ecapa_tdnn_taoruijie", "normalise_features": False, - "filter_state_dict_prefix": "speaker_encoder.", }, } } @@ -37,14 +36,13 @@ def __init__( weights: Optional[Union[Path, str]], arch: str, normalise_features: bool = False, - filter_state_dict_prefix: Optional[str] = None, ): """ Args: weights: Path to weights or special key for pretrained ones or ``None`` for random initialization. You can check available pretrained checkpoints in ``ECAPATDNNExtractor.pretrained_models``. arch: Model architecture, currently only supports ``ecapa_tdnn_taoruijie``. - normalise_features: Set ``True`` to normalise output features + normalise_features: Set ``True`` to normalise output features. """ super().__init__() @@ -63,12 +61,21 @@ def __init__( pretrained["hash"], # type: ignore fname=pretrained["fname"], # type: ignore ) - ckpt = torch.load(weights, map_location="cpu") - if "state_dict" in ckpt: - ckpt = ckpt["state_dict"] - if filter_state_dict_prefix is not None: - ckpt = filter_state_dict(ckpt, filter_state_dict_prefix) - self.model.load_state_dict(ckpt, strict=True) + state_dict = self.prepare_state_dict(weights, filter_prefix="speaker_encoder.") + else: + state_dict = self.prepare_state_dict(weights, filter_prefix="model.") + + self.model.load_state_dict(state_dict, strict=True) + + def prepare_state_dict(self, weights_path: Union[Path, str], filter_prefix: Optional[str] = None) -> TStateDict: + state_dict = torch.load(weights_path, map_location="cpu") + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + if filter_prefix is not None: + state_dict = OrderedDict( + (k[len(filter_prefix) :], v) for k, v in state_dict.items() if k.startswith(filter_prefix) + ) + return state_dict @property def feat_dim(self) -> int: @@ -83,9 +90,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def filter_state_dict(state_dict: TStateDict, prefix: str) -> TStateDict: - prefix_len = len(prefix) - return OrderedDict((k[prefix_len:], v) for k, v in state_dict.items() if k.startswith(prefix)) - - __all__ = ["ECAPATDNNExtractor"] diff --git a/tests/test_oml/test_models/test_audio_models.py b/tests/test_oml/test_models/test_audio_models.py index c780fccb..fc7bd6df 100644 --- a/tests/test_oml/test_models/test_audio_models.py +++ b/tests/test_oml/test_models/test_audio_models.py @@ -27,7 +27,7 @@ def test_extractor(constructor: IExtractor, args: Dict[str, Any]) -> None: fname = "weights_tmp.pth" torch.save({"state_dict": extractor.state_dict()}, fname) - extractor = ECAPATDNNExtractor(weights=fname, filter_state_dict_prefix="model.", **args).eval() + extractor = constructor(weights=fname, **args).eval() features2 = extractor.extract(signal) Path(fname).unlink()