diff --git a/src/algos/time_iteration.jl b/src/algos/time_iteration.jl index c9b361a..2f22188 100644 --- a/src/algos/time_iteration.jl +++ b/src/algos/time_iteration.jl @@ -147,12 +147,12 @@ function time_iteration_workspace(dmodel; interp_mode=:linear, improve=false, de vars = variables(dmodel.model.controls) φ = DFun(dmodel.model.states, x0, vars; interp_mode=interp_mode) - # if improve + if improve L = Dolo.dF_2(dmodel, x1, φ) tt = (;x0, x1, x2, r0, dx, J, L, φ) - # else - # tt = (;x0, x1, x2, r0, dx, J, φ) - # end + else + tt = (;x0, x1, x2, r0, dx, J, φ) + end return adapt(dest, tt) diff --git a/src/funs.jl b/src/funs.jl index 8be608e..6165969 100644 --- a/src/funs.jl +++ b/src/funs.jl @@ -109,13 +109,19 @@ function fit!(φ::DFun, x::GVector{G}) where G<:CGrid end -## PGrid +## PGrid ( SGrid × CGrid ) + + +function (f::DFun{A,B,I,vars})(x::QP) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars + f(x.loc...) +end + function (f::DFun{A,B,I,vars})(i::Int, x::SVector{d2, U}) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars where d2 where U f.itp[i](x) end -function (f::DFun{A,B,I,vars})(x::QP) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars - f(x.loc...) +function (f::DFun{A,B,I,vars})(i::Int, j::Int) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars where d2 where U + f.values[i,j] end function (f::DFun{A,B,I,vars})(x::Tuple) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars @@ -148,9 +154,9 @@ end # Compatibility calls -(f::DFun)(x::Real) = f(SVector(x)) -(f::DFun)(x::Real, y::Real) = f(SVector(x,y)) -(f::DFun)(x::Vector{SVector{d,<:Real}}) where d = [f(e) for e in x] +# (f::DFun)(x::Real) = f(SVector(x)) +# (f::DFun)(x::Real, y::Real) = f(SVector(x,y)) +# (f::DFun)(x::Vector{SVector{d,<:Real}}) where d = [f(e) for e in x] ndims(df::DFun) = ndims(df.domain)