Skip to content

Commit

Permalink
Remove unnecessary args from extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
amanteur committed Jan 25, 2025
1 parent e44ea0e commit a5ab6d2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
30 changes: 16 additions & 14 deletions oml/models/audio/ecapa_tdnn/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class ECAPATDNNExtractor(IExtractor):
"init_args": {
"arch": "ecapa_tdnn_taoruijie",
"normalise_features": False,
"filter_state_dict_prefix": "speaker_encoder.",
},
}
}
Expand All @@ -37,14 +36,13 @@ def __init__(
weights: Optional[Union[Path, str]],
arch: str,
normalise_features: bool = False,
filter_state_dict_prefix: Optional[str] = None,
):
"""
Args:
weights: Path to weights or special key for pretrained ones or ``None`` for random initialization.
You can check available pretrained checkpoints in ``ECAPATDNNExtractor.pretrained_models``.
arch: Model architecture, currently only supports ``ecapa_tdnn_taoruijie``.
normalise_features: Set ``True`` to normalise output features
normalise_features: Set ``True`` to normalise output features.
"""
super().__init__()

Expand All @@ -63,12 +61,21 @@ def __init__(
pretrained["hash"], # type: ignore
fname=pretrained["fname"], # type: ignore
)
ckpt = torch.load(weights, map_location="cpu")
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
if filter_state_dict_prefix is not None:
ckpt = filter_state_dict(ckpt, filter_state_dict_prefix)
self.model.load_state_dict(ckpt, strict=True)
state_dict = self.prepare_state_dict(weights, filter_prefix="speaker_encoder.")
else:
state_dict = self.prepare_state_dict(weights, filter_prefix="model.")

self.model.load_state_dict(state_dict, strict=True)

def prepare_state_dict(self, weights_path: Union[Path, str], filter_prefix: Optional[str] = None) -> TStateDict:
state_dict = torch.load(weights_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if filter_prefix is not None:
state_dict = OrderedDict(
(k[len(filter_prefix) :], v) for k, v in state_dict.items() if k.startswith(filter_prefix)
)
return state_dict

@property
def feat_dim(self) -> int:
Expand All @@ -83,9 +90,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def filter_state_dict(state_dict: TStateDict, prefix: str) -> TStateDict:
prefix_len = len(prefix)
return OrderedDict((k[prefix_len:], v) for k, v in state_dict.items() if k.startswith(prefix))


__all__ = ["ECAPATDNNExtractor"]
2 changes: 1 addition & 1 deletion tests/test_oml/test_models/test_audio_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_extractor(constructor: IExtractor, args: Dict[str, Any]) -> None:
fname = "weights_tmp.pth"
torch.save({"state_dict": extractor.state_dict()}, fname)

extractor = ECAPATDNNExtractor(weights=fname, filter_state_dict_prefix="model.", **args).eval()
extractor = constructor(weights=fname, **args).eval()
features2 = extractor.extract(signal)
Path(fname).unlink()

Expand Down

0 comments on commit a5ab6d2

Please sign in to comment.