-
Notifications
You must be signed in to change notification settings - Fork 973
/
Copy pathconvert.py
91 lines (73 loc) · 2.66 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import json
import shutil
from pathlib import Path
from typing import Dict, Union
import mlx.core as mx
from huggingface_hub import snapshot_download
def save_weights(save_path: Union[str, Path], weights: Dict[str, mx.array]) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
model_path = save_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights)
for weight_name in weights.keys():
index_data["weight_map"][weight_name] = "model.safetensors"
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(index_data, f, indent=4)
def download(hf_repo):
return Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.safetensors", "*.json"],
resume_download=True,
)
)
def convert(model_path):
weight_file = str(model_path / "model.safetensors")
weights = mx.load(weight_file)
mlx_weights = dict()
for k, v in weights.items():
if k in {
"vision_encoder.patch_embed.projection.weight",
"vision_encoder.neck.conv1.weight",
"vision_encoder.neck.conv2.weight",
"prompt_encoder.mask_embed.conv1.weight",
"prompt_encoder.mask_embed.conv2.weight",
"prompt_encoder.mask_embed.conv3.weight",
}:
v = v.transpose(0, 2, 3, 1)
if k in {
"mask_decoder.upscale_conv1.weight",
"mask_decoder.upscale_conv2.weight",
}:
v = v.transpose(1, 2, 3, 0)
mlx_weights[k] = v
return mlx_weights
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Meta SAM weights to MLX")
parser.add_argument(
"--hf-path",
default="facebook/sam-vit-base",
type=str,
help="Path to the Hugging Face model repo.",
)
parser.add_argument(
"--mlx-path",
type=str,
default="sam-vit-base",
help="Path to save the MLX model.",
)
args = parser.parse_args()
model_path = download(args.hf_path)
mlx_path = Path(args.mlx_path)
mlx_path.mkdir(parents=True, exist_ok=True)
mlx_weights = convert(model_path)
save_weights(mlx_path, mlx_weights)
shutil.copy(model_path / "config.json", mlx_path / "config.json")