-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetune meta-llama/Llama-Guard-3-1B #2237
Comments
I have no luck and ran into the following error:
|
Hey @jingzhaoou! Could you share the config you're using here, please? |
Will share my config soon. I do need to make some changes to the Look at here torchtune/torchtune/data/_prompt_templates.py Lines 116 to 117 in 27fd3a1
torchtune/torchtune/models/llama3/_tokenizer.py Lines 222 to 224 in 27fd3a1
There is an extra
With the existing
Things before "What is" and things after "mayonnaise?" are added through In order to get things the way I want, I commented out I wonder if this can be addressed in a better way. Thanks. |
Not just for custom template for Llama Guard, would the extra torchtune/torchtune/data/_prompt_templates.py Line 248 in 27fd3a1
"Question: " is effectively "Question:". "\n\nAnswer: " is effectively "Answer: ". These may cause subtle issues during fine-tuning IMO. |
Please find my custom template Python file and config Yaml file in the following zip file. Sorry, Github does not allow me to upload them directly. |
cc @RdoubleA @jingzhaoou I've pasted your files here for ease of access. Hope that's okay. # Config for single device full finetuning in full_finetune_single_device.py
# using a Llama3.2 1B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config llama3_2/1B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config llama3_2/1B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.
output_dir: /tmp/torchtune/llama3_2_1B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_1b
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
# path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
# max_seq_len: null
path: /srv/models/meta-llama/Llama-Guard-3-1B/original/tokenizer.model
max_seq_len: 8192
prompt_template: my_custom_guard_template.my_custom_guard_template
# Dataset
# dataset:
# _component_: torchtune.datasets.alpaca_dataset
# packed: False # True increases speed
dataset:
_component_: torchtune.datasets.instruct_dataset
source: csv
data_files: /srv/data/llama-guard/llama_guard_1b_wrong_polites.csv
column_map:
input: prompt
output: ground_truth
train_on_input: False
packed: False
split: train
seed: null
shuffle: True
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /srv/models/meta-llama/Llama-Guard-3-1B
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
# Fine-tuning arguments
batch_size: 4
epochs: 1
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
# Memory management
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs
#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True
#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False
# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1 from typing import List
from pathlib import Path
from torchtune.data import Message
from torchtune.data import PromptTemplate
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.datasets import instruct_dataset
class MyPromptTemplate(PromptTemplate):
def __call__(
self, messages: List[Message], inference: bool = False
) -> List[Message]:
messages = super().__call__(messages, inference)
return messages
def my_custom_guard_template() -> MyPromptTemplate:
return MyPromptTemplate(
template={
"user": (
"""Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
S1: Violent Crimes.
S2: Non-Violent Crimes.
S3: Sex Crimes.
S4: Child Exploitation.
S5: Defamation.
S6: Specialized Advice.
S7: Privacy.
S8: Intellectual Property.
S9: Indiscriminate Weapons.
S10: Hate.
S11: Self-Harm.
S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
User: """,
"""\n\n<END CONVERSATION>
Provide your safety assessment for ONLY THE LAST User message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. """),
},
)
if __name__ == '__main__':
msgs = [
Message(role="user", content="Emily is sitting next to me."),
Message(role="assistant", content="safe"),
]
prompt_template = my_custom_guard_template()
templated_msgs = prompt_template(msgs)
tokenizer_path = Path("/srv/models/meta-llama/Llama-Guard-3-1B/original/tokenizer.model")
tokenizer = llama3_tokenizer(
path=str(tokenizer_path),
prompt_template="my_custom_guard_template.my_custom_guard_template",
max_seq_len=8192,
)
dataset = instruct_dataset(
tokenizer=tokenizer,
source="csv",
data_files="data/llama-guard/llama_guard_1b_wrong_polites.csv",
column_map={
"input": "prompt",
"output": "ground_truth",
},
train_on_input=False,
packed=False,
split="train",
)
tokens = dataset[0]["tokens"]
print(tokenizer.decode(token_ids=tokens, skip_special_tokens=False)) |
I have been skeptical of this for a long time, but the reference code we used for the llama models at the time included this. I wanted to revisit this but as you said, changing this will have a lot of implications and will affect our regression tests for model correctness. We could take a look again at the llama repos to see if they still do something similar... |
@RdoubleA Can I assign you this issue to double check both the official llama repos AND Hugging Face llama implementation? If there is discrepancy between the two, we should surface to someone at either team. |
sure sure. let's nip this in the bud once and for all |
I tried fine-tuning the |
I don't see Llama Guard models are listed in the output of
tune ls
. Since meta-llama/Llama-Guard-3-1B is "a fine-tuned Llama-3.2-1B pretrained model", I wonder if I can use one of the existing recipes likellama3_2/1B_full
and derive my own template to fine-tune the Llama Guard models. In my own template, I can follow the instructions at here to generate the prompts.I appreciate any suggestions to tell me if I am on the right track before I spend more time on it.
The text was updated successfully, but these errors were encountered: