Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate internal brrr to nanotron #186

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ class Config:
data_stages: Optional[List[DatasetStageArgs]] = None
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None
# If you want to signal the training script to stop, you just need to touch the following file
# We force users to set one in order to programmatically be able to remove it.
kill_switch_path: Optional[Path] = None

@classmethod
def create_empty(cls):
Expand All @@ -345,6 +348,9 @@ def create_empty(cls):

def __post_init__(self):
# Some final sanity checks across separate arguments sections:
if self.general is not None and os.environ.get("SLURM_JOB_ID", None) is not None:
self.run = self.general.run.replace("%j", os.environ["SLURM_JOB_ID"])

if self.profiler is not None and self.profiler.profiler_export_path is not None:
assert self.tokens.train_steps < 10

Expand Down Expand Up @@ -376,9 +382,13 @@ def __post_init__(self):
for i in range(len(self.data_stages) - 1)
), "The stages are not sorted by start_training_step in increasing order"

# # if lighteval, we need tokenizer to be defined
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None
# if self.lighteval is not None:
# # assert self.tokenizer.tokenizer_name_or_path is not None
# if self.lighteval.checkpoints_path is None:
# self.lighteval.checkpoints_path = self.checkpoints.checkpoints_path

if isinstance(self.kill_switch_path, str):
self.kill_switch_path = Path(self.kill_switch_path)

@property
def global_batch_size(self):
Expand Down
23 changes: 22 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import shutil
import sys
import time
from dataclasses import asdict
from pathlib import Path
Expand Down Expand Up @@ -281,7 +282,12 @@ def pre_training(self, *args, **kwargs):
)

def post_train_step(self):
pass
# Kill switch
self.check_kill_switch(save_ckpt=True)

# # Update our background upload/removal of checkpoints
# if self.s3_mover is not None:
# self.s3_mover.update()

def post_training(self):
pass
Expand Down Expand Up @@ -895,6 +901,21 @@ def _mark_tied_parameters(
):
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)

def check_kill_switch(self, save_ckpt: bool):
if self.config.kill_switch_path and self.config.kill_switch_path.exists():
log_rank(
f"Detected kill switch at {self.config.kill_switch_path}. Exiting",
logger=logger,
level=logging.INFO,
rank=0,
)

# Save checkpoint
if save_ckpt:
self.save_checkpoint()
dist.barrier()
sys.exit(0)


def mark_tied_parameters(
model: NanotronModel, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs] = None
Expand Down
Loading