Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unthunk tangents (if any) before returning gradient #1551

Merged
merged 6 commits into from
Jan 21, 2025

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Jan 16, 2025

@CarloLucibello
Copy link
Member

Can you come up with a test?

@pxl-th
Copy link
Member Author

pxl-th commented Jan 16, 2025

Hm... While it does fix the issue with Flux.Dense it still returns thunks for plain struct:

using Zygote

struct Dense
    w::Matrix{Float32}
end

(d::Dense)(x) = d.w * x

function main()
    layers = [Dense(rand(Float32, 3, 3))]
    x = ones(Float32, 3)
    g = gradient(layers -> sum(layers[1](x)), layers)[1]
    @show g
end
main()

Changing @_adjoint_keepthunks to @adjoint here "fixes" it, but probably is not the right way, since it will unthunk every getfield.
Because pullback for getfield is last in the chain it returns thunks, but otherwise it is fine.

So maybe we unthunk the before returning?
Or someone has a better idea...

CC @mcabbott @ToucheSir

@ToucheSir
Copy link
Member

Can you remind me what if anything we lose by unthunking at the top level (before gradient returns)? I think the problem is that unthunk_tangent was not taught about Tuples and NamedTuples in https://github.com/FluxML/Zygote.jl/pull/966/files#diff-e0bc7da8f1a33a59f5ecfa67257c04038f0b4915b3f74bdf39780818fd0010a2R3-R8.

@mcabbott
Copy link
Member

mcabbott commented Jan 17, 2025

I think what's happening is that m.W*x returns dm1 = (W=Thunk(), b=nothing) while .+m.b returns dm2 = (W=nothing, b=Thunk()), and these are added by accum(dm1, dm2). That's why altering accum works here, but seems like the wrong place -- it will un-thunk things which aren't returned, and probably won't work for structs where only one field is used.

I agree that making unthunk_tangent recurse not just into arrays but also NamedTuples/Tuples seems like the right path.

It might be worth having accum(dx::AbstractThunk, ::Nothing) = dx etc, or otherwise making sure that this returns dx instead of making a new thunk? I can't prove this matters for performance but it seems worth trying a little to avoid making thunks of thunks of thunks.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 17, 2025

Hm... it was taught, unthunk_tangent is imported from ZygoteRules and there it has definition for NamedTuple:
https://github.com/FluxML/ZygoteRules.jl/blob/f9bf0e367fa259c5aa68f0e14ccbf2125d734bd6/src/adjoint.jl#L39

@ToucheSir
Copy link
Member

I don't think it is imported or used, per the comment in the code diff I linked.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 17, 2025

It is imported here.
And if defining unthunk_tangent for Tuple in Zygote it complains that it overwrites the one from ZygoteRules:

WARNING: Method definition unthunk_tangent(Tuple) in module ZygoteRules at /home/pxlth/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:38 overwritten in module Zygote at /home/pxlth/.julia/dev/Zygote/src/compiler/chainrules.jl:7.

@ToucheSir
Copy link
Member

That's strange, because the code in #966 is definitely defining an unthunk_tangent function from scratch. I wonder if that's because it's in one of the submodules.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 17, 2025

I added tests, but I'm a bit out of ideas, so I made it unthunk before returning gradients.
Other approaches I've tried broke laziness...

@pxl-th pxl-th changed the title Do not at-thunk with mixed-type accum Unthunk tangents (if any) before returning gradient Jan 18, 2025
@pxl-th
Copy link
Member Author

pxl-th commented Jan 18, 2025

I've tested with Flux and all tests pass (CPU + AMDGPU). Maybe this is fine for now?

@ToucheSir
Copy link
Member

The current approach LGTM, but perhaps it would make sense to have the new overloads for unthunk_tangent in ZygoteRules.jl? If nobody has any strong feelings about that though, happy to approve.

@pxl-th
Copy link
Member Author

pxl-th commented Jan 19, 2025

Agree, moved them: FluxML/ZygoteRules.jl#28

@pxl-th
Copy link
Member Author

pxl-th commented Jan 20, 2025

The current approach LGTM, but perhaps it would make sense to have the new overloads for unthunk_tangent in ZygoteRules.jl?

Based on FluxML/ZygoteRules.jl#28 (comment) and other comments, maybe it'd be better to move on with this PR and have a separate PR that will resolve this.

@ToucheSir
Copy link
Member

It also turns out that defining function unthunk_tangent end twice is a no-op? Can you remove that line as well to spare the same type of confusion that caused FluxML/ZygoteRules.jl#30?

@pxl-th
Copy link
Member Author

pxl-th commented Jan 20, 2025

It also turns out that defining function unthunk_tangent end twice is a no-op? Can you remove that line

Already did: https://github.com/FluxML/Zygote.jl/pull/1551/files#diff-e0bc7da8f1a33a59f5ecfa67257c04038f0b4915b3f74bdf39780818fd0010a2R3

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In hindsight, #966 should not have been merged with the unthunk_tangent changes because it was accidentally committing type piracy. But undoing that is going to take some work, and it's not clear if ZygoteRules.jl will still exist in its current form by the time someone gets around to doing said work because of how complex the phasing needs to be.

Project.toml Outdated
@@ -57,7 +57,7 @@ Requires = "1.1"
SpecialFunctions = "1.6, 2"
Statistics = "1"
Tracker = "0.2"
ZygoteRules = "0.2.5"
ZygoteRules = "=0.2.5"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bound needs to be updated after FluxML/ZygoteRules.jl#31.

@CarloLucibello CarloLucibello merged commit 1b111d8 into FluxML:master Jan 21, 2025
11 of 13 checks passed
@pxl-th pxl-th deleted the pxl-th/thunk branch January 21, 2025 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants