-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unthunk tangents (if any) before returning gradient #1551
Conversation
Can you come up with a test? |
Hm... While it does fix the issue with using Zygote
struct Dense
w::Matrix{Float32}
end
(d::Dense)(x) = d.w * x
function main()
layers = [Dense(rand(Float32, 3, 3))]
x = ones(Float32, 3)
g = gradient(layers -> sum(layers[1](x)), layers)[1]
@show g
end
main() Changing So maybe we unthunk the before returning? |
Can you remind me what if anything we lose by unthunking at the top level (before |
I think what's happening is that I agree that making It might be worth having |
Hm... it was taught, |
I don't think it is imported or used, per the comment in the code diff I linked. |
It is imported here.
|
That's strange, because the code in #966 is definitely defining an |
I added tests, but I'm a bit out of ideas, so I made it unthunk before returning gradients. |
I've tested with Flux and all tests pass (CPU + AMDGPU). Maybe this is fine for now? |
The current approach LGTM, but perhaps it would make sense to have the new overloads for |
Agree, moved them: FluxML/ZygoteRules.jl#28 |
Based on FluxML/ZygoteRules.jl#28 (comment) and other comments, maybe it'd be better to move on with this PR and have a separate PR that will resolve this. |
It also turns out that defining |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In hindsight, #966 should not have been merged with the unthunk_tangent
changes because it was accidentally committing type piracy. But undoing that is going to take some work, and it's not clear if ZygoteRules.jl will still exist in its current form by the time someone gets around to doing said work because of how complex the phasing needs to be.
Project.toml
Outdated
@@ -57,7 +57,7 @@ Requires = "1.1" | |||
SpecialFunctions = "1.6, 2" | |||
Statistics = "1" | |||
Tracker = "0.2" | |||
ZygoteRules = "0.2.5" | |||
ZygoteRules = "=0.2.5" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bound needs to be updated after FluxML/ZygoteRules.jl#31.
Fixes: FluxML/Flux.jl#2574 (comment)