Skip to content

Commit

Permalink
Fixed derivative of perp when u or v is infinite or both are zero.
Browse files Browse the repository at this point in the history
  • Loading branch information
albop committed Jan 29, 2025
1 parent d2c9806 commit 665ab1c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand All @@ -15,6 +16,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ForwardDiffChainRules = "c9556dd2-1aed-4cfe-8560-1557cf593001"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand Down Expand Up @@ -42,8 +44,10 @@ DoloCUDAExt = "CUDA"
DolooneAPIExt = "oneAPI"

[compat]
ChainRulesCore = "1.25.1"
Dolang = "≥3.3.0"
Format = "1.3.7"
ForwardDiffChainRules = "0.2.1"
LabelledArrays = "≥1.16.0"
StaticArrays = "≥1.9.8"
Term = "2.0.6"
44 changes: 42 additions & 2 deletions src/Dolo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,52 @@ module Dolo


(a,b) = min(a,b)
function (u,v)

function (u0,v0)
BIG = 100000
u = min(max(u0,-BIG),BIG)
v = min(max(v0,-BIG),BIG)
sq = sqrt(u^2+v^2)
p = (v<Inf ? (u+v-sq)/2 : u)
p = (u+v-sq)/2
return p
end

# function ⫫(u,v)
# sq = sqrt(u^2+v^2)
# p = (v<Inf ? (u+v-sq)/2 : u)
# return p
# end

using ChainRulesCore
using ForwardDiffChainRules

# define your frule for function f1 as usual
function ChainRulesCore.frule((_, Δu, Δv), ::typeof(), u::Real, v::Real)
BIG = 100000
u = min(max(u0,-BIG),BIG)
v = min(max(v0,-BIG),BIG)
sq = sqrt(u^2+v^2)
Omega = (u+v-sq)/2
if u==v==0.0
Omega, (Δu + Δv)*0
# elseif !(u<Inf)
# return v, Δv
# elseif !(v<Inf)
# return u, Δu
else
Δ1 = (0.5 - u/sq)*Δu
Δ2 = (0.5 - v/sq)*Δv
return Omega, Δ1 + Δ2
end
end


# @ForwardDiff_frule ⫫(u::ForwardDiff.Dual, v::ForwardDiff.Dual)
# @ForwardDiff_frule ⫫(u::ForwardDiff.Dual, v)
# @ForwardDiff_frule ⫫(u, v::ForwardDiff.Dual)


# ⫫(u,v) = fun(u,v)

# #, SDiagonal(J_u), SDiagonal(J_v)
# # J_u = (v<Inf ? (1.0 - u[i]./sq[i])/2 : 1) for i=1:d )
Expand Down

0 comments on commit 665ab1c

Please sign in to comment.