From 9a3f3aa5445b5c87f197147d9b944d128aaf1b4f Mon Sep 17 00:00:00 2001 From: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:55:44 +0100 Subject: [PATCH] Add `trixi_include_changeprecision` --- Project.toml | 2 ++ src/TrixiBase.jl | 1 + src/trixi_include.jl | 45 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/Project.toml b/Project.toml index bc524f8..2da2b87 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Michael Schlottke-Lakemper "] version = "0.1.5" [deps] +ChangePrecision = "3cb15238-376d-56a3-8042-d33272777c9a" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] @@ -13,6 +14,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" TrixiBaseMPIExt = "MPI" [compat] +ChangePrecision = "1.1.0" MPI = "0.20" TimerOutputs = "0.5.25" julia = "1.8" diff --git a/src/TrixiBase.jl b/src/TrixiBase.jl index daac03c..dad5559 100644 --- a/src/TrixiBase.jl +++ b/src/TrixiBase.jl @@ -1,5 +1,6 @@ module TrixiBase +using ChangePrecision: ChangePrecision using TimerOutputs: TimerOutput, TimerOutputs include("trixi_include.jl") diff --git a/src/trixi_include.jl b/src/trixi_include.jl index 65297cf..23413a1 100644 --- a/src/trixi_include.jl +++ b/src/trixi_include.jl @@ -61,6 +61,51 @@ function trixi_include(elixir::AbstractString; kwargs...) trixi_include(Main, elixir; kwargs...) end +""" + trixi_include_changeprecision(T, [mod::Module=Main,] elixir::AbstractString; kwargs...) + +`include` the elixir `elixir` and evaluate its content in the global scope of module `mod`. +You can override specific assignments in `elixir` by supplying keyword arguments, +similar to [`trixi_include`](@ref). + +The only difference to [`trixi_include`](@ref) is that the precision of floating-point +numbers in the included elixir is changed to `T`. +More precisely, the package [ChangePrecision.jl](https://github.com/JuliaMath/ChangePrecision.jl) +is used to convert all `Float64` literals, operations like `/` that produce `Float64` results, +and functions like `ones` that return `Float64` arrays by default, to the desired type `T`. +See the documentation of ChangePrecision.jl for more details. + +The purpose of this function is to conveniently run a full simulation with `Float32`, +which is orders of magnitude faster on most GPUs than `Float64`, by just including +the elixir with `trixi_include_changeprecision(Float32, elixir)`. +Most code in the Trixi framework is written in a way that changing all floating-point +numbers in the elixir to `Float32` manually will run the full simulation with single precision. + +See [the docs on GPU support](@ref gpu_support) for more information. +""" +function trixi_include_changeprecision(T, mod::Module, filename::AbstractString; kwargs...) + trixi_include(expr -> ChangePrecision.changeprecision(T, replace_trixi_include(T, expr)), + mod, filename; kwargs...) +end + +function trixi_include_changeprecision(T, filename::AbstractString; kwargs...) + trixi_include_changeprecision(T, Main, filename; kwargs...) +end + +function replace_trixi_include(T, expr) + expr = TrixiBase.walkexpr(expr) do x + if x isa Expr + if x.head === :call && x.args[1] === :trixi_include + x.args[1] = :trixi_include_changeprecision + insert!(x.args, 2, :($T)) + end + end + return x + end + + return expr +end + # Insert the keyword argument `maxiters` into calls to `solve` and `Trixi.solve` # with default value `10^5` if it is not already present. function insert_maxiters(expr)