Skip to content

Commit

Permalink
Pre-release improvements (#9)
Browse files Browse the repository at this point in the history
* Parse patience and min_delta arguments

* Reset metrics at the start of each validation and test epoch

* Validate positive integer and float arguments

* Update README.md with validation functions

* Validate positive integer arguments

* Update README.md with validation functions

---------

Co-authored-by: Ehssan <>
  • Loading branch information
ekintel authored Jan 22, 2025
1 parent 9685793 commit 605f22b
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 16 deletions.
7 changes: 7 additions & 0 deletions data-generator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ The file is expected to be located in the same directory as the script. If the f
#### Raises:
- `SystemExit`: If the token file is not found, permission is denied, or any other error occurs while reading the file.

### validate_positive_integer(value: str) -> int

Validates that the input string, provided via command-line arguments, represents a positive integer.

#### Raises:
- `argparse.ArgumentTypeError`: If the input is not a positive integer.

### `parse_string(input_string: str) -> Tuple[str, str]`

Parses a string containing `OUTPUT:` and `REASONING:` sections and extracts their values.
Expand Down
30 changes: 27 additions & 3 deletions data-generator/sdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ def read_token() -> None:
login(token)


def validate_positive_integer(value: str) -> int:
"""
Validate that the input is a positive integer.
Args:
value: The input string from argparse
Returns:
int: The validated integer value
Raises:
argparse.ArgumentTypeError: If validation fails
"""
try:
int_value = int(value)
if int_value <= 0:
raise argparse.ArgumentTypeError(
f"The input value must be positive, got {int_value}"
)
return int_value
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid integer value: {value}")


def parse_string(input_string: str) -> Tuple[str, str]:
"""
Parses a string containing `OUTPUT:` and `REASONING:` sections and extracts their values.
Expand Down Expand Up @@ -224,7 +248,7 @@ def main() -> None:

parser.add_argument(
"--sample_size",
type=int,
type=validate_positive_integer,
default=100,
help="The number of samples generated by the language model (default: 100)",
)
Expand All @@ -236,13 +260,13 @@ def main() -> None:
)
parser.add_argument(
"--max_new_tokens",
type=int,
type=validate_positive_integer,
default=256,
help="The maximum number of new tokens to generate for each sample (default: 256)",
)
parser.add_argument(
"--batch_size",
type=int,
type=validate_positive_integer,
default=20,
help="The batch size for saving generated samples to file (default: 20)",
)
Expand Down
19 changes: 14 additions & 5 deletions fine-tuner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ python fine-tune.py --train_data path/to/train.csv --val_data path/to/val.csv --
- `--train_data`: Path to the training CSV file (required)
- `--val_data`: Path to the validation CSV file (required)
- `--test_data`: Path to the test CSV file (required)
- `--batch_size`: Batch size for training and evaluation (default: 16)
- `--batch_size`: Batch size for training and evaluation (default: 32)
- `--learning_rate`: Learning rate for the optimizer (default: 5e-5)
- `--weight_decay`: Weight decay for the optimizer (default: 0.01)
- `--max_epochs`: Number of epochs for training (default: 6)
- `--max_epochs`: Number of epochs for training (default: 2)
- `--precision`: Precision for training (e.g., '16-mixed') (default: '16-mixed')
- `--patience`: Number of epochs with no improvement in the monitored metric after which training will be stopped (default: 3)
- `--min_delta`: Minimum change in the monitored metric to qualify as an improvement (default: 0.0)
- `--num_workers`: Number of worker threads for DataLoader (default: 6)
- `--accelerator`: Type of accelerator to use for training. Options include 'cpu', 'gpu', 'hpu', 'tpu', 'mps', and 'auto' (default: 'auto')
- `--devices`: Number of devices to use for training (default: 'auto')
Expand Down Expand Up @@ -87,7 +89,8 @@ python fine-tune.py \
--num_labels 2 \
--batch_size 32 \
--learning_rate 3e-5 \
--max_epochs 10 \
--max_epochs 3 \
--min_delta 0.005\
--log_dir ./logs \
--experiment_name my_experiment
```
Expand All @@ -96,12 +99,18 @@ This command will fine-tune a BERT model on the specified training data, validat

## Functions and Classes

## Functions and Classes

### `parse_devices(value: str) -> str | int`

Parses the devices argument for the number of devices to use for training. The argument can either be an integer, representing the number of devices, or the string 'auto', which automatically selects the available devices.

### `validate_positive_integer(value: str) -> int`

Validates that the input is a positive integer.

### `validate_positive_float(value: str) -> float`

Validates that the input is a positive float.

### `parse_args() -> argparse.Namespace`

Parses command-line arguments and returns them as an `argparse.Namespace` object.
Expand Down
106 changes: 98 additions & 8 deletions fine-tuner/fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,54 @@ def parse_devices(value: str) -> str | int:
)


def validate_positive_integer(value: str) -> int:
"""
Validate that the input is a positive integer.
Args:
value: The input string from argparse
Returns:
int: The validated integer value
Raises:
argparse.ArgumentTypeError: If validation fails
"""
try:
int_value = int(value)
if int_value <= 0:
raise argparse.ArgumentTypeError(
f"The input value must be positive, got {int_value}"
)
return int_value
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid integer value: {value}")


def validate_positive_float(value: str) -> float:
"""
Validate that the input is a positive float.
Args:
value: The input string from argparse
Returns:
float: The validated float value
Raises:
argparse.ArgumentTypeError: If validation fails
"""
try:
float_value = float(value)
if float_value <= 0:
raise argparse.ArgumentTypeError(
f"The input value must be positive, got {float_value}"
)
return float_value
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid float value: {value}")


def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments.
Expand All @@ -66,7 +114,7 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument(
"--num_labels",
type=int,
type=validate_positive_integer,
default=4,
help="Number of labels in classification task",
)
Expand All @@ -85,18 +133,39 @@ def parse_args() -> argparse.Namespace:
# Training parameters
parser.add_argument(
"--batch_size",
type=int,
default=16,
type=validate_positive_integer,
default=32,
help="Batch size for training and evaluation",
)
parser.add_argument(
"--learning_rate", type=float, default=5e-5, help="Learning rate for optimizer"
"--learning_rate",
type=validate_positive_float,
default=5e-5,
help="Learning rate for optimizer",
)
parser.add_argument(
"--weight_decay",
type=validate_positive_float,
default=0.01,
help="Weight decay for optimizer",
)
parser.add_argument(
"--max_epochs",
type=validate_positive_integer,
default=2,
help="Number of epochs for training",
)
parser.add_argument(
"--weight_decay", type=float, default=0.01, help="Weight decay for optimizer"
"--patience",
type=validate_positive_integer,
default=3,
help="The number of epochs with no improvement in the monitored metric after which training will be stopped.",
)
parser.add_argument(
"--max_epochs", type=int, default=6, help="Number of epochs for training"
"--min_delta",
type=float,
default=0.0,
help="The minimum change in the monitored metric to qualify as an improvement.",
)
parser.add_argument(
"--precision",
Expand All @@ -106,13 +175,14 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument(
"--num_workers",
type=int,
type=validate_positive_integer,
default=6,
help="Number of worker threads for DataLoader",
)
parser.add_argument(
"--accelerator",
type=str,
choices=["cpu", "gpu", "hpu", "tpu", "mps", "auto"],
default="auto",
help="Type of accelerator to use for training. Options include 'cpu', 'gpu', 'hpu', 'tpu', 'mps', or 'auto' to automatically select the available hardware.",
)
Expand Down Expand Up @@ -520,6 +590,20 @@ def configure_optimizers(self) -> Tuple[List[torch.optim.AdamW], List[Dict]]:
}
return [optimizer], [lr_scheduler]

def on_validation_epoch_start(self) -> None:
"""
Resets the validation accuracy and F1 score metrics at the start of each validation epoch.
"""
self.val_acc.reset()
self.val_f1.reset()

def on_test_epoch_start(self) -> None:
"""
Resets the test accuracy and F1 score metrics at the start of each test epoch.
"""
self.test_acc.reset()
self.test_f1.reset()


def main() -> None:
"""
Expand Down Expand Up @@ -558,7 +642,13 @@ def main() -> None:
# Setup callbacks and logger
callbacks = [
ModelCheckpoint(save_top_k=1, mode="max", monitor="val_f1"),
EarlyStopping(monitor="val_f1", patience=3, mode="max", verbose=True),
EarlyStopping(
monitor="val_f1",
patience=args.patience,
min_delta=args.min_delta,
mode="max",
verbose=True,
),
]
logger = TensorBoardLogger(save_dir=args.log_dir, name=args.experiment_name)

Expand Down

0 comments on commit 605f22b

Please sign in to comment.