-
Notifications
You must be signed in to change notification settings - Fork 353
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [x] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue It introduces a set of new optimizers called DiSK, which uses a simplified Kalman filter to improve optimizer performance. ## How Has This Been Tested (if it applies) It is tested with the mnist.py from the example folder (with modifications for DiSK) to ensure all the functions work. ## Checklist Not sure whether to add documents. - [ ] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: #706 Reviewed By: HuanyuZhang Differential Revision: D67626897 Pulled By: iden-kalemaj fbshipit-source-id: 3ac3caf5212920afdae7b4a8ef71bd3868073731
- Loading branch information
1 parent
4b0cd91
commit b4c075d
Showing
10 changed files
with
1,004 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from typing import List, Union | ||
|
||
from opacus.optimizers import DPOptimizer | ||
from opacus.privacy_engine import PrivacyEngine | ||
from torch import optim | ||
|
||
from .optimizers import KF_DPOptimizer, get_optimizer_class | ||
|
||
|
||
class KF_PrivacyEngine(PrivacyEngine): | ||
def __init__(self, *, accountant: str = "prv", secure_mode: bool = False): | ||
super().__init__(accountant=accountant, secure_mode=secure_mode) | ||
|
||
def _prepare_optimizer( | ||
self, | ||
*, | ||
optimizer: optim.Optimizer, | ||
noise_multiplier: float, | ||
max_grad_norm: Union[float, List[float]], | ||
expected_batch_size: int, | ||
loss_reduction: str = "mean", | ||
distributed: bool = False, | ||
clipping: str = "flat", | ||
noise_generator=None, | ||
grad_sample_mode="hooks", | ||
kalman: bool = False, | ||
**kwargs, | ||
) -> DPOptimizer: | ||
if kalman and isinstance(optimizer, KF_DPOptimizer): | ||
optimizer = optimizer.original_optimizer | ||
elif not kalman and isinstance(optimizer, DPOptimizer): | ||
optimizer = optimizer.original_optimizer | ||
|
||
generator = None | ||
if self.secure_mode: | ||
generator = self.secure_rng | ||
elif noise_generator is not None: | ||
generator = noise_generator | ||
|
||
optim_class = get_optimizer_class( | ||
clipping=clipping, | ||
distributed=distributed, | ||
grad_sample_mode=grad_sample_mode, | ||
kalman=kalman, | ||
) | ||
|
||
return optim_class( | ||
optimizer=optimizer, | ||
noise_multiplier=noise_multiplier, | ||
max_grad_norm=max_grad_norm, | ||
expected_batch_size=expected_batch_size, | ||
loss_reduction=loss_reduction, | ||
generator=generator, | ||
secure_mode=self.secure_mode, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# DiSK: Differentially Private Optimizer with Simplified Kalman Filter for Noise Reduction | ||
|
||
## Introduction | ||
This part of the code introduces a new component to the optimizer named DiSK. The code uses a simplifed Kalman to improve the privatized gradient estimate. Speficially, the privatized minibatch gradient is replaced with: | ||
|
||
$$\mathbb{g_{t+\frac{1}{2}}} = \frac{1}{B}\sum_{\xi \in \mathcal{B}_t} \mathrm{clip}_C\left(\frac{1-\kappa}{\kappa\gamma}\nabla f(x_t + \gamma(x_t-x_{t-1});\xi) + \Big(1- \frac{1-\kappa}{\kappa\gamma}\Big)\nabla f(x_t;\xi)\right) + w_t$$ | ||
$$g_{t}= (1-\kappa)g_{t-1} + \kappa g_{t+\frac{1}{2}}$$ | ||
|
||
A detailed description of the algorithm can be found at [Here](https://arxiv.org/abs/2410.03883). | ||
|
||
## Usage | ||
The code provides a modified privacy engine with three extra arguments: | ||
* kamlan: bool=False | ||
* kappa: float=0.7 | ||
* gamma: float=0.5 | ||
|
||
To use DiSK, follow the steps: | ||
|
||
**Step I:** Import KF_PrivacyEngine from KFprivacy_engine.py and set ```kalman=True``` | ||
|
||
**Step II:** Define a closure (see [here](https://pytorch.org/docs/stable/optim.html#optimizer-step-closure) for example) to compute loss and backward **without** ```zero_grad()``` and perform ```optimizer.step(closure)``` | ||
|
||
Example of using the DiSK optimizers: | ||
|
||
```python | ||
from KFprivacy_engine import KF_PrivacyEngine | ||
# ... | ||
# follow the same steps as original opacus training scripts | ||
privacy_engine = KF_PrivacyEngine() | ||
model, optimizer, train_loader = privacy_engine.make_private( | ||
module=model, | ||
optimizer=optimizer, | ||
data_loader=train_loader, | ||
noise_multiplier=args.sigma, | ||
max_grad_norm=max_grad_norm, | ||
clipping=clipping, | ||
grad_sample_mode=args.grad_sample_mode, | ||
kalman=True, # need this argument | ||
kappa=0.7, # optional | ||
gamma=0.5 # optional | ||
) | ||
|
||
# ... | ||
# during training: | ||
def closure(): # compute loss and backward, an example adapting the one used in examples/cifar10.py | ||
output = model(images) | ||
loss = criterion(output, target) | ||
loss.backward() | ||
return output, loss | ||
output, loss = optimizer.step(closure) | ||
optimizer.zero_grad() | ||
# compute other matrices | ||
# ... | ||
``` | ||
|
||
## Citation | ||
Consider citing the paper is you use DiSK in your papers, as follows: | ||
|
||
``` | ||
@article{zhang2024disk, | ||
title={{DiSK}: Differentially private optimizer with simplified kalman filter for noise reduction}, | ||
author={Zhang, Xinwei and Bu, Zhiqi and Balle, Borja and Hong, Mingyi and Razaviyayn, Meisam and Mirrokni, Vahab}, | ||
journal={arXiv preprint arXiv:2410.03883}, | ||
year={2024} | ||
} | ||
``` | ||
|
||
Contributer: Xinwei Zhang. Email: [[email protected]](mailto:[email protected]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import math | ||
from typing import Optional | ||
|
||
import torch | ||
from opacus.optimizers.adaclipoptimizer import AdaClipDPOptimizer | ||
from torch.optim import Optimizer | ||
from torch.optim.optimizer import required | ||
|
||
from .KFoptimizer import KF_DPOptimizer | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class KF_AdaClipDPOptimizer(AdaClipDPOptimizer, KF_DPOptimizer): | ||
def __init__( | ||
self, | ||
optimizer: Optimizer, | ||
*, | ||
noise_multiplier: float, | ||
target_unclipped_quantile: float, | ||
clipbound_learning_rate: float, | ||
max_clipbound: float, | ||
min_clipbound: float, | ||
unclipped_num_std: float, | ||
max_grad_norm: float, | ||
expected_batch_size: Optional[int], | ||
loss_reduction: str = "mean", | ||
generator=None, | ||
secure_mode: bool = False, | ||
kappa: float = 0.7, | ||
gamma: float = 0.5, | ||
): | ||
if gamma == 0 or abs(gamma - (1 - kappa) / kappa) < 1e-3: | ||
gamma = (1 - kappa) / kappa | ||
self.kf_compute_grad_at_original = False | ||
else: | ||
self.scaling_factor = (1 - kappa) / ( | ||
gamma * kappa | ||
) # (gamma*kappa+kappa-1)/(1-kappa) | ||
self.kf_compute_grad_at_original = True | ||
c = (1 - kappa) / (gamma * kappa) | ||
norm_factor = math.sqrt(c**2 + (1 - c) ** 2) | ||
noise_multiplier = noise_multiplier / norm_factor | ||
super(AdaClipDPOptimizer).__init__( | ||
optimizer, | ||
noise_multiplier=noise_multiplier, | ||
max_grad_norm=max_grad_norm, | ||
expected_batch_size=expected_batch_size, | ||
loss_reduction=loss_reduction, | ||
generator=generator, | ||
secure_mode=secure_mode, | ||
target_unclipped_quantile=target_unclipped_quantile, | ||
clipbound_learning_rate=clipbound_learning_rate, | ||
max_clipbound=max_clipbound, | ||
min_clipbound=min_clipbound, | ||
unclipped_num_std=unclipped_num_std, | ||
) | ||
self.kappa = kappa | ||
self.gamma = gamma | ||
|
||
def step(self, closure=required) -> Optional[float]: | ||
if self.kf_compute_grad_at_original: | ||
loss = self._compute_two_closure(closure) | ||
else: | ||
loss = self._compute_one_closure(closure) | ||
|
||
if self.pre_step(): | ||
tmp_states = [] | ||
first_step = False | ||
for p in self.params: | ||
grad = p.grad | ||
state = self.state[p] | ||
if "kf_d_t" not in state: | ||
state = dict() | ||
first_step = True | ||
state["kf_d_t"] = torch.zeros_like(p.data).to(p.data) | ||
state["kf_m_t"] = grad.clone().to(p.data) | ||
state["kf_m_t"].lerp_(grad, weight=self.kappa) | ||
p.grad = state["kf_m_t"].clone().to(p.data) | ||
state["kf_d_t"] = -p.data.clone().to(p.data) | ||
if first_step: | ||
tmp_states.append(state) | ||
self.original_optimizer.step() | ||
for p in self.params: | ||
if first_step: | ||
tmp_state = tmp_states.pop(0) | ||
self.state[p]["kf_d_t"] = tmp_state["kf_d_t"] | ||
self.state[p]["kf_m_t"] = tmp_state["kf_m_t"] | ||
del tmp_state | ||
self.state[p]["kf_d_t"].add_(p.data, alpha=1.0) | ||
return loss |
160 changes: 160 additions & 0 deletions
160
research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from __future__ import annotations | ||
|
||
from functools import partial | ||
from typing import Callable, List, Optional | ||
|
||
import torch | ||
from opacus.optimizers.ddp_perlayeroptimizer import _clip_and_accumulate_parameter | ||
from opacus.optimizers.optimizer import _generate_noise | ||
from torch import nn | ||
from torch.optim import Optimizer | ||
|
||
from .KFddpoptimizer import KF_DistributedDPOptimizer | ||
from .KFoptimizer import KF_DPOptimizer | ||
from .KFperlayeroptimizer import KF_DPPerLayerOptimizer | ||
|
||
|
||
class KF_SimpleDistributedPerLayerOptimizer( | ||
KF_DPPerLayerOptimizer, KF_DistributedDPOptimizer | ||
): | ||
def __init__( | ||
self, | ||
optimizer: Optimizer, | ||
*, | ||
noise_multiplier: float, | ||
max_grad_norm: float, | ||
expected_batch_size: Optional[int], | ||
loss_reduction: str = "mean", | ||
generator=None, | ||
secure_mode: bool = False, | ||
kappa: float = 0.7, | ||
gamma: float = 0.5, | ||
): | ||
self.rank = torch.distributed.get_rank() | ||
self.world_size = torch.distributed.get_world_size() | ||
|
||
super().__init__( | ||
optimizer, | ||
noise_multiplier=noise_multiplier, | ||
max_grad_norm=max_grad_norm, | ||
expected_batch_size=expected_batch_size, | ||
loss_reduction=loss_reduction, | ||
generator=generator, | ||
secure_mode=secure_mode, | ||
kappa=kappa, | ||
gamma=gamma, | ||
) | ||
|
||
|
||
class KF_DistributedPerLayerOptimizer(KF_DPOptimizer): | ||
""" | ||
:class:`~opacus.optimizers.optimizer.DPOptimizer` that implements | ||
per layer clipping strategy and is compatible with distributed data parallel | ||
""" | ||
|
||
def __init__( | ||
self, | ||
optimizer: Optimizer, | ||
*, | ||
noise_multiplier: float, | ||
max_grad_norm: List[float], | ||
expected_batch_size: Optional[int], | ||
loss_reduction: str = "mean", | ||
generator=None, | ||
secure_mode: bool = False, | ||
kappa: float = 0.7, | ||
gamma: float = 0.5, | ||
): | ||
self.rank = torch.distributed.get_rank() | ||
self.world_size = torch.distributed.get_world_size() | ||
self.max_grad_norms = max_grad_norm | ||
max_grad_norm = torch.norm(torch.Tensor(self.max_grad_norms), p=2).item() | ||
super().__init__( | ||
optimizer, | ||
noise_multiplier=noise_multiplier, | ||
max_grad_norm=max_grad_norm, | ||
expected_batch_size=expected_batch_size, | ||
loss_reduction=loss_reduction, | ||
generator=generator, | ||
secure_mode=secure_mode, | ||
kappa=kappa, | ||
gamma=gamma, | ||
) | ||
self._register_hooks() | ||
|
||
def _add_noise_parameter(self, p: nn.Parameter): | ||
""" | ||
The reason why we need self is because of generator for secure_mode | ||
""" | ||
noise = _generate_noise( | ||
std=self.noise_multiplier * self.max_grad_norm, | ||
reference=p.summed_grad, | ||
generator=None, | ||
secure_mode=self.secure_mode, | ||
) | ||
p.grad = p.summed_grad + noise | ||
|
||
@property | ||
def accumulated_iterations(self) -> int: | ||
return max([p.accumulated_iterations for p in self.params]) | ||
|
||
def _scale_grad_parameter(self, p: nn.Parameter): | ||
if not hasattr(p, "accumulated_iterations"): | ||
p.accumulated_iterations = 0 | ||
p.accumulated_iterations += 1 | ||
if self.loss_reduction == "mean": | ||
p.grad /= ( | ||
self.expected_batch_size * p.accumulated_iterations * self.world_size | ||
) | ||
|
||
def clip_and_accumulate(self): | ||
raise NotImplementedError( | ||
"Clip and accumulate is added per layer in DPDDP Per Layer." | ||
) | ||
|
||
def add_noise(self): | ||
raise NotImplementedError("Noise is added per layer in DPDDP Per Layer.") | ||
|
||
def pre_step( | ||
self, closure: Optional[Callable[[], float]] = None | ||
) -> Optional[float]: | ||
if self._check_skip_next_step(): | ||
self._is_last_step_skipped = True | ||
return False | ||
|
||
if self.step_hook: | ||
self.step_hook(self) | ||
|
||
for p in self.params: | ||
p.accumulated_iterations = 0 | ||
|
||
self._is_last_step_skipped = False | ||
return True | ||
|
||
def _ddp_per_layer_hook( | ||
self, p: nn.Parameter, max_grad_norm: float, _: torch.Tensor | ||
): | ||
_clip_and_accumulate_parameter(p, max_grad_norm) | ||
# Equivalent ot _check_skip_next_step but without popping because it has to be done for every parameter p | ||
if self._check_skip_next_step(pop_next=False): | ||
return | ||
|
||
if self.rank == 0: | ||
self._add_noise_parameter(p) | ||
else: | ||
p.grad = p.summed_grad | ||
self._scale_grad_parameter(p) | ||
|
||
return p.grad | ||
|
||
def _register_hooks(self): | ||
for p, max_grad_norm in zip(self.params, self.max_grad_norms): | ||
if not p.requires_grad: | ||
continue | ||
|
||
if not hasattr(p, "ddp_hooks"): | ||
p.ddp_hooks = [] | ||
|
||
p.ddp_hooks.append( | ||
p.register_hook(partial(self._ddp_per_layer_hook, p, max_grad_norm)) | ||
) |
Oops, something went wrong.