Skip to content

Commit

Permalink
Add trixi_include_changeprecision
Browse files Browse the repository at this point in the history
  • Loading branch information
efaulhaber committed Jan 14, 2025
1 parent d294f66 commit 9a3f3aa
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper <[email protected]>"]
version = "0.1.5"

[deps]
ChangePrecision = "3cb15238-376d-56a3-8042-d33272777c9a"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
Expand All @@ -13,6 +14,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
TrixiBaseMPIExt = "MPI"

[compat]
ChangePrecision = "1.1.0"
MPI = "0.20"
TimerOutputs = "0.5.25"
julia = "1.8"
Expand Down
1 change: 1 addition & 0 deletions src/TrixiBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module TrixiBase

using ChangePrecision: ChangePrecision
using TimerOutputs: TimerOutput, TimerOutputs

include("trixi_include.jl")
Expand Down
45 changes: 45 additions & 0 deletions src/trixi_include.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,51 @@ function trixi_include(elixir::AbstractString; kwargs...)
trixi_include(Main, elixir; kwargs...)
end

"""
trixi_include_changeprecision(T, [mod::Module=Main,] elixir::AbstractString; kwargs...)
`include` the elixir `elixir` and evaluate its content in the global scope of module `mod`.
You can override specific assignments in `elixir` by supplying keyword arguments,
similar to [`trixi_include`](@ref).
The only difference to [`trixi_include`](@ref) is that the precision of floating-point
numbers in the included elixir is changed to `T`.
More precisely, the package [ChangePrecision.jl](https://github.com/JuliaMath/ChangePrecision.jl)
is used to convert all `Float64` literals, operations like `/` that produce `Float64` results,
and functions like `ones` that return `Float64` arrays by default, to the desired type `T`.
See the documentation of ChangePrecision.jl for more details.
The purpose of this function is to conveniently run a full simulation with `Float32`,
which is orders of magnitude faster on most GPUs than `Float64`, by just including
the elixir with `trixi_include_changeprecision(Float32, elixir)`.
Most code in the Trixi framework is written in a way that changing all floating-point
numbers in the elixir to `Float32` manually will run the full simulation with single precision.
See [the docs on GPU support](@ref gpu_support) for more information.
"""
function trixi_include_changeprecision(T, mod::Module, filename::AbstractString; kwargs...)
trixi_include(expr -> ChangePrecision.changeprecision(T, replace_trixi_include(T, expr)),
mod, filename; kwargs...)
end

function trixi_include_changeprecision(T, filename::AbstractString; kwargs...)
trixi_include_changeprecision(T, Main, filename; kwargs...)
end

function replace_trixi_include(T, expr)
expr = TrixiBase.walkexpr(expr) do x
if x isa Expr
if x.head === :call && x.args[1] === :trixi_include
x.args[1] = :trixi_include_changeprecision
insert!(x.args, 2, :($T))
end
end
return x
end

return expr
end

# Insert the keyword argument `maxiters` into calls to `solve` and `Trixi.solve`
# with default value `10^5` if it is not already present.
function insert_maxiters(expr)
Expand Down

0 comments on commit 9a3f3aa

Please sign in to comment.