From e8bfd723efc2f64e14c86dbe3c9768083390830c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Jul 2022 23:13:22 -0400 Subject: [PATCH] tidy up --- src/rulesets/Base/mapreduce.jl | 96 +++++++++++----------------------- 1 file changed, 31 insertions(+), 65 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index e72be275c..fa8c1c576 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -465,7 +465,7 @@ function rrule( y = first(last(hobbits)) project = ProjectTo(x) function foldl_pullback_tuple(dy) - trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + trio = accumulate(reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) ds, da, db = back(dc) # Don't need to store every `da`, need one for the next iteration + the last. end @@ -501,78 +501,43 @@ end # The implementation was originally for both tuples and arrays, although using accumulate # to carry intermediate results along creates arrays of tuples which could be avoided. -# Using a loop can be a few times faster, this should be replaced. -# Note also that it does not return a gradient for `init`. +# Using a loop can be a few times faster, this should be replaced: +# https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305 + +# Note also that it does not return a gradient for `init`, now marked `@not_implemented`. function rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple}; + config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple}; ) where {G} - list, start = if init === _INIT - _drop1(x), first(x) + start, list = if init === Base._InitialValue() + Iterators.peel(x) else # Case with init keyword is simpler to understand first! - _reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy) + init, x end - hobbits = accumulate(list; init=(start, nothing)) do (a,_), b - # Here `a` is what we would normally cary forward, and `_` ignores - # the previous iteration's pullback function (needed later), - # while `b` is the fresh input from `list` as usual. - c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here! - # We don't really need to store every `c`, last one is `foldl` output. - # (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.) + hobbits = accumulate(list; init=(start, nothing)) do (a, _), b + c, back = rrule_via_ad(config, op, a, b) end y = first(last(hobbits)) axe = axes(x) project = ProjectTo(x) function unfoldl(dy) - trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) + trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) ds, da, db = back(dc) - # Don't need to store every `da`, need one for the next iteration + maybe last end dop = sum(first, trio) - dx = map(last, _reverse1(trio)) - if init === _INIT - # `hobbits` is one short + dx = map(last, Iterators.reverse(trio)) + if init === Base._InitialValue() # `hobbits` is one short dx = _vcat1(trio[end][2], dx) end d_init = @not_implemented "gradient for foldl does not at present include init, sorry" - return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe))) + return (NoTangent(), NoTangent(), dop, d_init, project(reshape(dx, axe))) end return y, unfoldl end - -##### -##### Iterator-or-Tuple functions -##### - -# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays, -# and also provides some alternatives for versions of Julia where iterators weren't supported. -# Inspired by `Base._reverse`, used in defn of `foldr`. - -# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps -# be replaced by _peel1 like Iterators.peel - -_reverse1(x) = Iterators.reverse(x) -_drop1(x) = Iterators.drop(x, 1) -_zip2(x, y) = zip(x, y) # for `accumulate`, below - -_reverse1(x::Tuple) = reverse(x) -_drop1(x::Tuple) = Base.tail(x) -_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N) - -const _INIT = Base._InitialValue() - _vcat1(x, ys::AbstractVector) = vcat(x, ys) _vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys) -_vcat1(x, ys::Tuple) = (x, ys...) - -_reshape1(x::AbstractArray, axe) = reshape(x, axe) -_reshape1(x::Tuple, axe) = x - -_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx) -_no_tuple_tangent(dx) = dx - ##### ##### `accumulate` @@ -584,13 +549,18 @@ _no_tuple_tangent(dx) = dx # Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)` function rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init, + config::RuleConfig{>:HasReverseMode}, + ::typeof(Base._accumulate!), + op::G, y::AbstractVector, + x::AbstractVector, + dims::Nothing, + init, ) where {G} - list, start = if init === nothing - _drop1(x), first(x) + start, list = if init === nothing + Iterators.peel(x) else - x, something(init) + something(init), x end hobbits = accumulate(list; init = (start, nothing)) do (a, _), b c, back = rrule_via_ad(config, op, a, b) @@ -607,20 +577,16 @@ function rrule( axe = axes(x) project = ProjectTo(x) function decumulate(dy) - dy_plain = _no_tuple_tangent(unthunk(dy)) - rev_list = if init === nothing - # Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...)) - # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" - _zip2(_reverse1(hobbits), _reverse1(dy_plain)) - else - _zip2(_reverse1(hobbits), _reverse1(dy_plain)) - end + dy_plain = unthunk(dy) + rev_list = zip(Iterators.reverse(hobbits), Iterators.reverse(dy_plain)) + # Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1)) + # gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz) ds, da, db = back(dc + dz) # Don't need to store every 'da', but need for next iteration, and the last one. end dop = sum(first, trio) - dx = map(last, _reverse1(trio)) + dx = map(last, Iterators.reverse(trio)) if init == nothing # `hobbits` is one short, and the first one is weird dx = _vcat1(trio[end][2] + dy_plain[1], dx) @@ -628,7 +594,7 @@ function rrule( dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only" d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry" d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not) - return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init) + return (NoTangent(), dop, dy, project(reshape(dx, axe)), NoTangent(), d_init) end - return _reshape1(y, axe), decumulate + return reshape(y, axe), decumulate end