Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added some synchronicity #533

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 48 additions & 45 deletions inference/convert.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import shutil
from argparse import ArgumentParser
from parser import Parser
from glob import glob
from tqdm import tqdm, trange

import asyncio as sync
import torch
from safetensors.torch import safe_open, save_file

Expand All @@ -29,6 +29,32 @@
"scale": ("scale", None),
}

async def set_param(param, name, i, n_local_experts, mp, state_dicts, dim):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved code to functions to apply to asyncio.gather or asyncio.to_thread.

new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
return
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
async def inner_safe_open(name, f, state_dicts, mp, n_local_experts):
if "model.layers.61" not in name:
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
await sync.gather(*(set_param(param, name, i, n_local_experts, mp, state_dicts, dim) for i in range(mp)))


def main(hf_ckpt_path, save_path, n_experts, mp):
"""
Expand All @@ -44,53 +70,30 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]

for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
n_local_experts,state_dicts = n_experts // mp, [{} for _ in range(mp)]
tensor_dir, token_dir = list(glob(os.path.join(hf_ckpt_path, "*.safetensors"))),list(glob(os.path.join(hf_ckpt_path, "*token*")))
for file_path in tqdm(tensor_dir):
cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu")
async with cm as f:
await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys()))

os.makedirs(save_path, exist_ok=True)

for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
await sync.gather(*(sync.to_thread(save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))) for i in trange(mp)))

async def set_file_path(file_path):
await sync.to_thread(shutil.copyfile, file_path, os.path.join(save_path, os.path.basename(file_path)))

await sync.gather(*(set_file_path(file_path) for file_path in token_dir))

for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
arg_list = [
("--hf-ckpt-path", type:=str, required:=True),
("--save-path", type:=str, required:=True),
("--n-experts", type:=int, required:=True),
("--model-parallel", type:=int, required:=True)
]
args = Parser(arg_list).apply_args().assert_model_parallel().return_args()
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))
80 changes: 38 additions & 42 deletions inference/fp8_cast_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,41 @@
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm

from asyncio import gather, to_thread, run
import torch
from safetensors.torch import load_file, save_file

from kernel import weight_dequant

def main(fp8_path, bf16_path):
def inner_tensor_file(safetensor_file):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)

# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()

async def main(fp8_path, bf16_path):
"""
Converts FP8 weights to BF16 and saves the converted weights.

Expand All @@ -32,13 +60,11 @@ def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
with open(model_index_file, "r") as f: model_index = json.load(f)
weight_map = model_index["weight_map"]

# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
loaded_files, fp8_weight_names = {}, []

# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
Expand All @@ -62,51 +88,21 @@ def get_tensor(tensor_name):

safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict

new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight

new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)

# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
gather(*(to_thread(inner_tensor_file, safetensor_file) for safetensor_file in tqdm(safetensor_files)))


# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if scale_inv_name in weight_map: weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
run(main(args.input_fp8_hf_path, args.output_bf16_hf_path))

55 changes: 25 additions & 30 deletions inference/generate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
import json
from parser import Parser
from argparse import ArgumentParser
from typing import List

import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model

from asyncio import gather, to_thread, run
from model import Transformer, ModelArgs


Expand Down Expand Up @@ -36,6 +36,7 @@ def generate(
temperature: float = 1.0
) -> List[List[int]]:
"""

Generates new tokens based on the given prompt tokens using the specified model.

Args:
Expand All @@ -47,38 +48,35 @@ def generate(

Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.

"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens): tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
def inner_cur_pos():
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
if temperature > 0: next_token = sample(logits, temperature)
else: next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
if finished.all(): return
gather(*(to_thread(cur_pos) for cur_pos in range(min(prompt_lens), total_len)))
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
if eos_id in toks: toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens


def main(
async def main(
ckpt_path: str,
config: str,
input_file: str = "",
Expand Down Expand Up @@ -131,8 +129,7 @@ def main(
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
if prompt == "/exit": break
elif prompt == "/clear":
messages.clear()
continue
Expand All @@ -143,8 +140,7 @@ def main(
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
with open(input_file) as f: prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
Expand All @@ -154,8 +150,7 @@ def main(
print("Completion:", completion)
print()

if world_size > 1:
dist.destroy_process_group()
if world_size > 1: dist.destroy_process_group()


if __name__ == "__main__":
Expand All @@ -173,13 +168,13 @@ def main(
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
arg_list = [
("--ckpt-path", type:=str, required:=True),
("--config", type:=str, required:=True),
("--input-file", type:=str, default:=""),
("--interactive", action:="store_true"),
("--max-new-tokens", type:=int, default:=200),
("--temperature", type:=float, default:=0.2)
]
args = Parser(arg_list).apply_args().assert_interactive().return_args()
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))
18 changes: 18 additions & 0 deletions inference/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from argparse import ArgumentParser

class Parser():
def __init__(self, parser = ArgumentParser(), arg_list = []):
self.parser = parser
self.arg_list = arg_list
def apply_args(self):
for arg in self.arg_list: self.parser.add_argument(*arg)
return self
def assert_model_parallel(self):
assert self.return_args.n_experts % self.return_args().model_parallel == 0
return self
def assert_interactive():
assert self.return_args().input_file or self.return_args().interactive
return self
def return_args(self):
return self.parser.parse_args()