Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more stringent test for CPUOffloadOptimizer #1650

Merged
merged 2 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,24 @@ def test_optim_4bit_correctness(self, optim_name):
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = _DEVICES[-1]
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128))
# The first two layers are chosen so that they have a terrible arithmetic density.
# this means long transfers and comparatively quick computation, increasing the chances
# that missing synchronization will lead to test failures.
# The third layer is very small, here to validate non-trainable parameters,
# but shouldn't influence the timings
model1 = nn.Sequential(
nn.Linear(32, 131072),
nn.ReLU(),
nn.Linear(131072, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 128),
)
model1.to(device)

# make sure it can work in the presence of non-trainable params
model1[0].requires_grad_(False)
model1[2].requires_grad_(False)
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
Expand All @@ -274,15 +287,26 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
offload_gradients=offload_grad,
)

rng = torch.Generator(device=device)
rng.manual_seed(42)

# make sure to run both models separately; otherwise, model1 gives additional
# time for operations in model2 to complete, marking potential race conditions.
for _ in range(2):
for _ in range(grad_accum):
x = torch.randn(4, 32, device=device)
x = torch.randn(4, 32, device=device, generator=rng)
model1(x).sum().backward()
model2(x).sum().backward()

optim1.step()
optim1.zero_grad()

# reset the rng
rng.manual_seed(42)
for _ in range(2):
for _ in range(grad_accum):
x = torch.randn(4, 32, device=device, generator=rng)
model2(x).sum().backward()

optim2.step()
optim2.zero_grad()

Expand Down
2 changes: 2 additions & 0 deletions torchao/prototype/low_bit_optim/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def step(self, closure=None):
with getattr(torch, self.device).stream(self.stream):
p_device.copy_(p_host, non_blocking=True)

# make sure param H2D finishes before the next forward pass
self.stream.synchronize()
self.queue.clear()
return loss

Expand Down
Loading