Skip to content

Commit

Permalink
Add **kwargs to all optimizer classes (#710)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #710

Purpose: To enable creating custom PrivacyEngines that extend the PrivacyEngine class and take in additional parameters.

Fix prior diff: D67456352

Reviewed By: HuanyuZhang

Differential Revision: D67953655

fbshipit-source-id: 70aef7571e012a370d6a0fd04948eccee06c9a0d
  • Loading branch information
iden-kalemaj authored and facebook-github-bot committed Jan 15, 2025
1 parent 3934851 commit 4b0cd91
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions opacus/optimizers/adaclipoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
super().__init__(
optimizer,
Expand Down
2 changes: 2 additions & 0 deletions opacus/optimizers/ddp_perlayeroptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
Expand Down
2 changes: 2 additions & 0 deletions opacus/optimizers/ddpoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
super().__init__(
optimizer,
Expand All @@ -47,6 +48,7 @@ def __init__(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
**kwargs,
)
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
Expand Down
2 changes: 2 additions & 0 deletions opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
super().__init__(
optimizer,
Expand All @@ -47,6 +48,7 @@ def __init__(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
**kwargs,
)
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
Expand Down
1 change: 1 addition & 0 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
"""
Expand Down
2 changes: 2 additions & 0 deletions opacus/optimizers/optimizer_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
**kwargs,
)

@property
Expand Down
2 changes: 2 additions & 0 deletions opacus/optimizers/perlayeroptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
**kwargs,
):
assert len(max_grad_norm) == len(params(optimizer))
self.max_grad_norms = max_grad_norm
Expand All @@ -51,6 +52,7 @@ def __init__(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
**kwargs,
)

def clip_and_accumulate(self):
Expand Down
1 change: 1 addition & 0 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _prepare_optimizer(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=self.secure_mode,
**kwargs,
)

def _prepare_data_loader(
Expand Down

0 comments on commit 4b0cd91

Please sign in to comment.