Skip to content

Commit

Permalink
fix: dispatches
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 29, 2024
1 parent ddeff44 commit 8d2794e
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ function materialize_traced_array(
return conj(materialize_traced_array(transpose(parent(x))))
end

function materialize_traced_array(
x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
return LinearAlgebra.diagm(parent(x))
function materialize_traced_array(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}) where {T}
return diagm(parent(x))
end

function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T}
Expand All @@ -42,7 +40,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
uAT = Symbol(:Unit, AT)
@eval begin
function TracedUtils.materialize_traced_array(
x::$(AT){T,TracedRArray{T,2}}
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
Expand All @@ -52,7 +50,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
end

function TracedUtils.materialize_traced_array(
x::$(uAT){T,TracedRArray{T,2}}
x::$(uAT){TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
Expand All @@ -64,7 +62,9 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
end
end

function TracedUtils.materialize_traced_array(x::Symmetric{T,TracedRArray{T,2}}) where {T}
function TracedUtils.materialize_traced_array(
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}
) where {T}
m, n = size(x)
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
Expand Down Expand Up @@ -107,7 +107,9 @@ function TracedUtils.set_mlir_data!(
return x
end

function TracedUtils.set_mlir_data!(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data) where {T}
function TracedUtils.set_mlir_data!(
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data
) where {T}
parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data
return x
end
Expand All @@ -119,7 +121,7 @@ for (AT, dcomp, ocomp) in (
(:UnitUpperTriangular, "LT", "GE"),
)
@eval function TracedUtils.set_mlir_data!(
x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data
x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data
) where {T}
tdata = TracedRArray{T}(data)
z = zero(tdata)
Expand All @@ -137,17 +139,19 @@ for (AT, dcomp, ocomp) in (
end

function TracedUtils.set_mlir_data!(
x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data
x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data
) where {T}
if x.uplo == 'L'
set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data)
set_mlir_data!(LowerTriangular(parent(x)), data)
else
set_mlir_data!(LinearAlgebra.UpperTriangular(parent(x)), data)
set_mlir_data!(UpperTriangular(parent(x)), data)
end
return x
end

function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T}
function TracedUtils.set_mlir_data!(
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data
) where {T}
tdata = TracedRArray{T}(data)
set_mlir_data!(x.dl, diag(tdata, -1).mlir_data)
set_mlir_data!(x.d, diag(tdata, 0).mlir_data)
Expand Down

0 comments on commit 8d2794e

Please sign in to comment.