Skip to content

Commit

Permalink
Disk (#706)
Browse files Browse the repository at this point in the history
Summary:
## Types of changes

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Docs change / refactoring / dependency upgrade

## Motivation and Context / Related issue

It introduces a set of new optimizers called DiSK, which uses a simplified Kalman filter to improve optimizer performance.

## How Has This Been Tested (if it applies)

It is tested with the mnist.py from the example folder (with modifications for DiSK) to ensure all the functions work.

## Checklist

Not sure whether to add documents.

- [ ] The documentation is up-to-date with the changes I made.
- [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**).
- [x] All tests passed, and additional code has been covered with new tests.

Pull Request resolved: #706

Reviewed By: HuanyuZhang

Differential Revision: D67626897

Pulled By: iden-kalemaj

fbshipit-source-id: 3ac3caf5212920afdae7b4a8ef71bd3868073731
  • Loading branch information
564612540 authored and facebook-github-bot committed Jan 17, 2025
1 parent 4b0cd91 commit b4c075d
Show file tree
Hide file tree
Showing 10 changed files with 1,004 additions and 0 deletions.
56 changes: 56 additions & 0 deletions research/disk_optimizer/KFprivacy_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import List, Union

from opacus.optimizers import DPOptimizer
from opacus.privacy_engine import PrivacyEngine
from torch import optim

from .optimizers import KF_DPOptimizer, get_optimizer_class


class KF_PrivacyEngine(PrivacyEngine):
def __init__(self, *, accountant: str = "prv", secure_mode: bool = False):
super().__init__(accountant=accountant, secure_mode=secure_mode)

def _prepare_optimizer(
self,
*,
optimizer: optim.Optimizer,
noise_multiplier: float,
max_grad_norm: Union[float, List[float]],
expected_batch_size: int,
loss_reduction: str = "mean",
distributed: bool = False,
clipping: str = "flat",
noise_generator=None,
grad_sample_mode="hooks",
kalman: bool = False,
**kwargs,
) -> DPOptimizer:
if kalman and isinstance(optimizer, KF_DPOptimizer):
optimizer = optimizer.original_optimizer
elif not kalman and isinstance(optimizer, DPOptimizer):
optimizer = optimizer.original_optimizer

generator = None
if self.secure_mode:
generator = self.secure_rng
elif noise_generator is not None:
generator = noise_generator

optim_class = get_optimizer_class(
clipping=clipping,
distributed=distributed,
grad_sample_mode=grad_sample_mode,
kalman=kalman,
)

return optim_class(
optimizer=optimizer,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=expected_batch_size,
loss_reduction=loss_reduction,
generator=generator,
secure_mode=self.secure_mode,
**kwargs,
)
69 changes: 69 additions & 0 deletions research/disk_optimizer/ReadMe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# DiSK: Differentially Private Optimizer with Simplified Kalman Filter for Noise Reduction

## Introduction
This part of the code introduces a new component to the optimizer named DiSK. The code uses a simplifed Kalman to improve the privatized gradient estimate. Speficially, the privatized minibatch gradient is replaced with:

$$\mathbb{g_{t+\frac{1}{2}}} = \frac{1}{B}\sum_{\xi \in \mathcal{B}_t} \mathrm{clip}_C\left(\frac{1-\kappa}{\kappa\gamma}\nabla f(x_t + \gamma(x_t-x_{t-1});\xi) + \Big(1- \frac{1-\kappa}{\kappa\gamma}\Big)\nabla f(x_t;\xi)\right) + w_t$$
$$g_{t}= (1-\kappa)g_{t-1} + \kappa g_{t+\frac{1}{2}}$$

A detailed description of the algorithm can be found at [Here](https://arxiv.org/abs/2410.03883).

## Usage
The code provides a modified privacy engine with three extra arguments:
* kamlan: bool=False
* kappa: float=0.7
* gamma: float=0.5

To use DiSK, follow the steps:

**Step I:** Import KF_PrivacyEngine from KFprivacy_engine.py and set ```kalman=True```

**Step II:** Define a closure (see [here](https://pytorch.org/docs/stable/optim.html#optimizer-step-closure) for example) to compute loss and backward **without** ```zero_grad()``` and perform ```optimizer.step(closure)```

Example of using the DiSK optimizers:

```python
from KFprivacy_engine import KF_PrivacyEngine
# ...
# follow the same steps as original opacus training scripts
privacy_engine = KF_PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=train_loader,
noise_multiplier=args.sigma,
max_grad_norm=max_grad_norm,
clipping=clipping,
grad_sample_mode=args.grad_sample_mode,
kalman=True, # need this argument
kappa=0.7, # optional
gamma=0.5 # optional
)

# ...
# during training:
def closure(): # compute loss and backward, an example adapting the one used in examples/cifar10.py
output = model(images)
loss = criterion(output, target)
loss.backward()
return output, loss
output, loss = optimizer.step(closure)
optimizer.zero_grad()
# compute other matrices
# ...
```

## Citation
Consider citing the paper is you use DiSK in your papers, as follows:

```
@article{zhang2024disk,
title={{DiSK}: Differentially private optimizer with simplified kalman filter for noise reduction},
author={Zhang, Xinwei and Bu, Zhiqi and Balle, Borja and Hong, Mingyi and Razaviyayn, Meisam and Mirrokni, Vahab},
journal={arXiv preprint arXiv:2410.03883},
year={2024}
}
```

Contributer: Xinwei Zhang. Email: [[email protected]](mailto:[email protected])

95 changes: 95 additions & 0 deletions research/disk_optimizer/optimizers/KFadaclipoptimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

import logging
import math
from typing import Optional

import torch
from opacus.optimizers.adaclipoptimizer import AdaClipDPOptimizer
from torch.optim import Optimizer
from torch.optim.optimizer import required

from .KFoptimizer import KF_DPOptimizer


logger = logging.getLogger(__name__)


class KF_AdaClipDPOptimizer(AdaClipDPOptimizer, KF_DPOptimizer):
def __init__(
self,
optimizer: Optimizer,
*,
noise_multiplier: float,
target_unclipped_quantile: float,
clipbound_learning_rate: float,
max_clipbound: float,
min_clipbound: float,
unclipped_num_std: float,
max_grad_norm: float,
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
kappa: float = 0.7,
gamma: float = 0.5,
):
if gamma == 0 or abs(gamma - (1 - kappa) / kappa) < 1e-3:
gamma = (1 - kappa) / kappa
self.kf_compute_grad_at_original = False
else:
self.scaling_factor = (1 - kappa) / (
gamma * kappa
) # (gamma*kappa+kappa-1)/(1-kappa)
self.kf_compute_grad_at_original = True
c = (1 - kappa) / (gamma * kappa)
norm_factor = math.sqrt(c**2 + (1 - c) ** 2)
noise_multiplier = noise_multiplier / norm_factor
super(AdaClipDPOptimizer).__init__(
optimizer,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=expected_batch_size,
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
target_unclipped_quantile=target_unclipped_quantile,
clipbound_learning_rate=clipbound_learning_rate,
max_clipbound=max_clipbound,
min_clipbound=min_clipbound,
unclipped_num_std=unclipped_num_std,
)
self.kappa = kappa
self.gamma = gamma

def step(self, closure=required) -> Optional[float]:
if self.kf_compute_grad_at_original:
loss = self._compute_two_closure(closure)
else:
loss = self._compute_one_closure(closure)

if self.pre_step():
tmp_states = []
first_step = False
for p in self.params:
grad = p.grad
state = self.state[p]
if "kf_d_t" not in state:
state = dict()
first_step = True
state["kf_d_t"] = torch.zeros_like(p.data).to(p.data)
state["kf_m_t"] = grad.clone().to(p.data)
state["kf_m_t"].lerp_(grad, weight=self.kappa)
p.grad = state["kf_m_t"].clone().to(p.data)
state["kf_d_t"] = -p.data.clone().to(p.data)
if first_step:
tmp_states.append(state)
self.original_optimizer.step()
for p in self.params:
if first_step:
tmp_state = tmp_states.pop(0)
self.state[p]["kf_d_t"] = tmp_state["kf_d_t"]
self.state[p]["kf_m_t"] = tmp_state["kf_m_t"]
del tmp_state
self.state[p]["kf_d_t"].add_(p.data, alpha=1.0)
return loss
160 changes: 160 additions & 0 deletions research/disk_optimizer/optimizers/KFddp_perlayeroptimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

from functools import partial
from typing import Callable, List, Optional

import torch
from opacus.optimizers.ddp_perlayeroptimizer import _clip_and_accumulate_parameter
from opacus.optimizers.optimizer import _generate_noise
from torch import nn
from torch.optim import Optimizer

from .KFddpoptimizer import KF_DistributedDPOptimizer
from .KFoptimizer import KF_DPOptimizer
from .KFperlayeroptimizer import KF_DPPerLayerOptimizer


class KF_SimpleDistributedPerLayerOptimizer(
KF_DPPerLayerOptimizer, KF_DistributedDPOptimizer
):
def __init__(
self,
optimizer: Optimizer,
*,
noise_multiplier: float,
max_grad_norm: float,
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
kappa: float = 0.7,
gamma: float = 0.5,
):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()

super().__init__(
optimizer,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=expected_batch_size,
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
kappa=kappa,
gamma=gamma,
)


class KF_DistributedPerLayerOptimizer(KF_DPOptimizer):
"""
:class:`~opacus.optimizers.optimizer.DPOptimizer` that implements
per layer clipping strategy and is compatible with distributed data parallel
"""

def __init__(
self,
optimizer: Optimizer,
*,
noise_multiplier: float,
max_grad_norm: List[float],
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
kappa: float = 0.7,
gamma: float = 0.5,
):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
self.max_grad_norms = max_grad_norm
max_grad_norm = torch.norm(torch.Tensor(self.max_grad_norms), p=2).item()
super().__init__(
optimizer,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=expected_batch_size,
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
kappa=kappa,
gamma=gamma,
)
self._register_hooks()

def _add_noise_parameter(self, p: nn.Parameter):
"""
The reason why we need self is because of generator for secure_mode
"""
noise = _generate_noise(
std=self.noise_multiplier * self.max_grad_norm,
reference=p.summed_grad,
generator=None,
secure_mode=self.secure_mode,
)
p.grad = p.summed_grad + noise

@property
def accumulated_iterations(self) -> int:
return max([p.accumulated_iterations for p in self.params])

def _scale_grad_parameter(self, p: nn.Parameter):
if not hasattr(p, "accumulated_iterations"):
p.accumulated_iterations = 0
p.accumulated_iterations += 1
if self.loss_reduction == "mean":
p.grad /= (
self.expected_batch_size * p.accumulated_iterations * self.world_size
)

def clip_and_accumulate(self):
raise NotImplementedError(
"Clip and accumulate is added per layer in DPDDP Per Layer."
)

def add_noise(self):
raise NotImplementedError("Noise is added per layer in DPDDP Per Layer.")

def pre_step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
if self._check_skip_next_step():
self._is_last_step_skipped = True
return False

if self.step_hook:
self.step_hook(self)

for p in self.params:
p.accumulated_iterations = 0

self._is_last_step_skipped = False
return True

def _ddp_per_layer_hook(
self, p: nn.Parameter, max_grad_norm: float, _: torch.Tensor
):
_clip_and_accumulate_parameter(p, max_grad_norm)
# Equivalent ot _check_skip_next_step but without popping because it has to be done for every parameter p
if self._check_skip_next_step(pop_next=False):
return

if self.rank == 0:
self._add_noise_parameter(p)
else:
p.grad = p.summed_grad
self._scale_grad_parameter(p)

return p.grad

def _register_hooks(self):
for p, max_grad_norm in zip(self.params, self.max_grad_norms):
if not p.requires_grad:
continue

if not hasattr(p, "ddp_hooks"):
p.ddp_hooks = []

p.ddp_hooks.append(
p.register_hook(partial(self._ddp_per_layer_hook, p, max_grad_norm))
)
Loading

0 comments on commit b4c075d

Please sign in to comment.