Skip to content

Commit

Permalink
Incorporate code review feedback on PR #694
Browse files Browse the repository at this point in the history
  • Loading branch information
pagarwl committed Jan 12, 2025
1 parent 888e435 commit 5444ff1
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 23 deletions.
2 changes: 1 addition & 1 deletion opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

from typing import Dict, List

from opacus.grad_sample import embedding_norm_sample
import torch
import torch.nn as nn

from opacus.grad_sample import embedding_norm_sample
from .utils import register_grad_sampler, register_norm_sampler


Expand Down
10 changes: 4 additions & 6 deletions opacus/grad_sample/embedding_norm_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@ def compute_embedding_norm_sample(
activations: [tensor([[1, 1],
[2, 0],
[2, 0]])]
backprops: tensor([[0.2000],
[0.2000],
[0.3000],
[0.1000],
[0.3000],
[0.1000]])
backprops: tensor([[[0.2], [0.2]],
[[0.3], [0.1]],
[[0.3], [0.1]]])
backprops.shape: torch.Size([3, 2, 1])
Intermediate values:
input_ids: tensor([[1, 1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_norm_calculation(self):
diff = flat_norms_normal - flat_norms_gc

logging.info(f"Diff = {diff}")
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg

def test_gradient_calculation(self):
Expand Down
22 changes: 7 additions & 15 deletions opacus/tests/grad_samples/embedding_norm_sample_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import unittest

from opacus.grad_sample import embedding_norm_sample
import torch
import torch.nn as nn
from opacus.grad_sample import embedding_norm_sample


class TestComputeEmbeddingNormSample(unittest.TestCase):
Expand All @@ -36,15 +36,11 @@ def test_compute_embedding_norm_sample(self):
# Example input ids (activations). Shape: [3, 2]
input_ids = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)

# Example gradients with respect to the embedding output (backprops).
# Shape: [6, 1]
grad_values = torch.tensor(
[[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]], dtype=torch.float32
# Example backprops. Shape: [3, 2, 1]
backprops = torch.tensor(
[[[0.2], [0.2]], [[0.3], [0.1]], [[0.3], [0.1]]], dtype=torch.float32
)

# Simulate backprop through embedding layer
backprops = grad_values

# Wrap input_ids in a list as expected by the norm sample function
activations = [input_ids]

Expand All @@ -70,17 +66,17 @@ def test_compute_embedding_norm_sample_with_non_one_embedding_dim(self):

# Manually set weights for the embedding layer for testing
embedding_layer.weight = nn.Parameter(
torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32)
torch.tensor([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]], dtype=torch.float32)
)

# Example input ids (activations). Shape: [6, 1, 1].
input_ids = torch.tensor(
[[[1]], [[1]], [[2]], [[0]], [[2]], [[0]]], dtype=torch.long
)

# Example gradients per input id, with embedding_dim=2.
# Example backprops per input id, with embedding_dim=2.
# Shape: [6, 1, 1, 2]
grad_values = torch.tensor(
backprops = torch.tensor(
[
[[[0.2, 0.2]]],
[[[0.2, 0.2]]],
Expand All @@ -92,9 +88,6 @@ def test_compute_embedding_norm_sample_with_non_one_embedding_dim(self):
dtype=torch.float32,
)

# Simulate backprop through embedding layer
backprops = grad_values

# Wrap input_ids in a list as expected by the grad norm function
activations = [input_ids]

Expand Down Expand Up @@ -211,7 +204,6 @@ def test_compute_embedding_norm_sample_with_extra_activations_per_example(self):
expected_norms = torch.tensor(
[0.0150, 0.0071, 0.0005, 0.0081, 0.0039], dtype=torch.float32
)
print("expected_norms: ", expected_norms)
computed_norms = result[embedding_layer.weight]

# Verify the computed norms match the expected norms
Expand Down

0 comments on commit 5444ff1

Please sign in to comment.