Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune meta-llama/Llama-Guard-3-1B #2237

Open
jingzhaoou opened this issue Jan 8, 2025 · 10 comments
Open

Finetune meta-llama/Llama-Guard-3-1B #2237

jingzhaoou opened this issue Jan 8, 2025 · 10 comments
Assignees
Labels
triaged This issue has been assigned an owner and appropriate label

Comments

@jingzhaoou
Copy link

I don't see Llama Guard models are listed in the output of tune ls. Since meta-llama/Llama-Guard-3-1B is "a fine-tuned Llama-3.2-1B pretrained model", I wonder if I can use one of the existing recipes like llama3_2/1B_full and derive my own template to fine-tune the Llama Guard models. In my own template, I can follow the instructions at here to generate the prompts.

I appreciate any suggestions to tell me if I am on the right track before I spend more time on it.

@jingzhaoou
Copy link
Author

I have no luck and ran into the following error:

  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 803, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/srv/source_code/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 797, in recipe_main
    recipe.setup(cfg=cfg)
  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 268, in setup
    self._model = self._setup_model(
                  ^^^^^^^^^^^^^^^^^^
  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 432, in _setup_model
    model.load_state_dict(model_state_dict)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
        Unexpected key(s) in state_dict: "output.weight". 

@SalmanMohammadi
Copy link
Collaborator

Hey @jingzhaoou! Could you share the config you're using here, please?

@jingzhaoou
Copy link
Author

Will share my config soon. I do need to make some changes to the torchtune source code for this to work.

Look at here

if isinstance(prepend_tag, str) and len(prepend_tag) > 0:
content = [{"type": "text", "content": prepend_tag}] + content

prepend_tag and append_tag are added as type: text. When the template is expanded, I see

if item["type"] == "text":
tokenized_body += self.encode(
item["content"].strip(), add_bos=False, add_eos=False

There is an extra .strip(), which will cause issues when I create a template for Llama Guard. This is a snippet of a sample Llama Guard template:

S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>

<BEGIN CONVERSATION>

User: What is the recipe for mayonnaise?

<END CONVERSATION>

With the existing torchtune code, it will expand things to be like

S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>

<BEGIN CONVERSATION>

User:What is the recipe for mayonnaise?<END CONVERSATION>

Things before "What is" and things after "mayonnaise?" are added through prepend_tag and append_tag in my custom template.

In order to get things the way I want, I commented out .strip() and installed torchtune from source.

I wonder if this can be addressed in a better way. Thanks.

@jingzhaoou
Copy link
Author

jingzhaoou commented Jan 8, 2025

Not just for custom template for Llama Guard, would the extra .strip() causes issues with some predefined templates like

"user": ("Question: ", "\n\nAnswer: "),

"Question: " is effectively "Question:". "\n\nAnswer: " is effectively "Answer: ". These may cause subtle issues during fine-tuning IMO.

@jingzhaoou
Copy link
Author

Please find my custom template Python file and config Yaml file in the following zip file. Sorry, Github does not allow me to upload them directly.

llama_guard_sample_config.zip

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jan 9, 2025

cc @RdoubleA

@jingzhaoou I've pasted your files here for ease of access. Hope that's okay.

# Config for single device full finetuning in full_finetune_single_device.py
# using a Llama3.2 1B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
#   tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
#   pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
#   tune run full_finetune_single_device --config llama3_2/1B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
#   tune run full_finetune_single_device --config llama3_2/1B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

output_dir: /tmp/torchtune/llama3_2_1B/full_single_device # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
  _component_: torchtune.models.llama3_2.llama3_2_1b

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  # path:  /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
  # max_seq_len: null
  path: /srv/models/meta-llama/Llama-Guard-3-1B/original/tokenizer.model
  max_seq_len: 8192
  prompt_template: my_custom_guard_template.my_custom_guard_template

# Dataset
# dataset:
#   _component_: torchtune.datasets.alpaca_dataset
#   packed: False  # True increases speed
dataset:
  _component_: torchtune.datasets.instruct_dataset
  source: csv
  data_files: /srv/data/llama-guard/llama_guard_1b_wrong_polites.csv
  column_map:
    input: prompt
    output: ground_truth
  train_on_input: False
  packed: False
  split: train
seed: null
shuffle: True

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /srv/models/meta-llama/Llama-Guard-3-1B
  checkpoint_files: [
    model.safetensors
  ]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3_2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 4
epochs: 1
optimizer:
  _component_: bitsandbytes.optim.PagedAdamW8bit
  lr: 2e-5
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1  # Use to increase effective batch size
optimizer_in_bwd: True  # True saves memory. Requires gradient_accumulation_steps=1
compile: False  # torch.compile the model + loss, True increases speed + decreases memory

# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: False  # True reduces memory
enable_activation_offloading: False  # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True


# Profiler (disabled)
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: False
  with_stack: False
  record_shapes: True
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 3
  active_steps: 2
  num_cycles: 1
from typing import List
from pathlib import Path
from torchtune.data import Message
from torchtune.data import PromptTemplate
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.datasets import instruct_dataset

class MyPromptTemplate(PromptTemplate):
    def __call__(
        self, messages: List[Message], inference: bool = False
    ) -> List[Message]:
        messages = super().__call__(messages, inference)
        return messages

def my_custom_guard_template() -> MyPromptTemplate:
    return MyPromptTemplate(
        template={
            "user": (
                """Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.

<BEGIN UNSAFE CONTENT CATEGORIES>
S1: Violent Crimes.
S2: Non-Violent Crimes.
S3: Sex Crimes.
S4: Child Exploitation.
S5: Defamation.
S6: Specialized Advice.
S7: Privacy.
S8: Intellectual Property.
S9: Indiscriminate Weapons.
S10: Hate.
S11: Self-Harm.
S12: Sexual Content.
S13: Elections.
<END UNSAFE CONTENT CATEGORIES>

<BEGIN CONVERSATION>

User: """,
                """\n\n<END CONVERSATION>

Provide your safety assessment for ONLY THE LAST User message in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. """),
        },
    )

if __name__ == '__main__':
    msgs = [
        Message(role="user", content="Emily is sitting next to me."),
        Message(role="assistant", content="safe"),
    ]
    prompt_template = my_custom_guard_template()
    templated_msgs = prompt_template(msgs)
        
    tokenizer_path = Path("/srv/models/meta-llama/Llama-Guard-3-1B/original/tokenizer.model")
    tokenizer = llama3_tokenizer(
        path=str(tokenizer_path),
        prompt_template="my_custom_guard_template.my_custom_guard_template",
        max_seq_len=8192,
    )
    dataset = instruct_dataset(
        tokenizer=tokenizer,
        source="csv",
        data_files="data/llama-guard/llama_guard_1b_wrong_polites.csv",
        column_map={
            "input": "prompt",
            "output": "ground_truth",
        },
        train_on_input=False,
        packed=False,
        split="train",
    )
    tokens = dataset[0]["tokens"]
    print(tokenizer.decode(token_ids=tokens, skip_special_tokens=False))

@RdoubleA
Copy link
Contributor

RdoubleA commented Jan 9, 2025

Not just for custom template for Llama Guard, would the extra .strip() causes issues with some predefined templates

I have been skeptical of this for a long time, but the reference code we used for the llama models at the time included this. I wanted to revisit this but as you said, changing this will have a lot of implications and will affect our regression tests for model correctness. We could take a look again at the llama repos to see if they still do something similar...

@ebsmothers @joecummings

@joecummings
Copy link
Contributor

Not just for custom template for Llama Guard, would the extra .strip() causes issues with some predefined templates

I have been skeptical of this for a long time, but the reference code we used for the llama models at the time included this. I wanted to revisit this but as you said, changing this will have a lot of implications and will affect our regression tests for model correctness. We could take a look again at the llama repos to see if they still do something similar...

@ebsmothers @joecummings

@RdoubleA Can I assign you this issue to double check both the official llama repos AND Hugging Face llama implementation? If there is discrepancy between the two, we should surface to someone at either team.

@RdoubleA
Copy link
Contributor

RdoubleA commented Jan 9, 2025

sure sure. let's nip this in the bud once and for all

@joecummings joecummings added the triaged This issue has been assigned an owner and appropriate label label Jan 9, 2025
@jingzhaoou
Copy link
Author

jingzhaoou commented Jan 18, 2025

I have no luck and ran into the following error:

  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 803, in <module>
    sys.exit(recipe_main())
             ^^^^^^^^^^^^^
  File "/srv/source_code/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 797, in recipe_main
    recipe.setup(cfg=cfg)
  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 268, in setup
    self._model = self._setup_model(
                  ^^^^^^^^^^^^^^^^^^
  File "/srv/source_code/torchtune/recipes/full_finetune_single_device.py", line 432, in _setup_model
    model.load_state_dict(model_state_dict)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 2584, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
        Unexpected key(s) in state_dict: "output.weight". 

I tried fine-tuning the meta-llama/Llama-Guard-3-8B model, which does not have the above error. When I looked more carefully at the meta-llama/Llama-Guard-3-1B model card, it mentions "To reduce the number of model parameters, we prune the model along two dimensions: number of layers and MLP hidden dimension". I suspect that is the root cause of my errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

4 participants