From 7679bb7c1251af66ac8f9dc0686f73f7f715a88b Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 19 Jan 2025 11:30:16 +0200 Subject: [PATCH 1/3] Expand unthunk-tangent to more methods --- src/adjoint.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/adjoint.jl b/src/adjoint.jl index 7fef6ea..9f99639 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,6 +1,6 @@ using MacroTools using MacroTools: @q, combinedef -using ChainRulesCore: AbstractZero +using ChainRulesCore: AbstractZero, @non_differentiable function named(arg) if isexpr(arg, :(::)) && length(arg.args) == 1 @@ -37,6 +37,12 @@ function unthunk_tangent end @inline unthunk_tangent(x) = x @inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x) @inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x) +@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) +@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x +@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x +@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) +unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) +@non_differentiable unthunk_tangent(::IdDict) function gradm(ex, mut = false, keepthunks = false) From 6b898b1d1373639621593e3f088a87613d641a99 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 19 Jan 2025 11:31:52 +0200 Subject: [PATCH 2/3] Fix --- src/adjoint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adjoint.jl b/src/adjoint.jl index 9f99639..a5883dc 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,6 +1,6 @@ using MacroTools using MacroTools: @q, combinedef -using ChainRulesCore: AbstractZero, @non_differentiable +using ChainRulesCore: AbstractZero, AbstractThunk, @non_differentiable function named(arg) if isexpr(arg, :(::)) && length(arg.args) == 1 From b0a235c1ee51db353f1331d2631a58bafdd5caee Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Sun, 19 Jan 2025 11:36:35 +0200 Subject: [PATCH 3/3] Bump to 0.2.6 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 57e80e2..c20837c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ZygoteRules" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" +version = "0.2.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"