Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds fast gradient clipping support for the Embedding layer. #694

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions opacus/grad_sample/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .dp_multihead_attention import compute_sequence_bias_grad_sample # noqa
from .dp_rnn import compute_rnn_linear_grad_sample # noqa
from .embedding import compute_embedding_grad_sample # noqa
from .embedding_norm_sample import compute_embedding_norm_sample # noqa
from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample
from .grad_sample_module_fast_gradient_clipping import ( # noqa
GradSampleModuleFastGradientClipping,
Expand Down
26 changes: 24 additions & 2 deletions opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import Dict, List

import torch
import torch.nn as nn

from .utils import register_grad_sampler
from opacus.grad_sample import embedding_norm_sample
from .utils import register_grad_sampler, register_norm_sampler


@register_grad_sampler(nn.Embedding)
Expand Down Expand Up @@ -82,3 +83,24 @@ def compute_embeddingbag_gradsampler(layer, inputs, backprops):
ret[layer.weight] = gsm

return ret


@register_norm_sampler(nn.Embedding)
def compute_embedding_norm_sample(
layer: nn.Embedding,
activations: List[torch.Tensor],
backprops: torch.Tensor,
) -> Dict[nn.Parameter, torch.Tensor]:
"""Computes gradient norms for ``nn.Embedding`` layer.

Args:
layer: Layer
activations: Activations
backprops: Backpropagations

Returns:
A dictionary of parameter gradients
"""
return embedding_norm_sample.compute_embedding_norm_sample(
layer, activations, backprops
)
148 changes: 148 additions & 0 deletions opacus/grad_sample/embedding_norm_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Copyright 2024, The Opacus authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility for computing gradient norm for the embedding layer.

Based on the algorithm from the paper:
https://proceedings.neurips.cc/paper_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf.
"""
from typing import Dict, List

import torch
from torch import nn


def compute_embedding_norm_sample(
layer: nn.Embedding,
activations: List[torch.Tensor],
backprops: torch.Tensor,
) -> Dict[nn.Parameter, torch.Tensor]:
"""Computes per sample gradient norms for ``nn.Embedding`` layer.

Args:
layer: Layer
activations: Activations
backprops: Backpropagations

Returns:
A dictionary of parameter gradients

NOTE: Here is an example input, and the expected intermediate values. This
is proivided to help in understanding the algorithm:
Inputs:
layer: Embedding(3, 1) # (vocab_size, embedding_dim)
activations: [tensor([[1, 1],
[2, 0],
[2, 0]])]
backprops: tensor([[[0.2], [0.2]],
[[0.3], [0.1]],
[[0.3], [0.1]]])
backprops.shape: torch.Size([3, 2, 1])

Intermediate values:
input_ids: tensor([[1, 1],
[2, 0],
[2, 0]])
input_ids.shape: torch.Size([3, 2])
grad_values: tensor([[0.2000],
[0.2000],
[0.3000],
[0.1000],
[0.3000],
[0.1000]])
grad_values.shape: torch.Size([6, 1])
nrows: 3
ncols: 2
row_indices: tensor([[0],
[0],
[1],
[1],
[2],
[2]])
flattened_indices: tensor([[1],
[1],
[2],
[0],
[2],
[0]])
paired_indices: tensor([[0, 1],
[0, 1],
[1, 2],
[1, 0],
[2, 2],
[2, 0]])
unique_paired_indices: tensor([[0, 1],
[1, 0],
[1, 2],
[2, 0],
[2, 2]])
new_index_positions: tensor([0, 0, 2, 1, 4, 3])
num_unique_paired_indices: 5
summed_gradients: tensor([[0.4000],
[0.1000],
[0.3000],
[0.1000],
[0.3000]])
sqr_gradient_sum: tensor([0.1600, 0.0100, 0.0900, 0.0100, 0.0900])
unique_batch_ids: tensor([0, 1, 1, 2, 2])
result: tensor([0.1600, 0.1000, 0.1000])
result_sqrt: tensor([0.4000, 0.3162, 0.3162])
"""
device = activations[0].device
input_ids = activations[0].to(device)
grad_values = backprops.to(device)

# Reshape input_ids preserving the batch size as the first dimension
input_ids = input_ids.reshape(input_ids.shape[0], -1)

# Reshape grad_values preserving the embedding dimension as the last dimension
grad_values = grad_values.reshape(-1, grad_values.size(-1))

# Create 1D tensor of row indices
nrows = input_ids.size(0)
ncols = input_ids.size(1)
row_indices = (
torch.repeat_interleave(torch.arange(nrows).to(device), ncols)
.unsqueeze(-1)
.to(device)
)

# Pair the input IDs with the row indices
flattened_indices = input_ids.view(-1, 1)
paired_indices = torch.cat([row_indices, flattened_indices], dim=1).to(device)

# Get unique paired indices and new index positions for aggregation
unique_paired_indices, new_index_positions = torch.unique(
paired_indices, dim=0, return_inverse=True, sorted=True
)

# Sum gradients over new index positions and compute squared gradient norms
num_unique_paired_indices = unique_paired_indices.size(0)
summed_gradients = torch.zeros(
num_unique_paired_indices, grad_values.size(-1), device=device
)
summed_gradients = summed_gradients.index_add(
0, new_index_positions.to(device), grad_values
)
sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1)

# Scatter add the squared sums back to their respective rows
result = torch.zeros(nrows, device=device)
unique_batch_ids = unique_paired_indices[:, 0].to(device)
result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum)

# Compute the square root for the final result (norm)
result_sqrt = torch.sqrt(result)
return {layer.weight: result_sqrt}
164 changes: 164 additions & 0 deletions opacus/tests/grad_sample_module_fast_gradient_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import unittest

import hypothesis.strategies as st
import torch
Expand Down Expand Up @@ -67,6 +68,21 @@ def forward(self, x):
return x


class SampleEmbeddingModule(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(SampleEmbeddingModule, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)

# Manually set weights for the embedding layer for testing
self.embedding.weight = nn.Parameter(
torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32)
)

def forward(self, x):
x = self.embedding(x)
return x


class GradSampleModuleFastGradientClippingTest(GradSampleModuleTest):
CLS = GradSampleModuleFastGradientClipping

Expand Down Expand Up @@ -260,3 +276,151 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
logging.info(f"Diff = {diff}")
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg


class GradSampleModuleFastGradientClippingEmbeddingLayerTest(unittest.TestCase):

def test_norm_calculation(self):
"""
Tests if norm calculation for embedding layer is the same between
standard (Opacus) and fast gradient clipping"
"""
vocab_size = 3
embedding_dim = 1

criterion = torch.nn.CrossEntropyLoss(reduction="none")
noise_multiplier = 0.0
input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)
batch_size = 3
max_grad_norm = 1.0
sample_module = SampleEmbeddingModule(vocab_size, embedding_dim)
model_normal = GradSampleModule(clone_module(sample_module))
optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)
optimizer_normal = DPOptimizer(
optimizer_normal,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

grad_sample_module = GradSampleModuleFastGradientClipping(
clone_module(sample_module),
max_grad_norm=max_grad_norm,
use_ghost_clipping=True,
)
optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1)
optimizer_gc = DPOptimizerFastGradientClipping(
optimizer_gc,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

optimizer_normal.zero_grad()
output_normal = model_normal(input_data)
target_data = torch.rand_like(output_normal)

loss_normal = torch.mean(criterion(output_normal, target_data), dim=0)
loss_normal.backward()
all_norms_normal = torch.stack(
[
torch.stack([g.norm() for g in param.grad_sample], dim=0)
for param in model_normal.parameters()
],
dim=0,
)
flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal])

grad_sample_module.enable_hooks()
output_gc = grad_sample_module(input_data)

first_loss_per_sample = criterion(output_gc, target_data)
first_loss = torch.mean(first_loss_per_sample)
first_loss.backward(retain_graph=True)

optimizer_gc.zero_grad()
coeff = grad_sample_module.get_clipping_coef()
second_loss_per_sample = coeff * first_loss_per_sample
second_loss = torch.sum(second_loss_per_sample)
grad_sample_module.disable_hooks()
second_loss.backward()

all_norms_gc = [param._norm_sample for param in grad_sample_module.parameters()]
flat_norms_gc = torch.cat([p.flatten() for p in all_norms_gc])

diff = flat_norms_normal - flat_norms_gc

logging.info(f"Diff = {diff}")
msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg

def test_gradient_calculation(self):
"""Tests if gradients for embedding layer are the same between standard
(Opacus) and fast gradient clipping."""

noise_multiplier = 0.0
vocab_size = 3
embedding_dim = 1
batch_size = 3
input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)
max_grad_norm = 1.0
criterion = torch.nn.CrossEntropyLoss()

sample_module = SampleEmbeddingModule(vocab_size, embedding_dim)
model_normal = GradSampleModule(clone_module(sample_module))
grad_sample_module = GradSampleModuleFastGradientClipping(
clone_module(sample_module),
max_grad_norm=max_grad_norm,
use_ghost_clipping=True,
)

optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)
optimizer_normal = DPOptimizer(
optimizer_normal,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1)
optimizer_gc = DPOptimizerFastGradientClipping(
optimizer_gc,
noise_multiplier=noise_multiplier,
max_grad_norm=max_grad_norm,
expected_batch_size=batch_size,
)

criterion_gc = DPLossFastGradientClipping(
grad_sample_module, optimizer_gc, criterion
)

optimizer_normal.zero_grad()
output_normal = model_normal(input_data)
target_data = torch.tensor([[[0.1], [0.1]], [[0.2], [0.3]], [[0.2], [0.3]]])
loss_normal = torch.mean(criterion(output_normal, target_data), dim=0)
loss_normal.backward()
optimizer_normal.step()

all_grads_normal = [param.summed_grad for param in model_normal.parameters()]
flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal])

optimizer_gc.zero_grad()
grad_sample_module.enable_hooks()
output_gc = grad_sample_module(input_data)

loss_gc = criterion_gc(output_gc, target_data)
loss_gc.backward()
optimizer_gc.step()

all_grads_gc = [param.grad for param in grad_sample_module.parameters()]
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])
diff = torch.tensor(
[
(g_gc - g_normal).norm()
for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal)
]
)

logging.info(f"Diff = {diff}")
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
Loading
Loading