From 773aab4a817d07f36bd917f2392ca18772198771 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 12 May 2023 00:09:12 +0200 Subject: [PATCH 01/11] wip: add ir module --- examples/brutus-14.jl | 263 ++++++++++ examples/brutus.jl | 276 ++++++++++ lib/14/libMLIR_h.jl | 2 +- lib/15/libMLIR_h.jl | 2 +- src/Dialects.jl | 204 ++++++++ src/IR.jl | 1143 +++++++++++++++++++++++++++++++++++++++++ src/MLIR.jl | 5 +- 7 files changed, 1892 insertions(+), 3 deletions(-) create mode 100644 examples/brutus-14.jl create mode 100644 examples/brutus.jl create mode 100644 src/Dialects.jl create mode 100644 src/IR.jl diff --git a/examples/brutus-14.jl b/examples/brutus-14.jl new file mode 100644 index 00000000..53c3c640 --- /dev/null +++ b/examples/brutus-14.jl @@ -0,0 +1,263 @@ +module Brutus + +using MLIR.IR +using MLIR.Dialects: arith, std +using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode + + +const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} + +function cmpi_pred(predicate) + function(ctx, ops; loc=Location(ctx)) + arith.cmpi(ctx, predicate, ops; loc) + end +end + +function single_op_wrapper(fop) + (ctx::Context, block::Block, args::Vector{Value}; loc=Location(ctx)) -> push!(block, fop(ctx, args; loc)) +end + +const intrinsics_to_mlir = Dict([ + Base.add_int => single_op_wrapper(arith.addi), + Base.sle_int => single_op_wrapper(cmpi_pred(arith.Predicates.sle)), + Base.slt_int => single_op_wrapper(cmpi_pred(arith.Predicates.slt)), + Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), + Base.mul_int => single_op_wrapper(arith.muli), + Base.mul_float => single_op_wrapper(arith.mulf), + # TODO: i don't know how to do a bitwise negation in any other way + Base.not_int => function(ctx, block, args; loc=Location(ctx)) + arg = only(args) + ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result + push!(block, arith.xori(ctx, Value[arg, ones]; loc)) + end, +]) + +"Generates a block argument for each phi node present in the block." +function prepare_block(ctx, ir, bb) + b = Block() + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + type = stmt[:type] + IR.push_argument!(b, MType(ctx, type), Location(ctx)) + end + + return b +end + +"Values to populate the Phi Node when jumping from `from` to `to`." +function collect_value_arguments(ir, from, to) + to = ir.cfg.blocks[to] + values = [] + for s in to.stmts + stmt = ir.stmts[s] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + edge = findfirst(==(from), inst.edges) + if isnothing(edge) # use dummy scalar val instead + val = zero(stmt[:type]) + push!(values, val) + else + push!(values, inst.values[edge]) + end + end + values +end + +""" + code_mlir(f, types::Type{Tuple}; ctx=Context()) -> IR.Operation + +Returns a `builtin.func` operation corresponding to the ircode of the provided method. +This only supports a few Julia Core primitives and scalar types of type $BrutusScalar. + +!!! note + The Julia SSAIR to MLIR conversion implemented is very primitive and only supports a + handful of primitives. A better to perform this conversion would to create a dialect + representing Julia IR and progressively lower it to base MLIR dialects. +""" +function code_mlir(f, types; ctx=Context()) + ir, ret = Core.Compiler.code_ircode(f, types) |> only + @assert first(ir.argtypes) isa Core.Const + + values = Vector{Value}(undef, length(ir.stmts)) + @show length(ir.stmts) length(values) + + for dialect in ("std",) + IR.get_or_load_dialect!(ctx, dialect) + end + + + blocks = [ + prepare_block(ctx, ir, bb) + for bb in ir.cfg.blocks + ] + + current_block = entry_block = blocks[begin] + + for argtype in types.parameters + IR.push_argument!(entry_block, MType(ctx, argtype), Location(ctx)) + end + + function get_value(x)::Value + if x isa Core.SSAValue + @assert isassigned(values, x.id) "value $x was not assigned" + values[x.id] + elseif x isa Core.Argument + IR.get_argument(entry_block, x.n - 1) + elseif x isa Number + IR.get_result(push!(current_block, arith.constant(ctx, x))) + else + error("could not use value $x inside MLIR") + end + end + + for (block_id, (b, bb)) in enumerate(zip(blocks, ir.cfg.blocks)) + current_block = b + n_phi_nodes = 0 + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + line = ir.linetable[stmt[:line]] + + if Meta.isexpr(inst, :call) + line = ir.linetable[stmt[:line]] + val_type = stmt[:type] + if !(val_type <: BrutusScalar) + error("type $val_type is not supported") + end + out_type = MType(ctx, val_type) + + called_func = first(inst.args) + if called_func isa GlobalRef # TODO: should probably use something else here + called_func = getproperty(called_func.mod, called_func.name) + end + + fop! = intrinsics_to_mlir[called_func] + args = get_value.(@view inst.args[begin+1:end]) + + res = IR.get_result(fop!(ctx, current_block, args; loc=Location(ctx, line))) + + values[sidx] = res + elseif inst isa PhiNode + values[sidx] = IR.get_argument(current_block, n_phi_nodes += 1) + elseif inst isa PiNode + values[sidx] = get_value(inst.val) + elseif inst isa GotoNode + args = get_value.(collect_value_arguments(ir, block_id, inst.label)) + dest = blocks[inst.label] + push!(current_block, std.br(ctx, dest, args; loc=Location(ctx, line))) + elseif inst isa GotoIfNot + false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) + cond = get_value(inst.cond) + @assert length(bb.succs) == 2 # NOTE: We assume that length(bb.succs) == 2, this might be wrong + other_dest = setdiff(bb.succs, inst.dest) |> only + true_args = get_value.(collect_value_arguments(ir, block_id, other_dest)) + other_dest = blocks[other_dest] + dest = blocks[inst.dest] + + cond_br = std.cond_br(ctx, cond, other_dest, dest, true_args, false_args; loc=Location(ctx, line)) + push!(current_block, cond_br) + elseif inst isa ReturnNode + line = ir.linetable[stmt[:line]] + push!(current_block, std.return_(ctx, [get_value(inst.val)]; loc=Location(ctx, line))) + else + error("unhandled ir $(inst)") + end + end + end + + func_name = nameof(f) + + region = Region() + for b in blocks + push!(region, b) + end + + state = OperationState("builtin.func", Location(ctx)) + + input_types = MType[ + IR.get_type(IR.get_argument(entry_block, i)) + for i in 1:IR.num_arguments(entry_block) + ] + result_types = [MType(ctx, ret)] + + ftype = MType(ctx, input_types => result_types) + IR.add_attributes!(state, [ + NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), + NamedAttribute(ctx, "type", IR.Attribute(ftype)), + ]) + IR.add_owned_regions!(state, Region[region]) + + op = Operation(state) + + IR.verifyall(op) + + op +end + +""" + @code_mlir f(args...) +""" +macro code_mlir(call) + @assert Meta.isexpr(call, :call) "only calls are supported" + + f = first(call.args) |> esc + args = Expr(:curly, + Tuple, + map(arg -> :($(Core.Typeof)($arg)), + call.args[begin+1:end])..., + ) |> esc + + quote + code_mlir($f, $args) + end +end + +end # module Brutus + +# --- + +function pow(x::F, n) where {F} + p = one(F) + for _ in 1:n + p *= x + end + p +end + +function f(x) + if x == 1 + 2 + else + 3 + end +end + +# --- + +using MLIR.IR, MLIR + +ctx = Context() +MLIR.API.mlirContextEnableMultithreading(ctx, false) + +op = Brutus.code_mlir(pow, Tuple{Float64, Int}) + +mod = MModule(ctx, Location(ctx)) +body = IR.get_body(mod) +push!(body, op) + +pm = IR.PassManager(ctx) +opm = IR.OpPassManager(pm, "builtin.func") + +# TODO: make high-level API for these +MLIR.API.mlirPassManagerEnableIRPrinting(pm) +MLIR.API.mlirPassManagerEnableVerifier(pm, true) +MLIR.API.mlirOpPassManagerAddOwnedPass(opm, MLIR.API.mlirCreateConversionConvertArithmeticToLLVM()) +MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateConversionConvertStandardToLLVM()) + +IR.run(pm, mod) diff --git a/examples/brutus.jl b/examples/brutus.jl new file mode 100644 index 00000000..93cefe4a --- /dev/null +++ b/examples/brutus.jl @@ -0,0 +1,276 @@ +module Brutus + +using MLIR.IR +using MLIR.Dialects: arith, func, cf +using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode + + +const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} + +function cmpi_pred(predicate) + function(ctx, ops; loc=Location(ctx)) + arith.cmpi(ctx, predicate, ops; loc) + end +end + +function single_op_wrapper(fop) + (ctx::Context, block::Block, args::Vector{Value}; loc=Location(ctx)) -> push!(block, fop(ctx, args; loc)) +end + +const intrinsics_to_mlir = Dict([ + Base.add_int => single_op_wrapper(arith.addi), + Base.sle_int => single_op_wrapper(cmpi_pred(arith.Predicates.sle)), + Base.slt_int => single_op_wrapper(cmpi_pred(arith.Predicates.slt)), + Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), + Base.mul_int => single_op_wrapper(arith.muli), + Base.mul_float => single_op_wrapper(arith.mulf), + # TODO: i don't know how to do a bitwise negation in any other way + Base.not_int => function(ctx, block, args; loc=Location(ctx)) + arg = only(args) + ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result + push!(block, arith.xori(ctx, Value[arg, ones]; loc)) + end, +]) + +"Generates a block argument for each phi node present in the block." +function prepare_block(ctx, ir, bb) + b = Block() + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + type = stmt[:type] + IR.push_argument!(b, MType(ctx, type), Location(ctx)) + end + + return b +end + +"Values to populate the Phi Node when jumping from `from` to `to`." +function collect_value_arguments(ir, from, to) + to = ir.cfg.blocks[to] + values = [] + for s in to.stmts + stmt = ir.stmts[s] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + edge = findfirst(==(from), inst.edges) + if isnothing(edge) # use dummy scalar val instead + val = zero(stmt[:type]) + push!(values, val) + else + push!(values, inst.values[edge]) + end + end + values +end + +""" + code_mlir(f, types::Type{Tuple}; ctx=Context()) -> IR.Operation + +Returns a `func.func` operation corresponding to the ircode of the provided method. +This only supports a few Julia Core primitives and scalar types of type $BrutusScalar. + +!!! note + The Julia SSAIR to MLIR conversion implemented is very primitive and only supports a + handful of primitives. A better to perform this conversion would to create a dialect + representing Julia IR and progressively lower it to base MLIR dialects. +""" +function code_mlir(f, types; ctx=Context()) + ir, ret = Core.Compiler.code_ircode(f, types) |> only + @assert first(ir.argtypes) isa Core.Const + + values = Vector{Value}(undef, length(ir.stmts)) + @show length(ir.stmts) length(values) + + for dialect in ("func", "cf") + IR.get_or_load_dialect!(ctx, dialect) + end + + blocks = [ + prepare_block(ctx, ir, bb) + for bb in ir.cfg.blocks + ] + + current_block = entry_block = blocks[begin] + + for argtype in types.parameters + IR.push_argument!(entry_block, MType(ctx, argtype), Location(ctx)) + end + + function get_value(x)::Value + if x isa Core.SSAValue + @assert isassigned(values, x.id) "value $x was not assigned" + values[x.id] + elseif x isa Core.Argument + IR.get_argument(entry_block, x.n - 1) + elseif x isa Number + IR.get_result(push!(current_block, arith.constant(ctx, x))) + else + error("could not use value $x inside MLIR") + end + end + + for (block_id, (b, bb)) in enumerate(zip(blocks, ir.cfg.blocks)) + current_block = b + n_phi_nodes = 0 + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + line = ir.linetable[stmt[:line]] + + if Meta.isexpr(inst, :call) + line = ir.linetable[stmt[:line]] + val_type = stmt[:type] + if !(val_type <: BrutusScalar) + error("type $val_type is not supported") + end + out_type = MType(ctx, val_type) + + called_func = first(inst.args) + if called_func isa GlobalRef # TODO: should probably use something else here + called_func = getproperty(called_func.mod, called_func.name) + end + + fop! = intrinsics_to_mlir[called_func] + args = get_value.(@view inst.args[begin+1:end]) + + res = IR.get_result(fop!(ctx, current_block, args; loc=Location(ctx, line))) + + values[sidx] = res + elseif inst isa PhiNode + values[sidx] = IR.get_argument(current_block, n_phi_nodes += 1) + elseif inst isa PiNode + values[sidx] = get_value(inst.val) + elseif inst isa GotoNode + args = get_value.(collect_value_arguments(ir, block_id, inst.label)) + dest = blocks[inst.label] + push!(current_block, cf.br(ctx, dest, args; loc=Location(ctx, line))) + elseif inst isa GotoIfNot + false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) + cond = get_value(inst.cond) + @assert length(bb.succs) == 2 # NOTE: We assume that length(bb.succs) == 2, this might be wrong + other_dest = setdiff(bb.succs, inst.dest) |> only + true_args = get_value.(collect_value_arguments(ir, block_id, other_dest)) + other_dest = blocks[other_dest] + dest = blocks[inst.dest] + + cond_br = cf.cond_br(ctx, cond, other_dest, dest, true_args, false_args; loc=Location(ctx, line)) + push!(current_block, cond_br) + elseif inst isa ReturnNode + line = ir.linetable[stmt[:line]] + push!(current_block, func.return_(ctx, [get_value(inst.val)]; loc=Location(ctx, line))) + else + error("unhandled ir $(inst)") + end + end + end + + func_name = nameof(f) + + region = Region() + for b in blocks + push!(region, b) + end + + state = OperationState("func.func", Location(ctx)) + + input_types = MType[ + IR.get_type(IR.get_argument(entry_block, i)) + for i in 1:IR.num_arguments(entry_block) + ] + result_types = [MType(ctx, ret)] + + ftype = MType(ctx, input_types => result_types) + IR.add_attributes!(state, [ + NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), + NamedAttribute(ctx, "function_type", IR.Attribute(ftype)), + ]) + IR.add_owned_regions!(state, Region[region]) + + op = Operation(state) + + IR.verifyall(op) + + op +end + +""" + @code_mlir f(args...) +""" +macro code_mlir(call) + @assert Meta.isexpr(call, :call) "only calls are supported" + + f = first(call.args) |> esc + args = Expr(:curly, + Tuple, + map(arg -> :($(Core.Typeof)($arg)), + call.args[begin+1:end])..., + ) |> esc + + quote + code_mlir($f, $args) + end +end + +end # module Brutus + +# --- + +function pow(x::F, n) where {F} + p = one(F) + for _ in 1:n + p *= x + end + p +end + +function f(x) + if x == 1 + 2 + else + 3 + end +end + +# --- + +using MLIR.IR, MLIR + +ctx = Context() + +MLIR.API.mlirContextEnableMultithreading(ctx, false) +MLIR.API.mlirRegisterAllLLVMTranslations(ctx) +MLIR.API.mlirRegisterAllPasses() + +op = Brutus.code_mlir(pow, Tuple{Float64, Int}) + +mod = MModule(ctx, Location(ctx)) +body = IR.get_body(mod) +push!(body, op) + +pm = IR.PassManager(ctx) +opm = IR.OpPassManager(pm, "builtin.module") + +# TODO: make high-level API for these +MLIR.API.mlirPassManagerEnableIRPrinting(pm) +MLIR.API.mlirPassManagerEnableVerifier(pm, true) + +# MLIR.API.mlirOpPassManagerAddOwnedPass(opm, MLIR.API.mlirCreateConversionConvertArithmeticToLLVM()) +# MLIR.API.mlirOpPassManagerAddOwnedPass(opm, MLIR.API.mlirCreateConversionConvertControlFlowToLLVM()) +# MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateConversionConvertFuncToLLVM()) + +# MLIR.API.mlirRegisterConversionConvertFuncToLLVM() +# MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateTransformsCanonicalizer()) +# MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateTransformsControlFlowSink()) +MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateTransformsTopologicalSort()) + +IR.add_pipeline!(opm, "convert-func-to-llvm") + +IR.run(pm, mod) + +mod diff --git a/lib/14/libMLIR_h.jl b/lib/14/libMLIR_h.jl index c9dfbb76..4a7e19a3 100644 --- a/lib/14/libMLIR_h.jl +++ b/lib/14/libMLIR_h.jl @@ -546,7 +546,7 @@ An auxiliary class for constructing operations. This class contains all the information necessary to construct the operation. It owns the MlirRegions it has pointers to and does not own anything else. By default, the state can be constructed from a name and location, the latter being also used to access the context, and has no other components. These components can be added progressively until the operation is constructed. Users are not expected to rely on the internals of this class and should use mlirOperationState* functions instead. """ -struct MlirOperationState +mutable struct MlirOperationState # TODO: make mutable in res name::MlirStringRef location::MlirLocation nResults::intptr_t diff --git a/lib/15/libMLIR_h.jl b/lib/15/libMLIR_h.jl index 702af778..e7b97ca8 100644 --- a/lib/15/libMLIR_h.jl +++ b/lib/15/libMLIR_h.jl @@ -656,7 +656,7 @@ An auxiliary class for constructing operations. This class contains all the information necessary to construct the operation. It owns the MlirRegions it has pointers to and does not own anything else. By default, the state can be constructed from a name and location, the latter being also used to access the context, and has no other components. These components can be added progressively until the operation is constructed. Users are not expected to rely on the internals of this class and should use mlirOperationState* functions instead. """ -struct MlirOperationState +mutable struct MlirOperationState name::MlirStringRef location::MlirLocation nResults::intptr_t diff --git a/src/Dialects.jl b/src/Dialects.jl new file mode 100644 index 00000000..f88fd4c3 --- /dev/null +++ b/src/Dialects.jl @@ -0,0 +1,204 @@ +module Dialects + +module arith + +using ...IR + +for (f, t) in Iterators.product( + (:add, :sub, :mul), + (:i, :f), +) + fname = Symbol(f, t) + @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + state = OperationState($(string("arith.", fname)), loc) + IR.add_operands!(state, operands) + IR.add_results!(state, [type]) + Operation(state) + end +end + +for fname in (:xori, :andi, :ori) + @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + state = OperationState($(string("arith.", fname)), loc) + IR.add_operands!(state, operands) + IR.add_results!(state, [type]) + Operation(state) + end +end + +for (f, t) in Iterators.product( + (:div, :max, :min), + (:si, :ui, :f), +) + fname = Symbol(f, t) + @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + state = OperationState($(string("arith.", fname)), loc) + IR.add_operands!(state, operands) + IR.add_results!(state, [type]) + Operation(state) + end +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop +for f in (:index_cast, :index_castui) + @eval function $f(context, operand; loc=Location(context)) + state = OperationState($(string("arith.", f)), loc) + add_operands!(state, [operand]) + add_results!(state, [IR.IndexType(context)]) + Operation(state) + end +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop +function extf(context, operand, type; loc=Location(context)) + state = OperationState("arith.exf", loc) + IR.add_results!(state, [type]) + IR.add_operands!(state, [operand]) + Operation(state) +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithsitofp-mlirarithsitofpop +function sitofp(context, operand, ftype=float(julia_type(eltype(get_type(operand)))); loc=Location(context)) + state = OperationState("arith.sitofp", loc) + type = get_type(operand) + IR.add_results!(state, [ + IR.is_tensor(type) ? + MType(context, ftype isa MType ? eltype(ftype) : MType(context, ftype), size(type)) : + MType(context, ftype) + ]) + IR.add_operands!(state, [operand]) + Operation(state) +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithfptosi-mlirarithfptosiop +function fptosi(context, operand, itype; loc=Location(context)) + state = OperationState("arith.fptosi", loc) + type = get_type(operand) + IR.add_results!(state, [ + IR.is_tensor(type) ? + MType(context, itype isa MType ? itype : MType(context, itype), size(type)) : + MType(context, itype) + ]) + IR.add_operands!(state, [operand]) + Operation(state) +end + + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop +function constant(context, value, type=MType(context, typeof(value)); loc=Location(context)) + state = OperationState("arith.constant", loc) + IR.add_results!(state, [type]) + IR.add_attributes!(state, [ + IR.NamedAttribute(context, "value", + Attribute(context, value, type)), + ]) + Operation(state) +end + +module Predicates + const eq = 0 + const ne = 1 + const slt = 2 + const sle = 3 + const sgt = 4 + const sge = 5 + const ult = 6 + const ule = 7 + const ugt = 8 + const uge = 9 +end + +function cmpi(context, predicate, operands; loc=Location(context)) + state = OperationState("arith.cmpi", loc) + IR.add_operands!(state, operands) + IR.add_attributes!(state, [ + IR.NamedAttribute(context, "predicate", + Attribute(context, predicate)) + ]) + IR.add_results!(state, [MType(context, Bool)]) + Operation(state) +end + +end # module arith + +module std +# for llvm 14 + +using ...IR + +function return_(context, operands; loc=Location(context)) + state = OperationState("std.return", loc) + IR.add_operands!(state, operands) + Operation(state) +end + +function br(context, dest, operands; loc=Location(context)) + state = OperationState("std.br", loc) + IR.add_successors!(state, [dest]) + IR.add_operands!(state, operands) + Operation(state) +end + +function cond_br( + context, cond, + true_dest, false_dest, + true_dest_operands, + false_dest_operands; + loc=Location(context), +) + state = OperationState("std.cond_br", loc) + IR.add_successors!(state, [true_dest, false_dest]) + IR.add_operands!(state, [cond, true_dest_operands..., false_dest_operands...]) + IR.add_attributes!(state, [ + IR.NamedAttribute(context, "operand_segment_sizes", + IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ]) + Operation(state) +end + +end # module std + +module func +# https://mlir.llvm.org/docs/Dialects/Func/ + +using ...IR + +function return_(context, operands; loc=Location(context)) + state = OperationState("func.return", loc) + IR.add_operands!(state, operands) + Operation(state) +end + +end # module func + +module cf + +using ...IR + +function br(context, dest, operands; loc=Location(context)) + state = OperationState("cf.br", loc) + IR.add_successors!(state, [dest]) + IR.add_operands!(state, operands) + Operation(state) +end + +function cond_br( + context, cond, + true_dest, false_dest, + true_dest_operands, + false_dest_operands; + loc=Location(context), +) + state = OperationState("cf.cond_br", loc) + IR.add_successors!(state, [true_dest, false_dest]) + IR.add_operands!(state, [cond, true_dest_operands..., false_dest_operands...]) + IR.add_attributes!(state, [ + IR.NamedAttribute(context, "operand_segment_sizes", + IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ]) + Operation(state) +end + +end # module cf + +end # module Dialects diff --git a/src/IR.jl b/src/IR.jl new file mode 100644 index 00000000..636f67db --- /dev/null +++ b/src/IR.jl @@ -0,0 +1,1143 @@ +module IR + +import ..API: API as LibMLIR + +export + Operation, + OperationState, + Location, + Context, + MModule, + Value, + MType, + Region, + Block, + Attribute, + NamedAttribute + +export + add_results!, + add_attributes!, + add_owned_regions!, + add_successors! + + +import Base: ==, String +using .LibMLIR: + MlirDialectRegistry, + MlirDialectHandle, + MlirAttribute, + MlirNamedAttribute, + MlirDialect, + MlirStringRef, + MlirOperation, + MlirOperationState, + MlirLocation, + MlirBlock, + MlirRegion, + MlirModule, + MlirContext, + MlirType, + MlirValue, + MlirIdentifier, + MlirPassManager, + MlirOpPassManager + +function mlirIsNull(val) + val.ptr == C_NULL +end + +function print_callback(str::MlirStringRef, userdata) + data = unsafe_wrap(Array, Base.convert(Ptr{Cchar}, str.data), str.length; own=false) + write(userdata isa Base.RefValue ? userdata[] : userdata, data) + return Cvoid() +end + +### String Ref + +String(strref::MlirStringRef) = + Base.unsafe_string(Base.convert(Ptr{Cchar}, strref.data), strref.length) +Base.convert(::Type{MlirStringRef}, s::String) = + MlirStringRef(Base.unsafe_convert(Cstring, s), sizeof(s)) + +### Identifier + +String(ident::MlirIdentifier) = String(LibMLIR.mlirIdentifierStr(ident)) + +### Dialect + +struct Dialect + dialect::MlirDialect + + Dialect(dialect) = begin + @assert !mlirIsNull(dialect) "cannot create Dialect from null MlirDialect" + new(dialect) + end +end + +Base.convert(::Type{MlirDialect}, dialect::Dialect) = dialect.dialect +function Base.show(io::IO, dialect::Dialect) + print(io, "Dialect(\"", String(LibMLIR.mlirDialectGetNamespace(dialect)), "\")") +end + +### DialectHandle + +struct DialectHandle + handle::LibMLIR.MlirDialectHandle +end + +function DialectHandle(s::Symbol) + s = Symbol("mlirGetDialectHandle__", s, "__") + DialectHandle(getproperty(LibMLIR, s)()) +end + +Base.convert(::Type{MlirDialectHandle}, handle::DialectHandle) = handle.handle + +### Dialect Registry + +mutable struct DialectRegistry + registry::MlirDialectRegistry +end +function DialectRegistry() + registry = LibMLIR.mlirDialectRegistryCreate() + @assert !mlirIsNull(registry) "cannot create DialectRegistry with null MlirDialectRegistry" + finalizer(DialectRegistry(registry)) do registry + LibMLIR.mlirDialectRegistryDestroy(registry.registry) + end +end + +function Base.insert!(registry::DialectRegistry, handle::DialectHandle) + LibMLIR.mlirDialectHandleInsertDialect(registry, handle) +end + +### Context + +mutable struct Context + context::MlirContext +end +function Context() + context = LibMLIR.mlirContextCreate() + @assert !mlirIsNull(context) "cannot create Context with null MlirContext" + finalizer(Context(context)) do context + LibMLIR.mlirContextDestroy(context.context) + end +end + +Base.convert(::Type{MlirContext}, c::Context) = c.context + +num_loaded_dialects(context) = LibMLIR.mlirContextGetNumLoadedDialects(context) +function get_or_load_dialect!(context, handle::DialectHandle) + mlir_dialect = LibMLIR.mlirDialectHandleLoadDialect(handle, context) + if mlirIsNull(mlir_dialect) + error("could not load dialect from handle $handle") + else + Dialect(mlir_dialect) + end +end +function get_or_load_dialect!(context, dialect::String) + get_or_load_dialect!(context, DialectHandle(Symbol(dialect))) +end + +is_registered_operation(context, opname) = LibMLIR.mlirContextIsRegisteredOperation(context, opname) + +### Location + +struct Location + location::MlirLocation + + Location(location) = begin + @assert !mlirIsNull(location) "cannot create Location with null MlirLocation" + new(location) + end +end + +Location(context::Context) = Location(LibMLIR.mlirLocationUnknownGet(context)) +Location(context::Context, filename, line, column=0) = + Location(LibMLIR.mlirLocationFileLineColGet(context, filename, line, column)) +Location(context::Context, lin::Core.LineInfoNode) = + Location(context, string(lin.file), lin.line) +Location(context::Context, lin::LineNumberNode) = + isnothing(lin.file) ? + Location(context) : + Location(context, string(lin.file), lin.line) +Location(context::Context, ::Nothing) = Location(context) + +Base.convert(::Type{MlirLocation}, location::Location) = location.location + +function Base.show(io::IO, location::Location) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + print(io, "Location(#= ") + GC.@preserve ref LibMLIR.mlirLocationPrint(location, c_print_callback, ref) + print(io, " =#)") +end + +### Type + +struct MType + type::MlirType + + MType(type) = begin + @assert !mlirIsNull(type) + new(type) + end +end + +MType(t::MType) = t +MType(context::Context, T::Type{<:Signed}) = + MType(LibMLIR.mlirIntegerTypeGet(context, sizeof(T) * 8)) +MType(context::Context, T::Type{<:Unsigned}) = + MType(LibMLIR.mlirIntegerTypeGet(context, sizeof(T) * 8)) +MType(context::Context, ::Type{Bool}) = + MType(LibMLIR.mlirIntegerTypeGet(context, 1)) +MType(context::Context, ::Type{Float32}) = + MType(LibMLIR.mlirF32TypeGet(context)) +MType(context::Context, ::Type{Float64}) = + MType(LibMLIR.mlirF64TypeGet(context)) +MType(context::Context, ft::Pair) = + MType(LibMLIR.mlirFunctionTypeGet(context, + length(ft.first), [MType(t) for t in ft.first], + length(ft.second), [MType(t) for t in ft.second])) +MType(context, a::AbstractArray{T}) where {T} = MType(context, MType(context, T), size(a)) +MType(context, ::Type{<:AbstractArray{T,N}}, dims) where {T,N} = + MType(LibMLIR.mlirRankedTensorTypeGetChecked( + Location(context), + N, collect(dims), + MType(context, T), + Attribute(), + )) +MType(context, element_type::MType, dims) = + MType(LibMLIR.mlirRankedTensorTypeGetChecked( + Location(context), + length(dims), collect(dims), + element_type, + Attribute(), + )) +MType(context, ::T) where {T<:Real} = MType(context, T) +MType(_, type::MType) = type + +IndexType(context) = MType(LibMLIR.mlirIndexTypeGet(context)) + +Base.convert(::Type{MlirType}, mtype::MType) = mtype.type + +function Base.eltype(type::MType) + if LibMLIR.mlirTypeIsAShaped(type) + MType(LibMLIR.mlirShapedTypeGetElementType(type)) + else + type + end +end + +function show_inner(io::IO, type::MType) + if LibMLIR.mlirTypeIsAInteger(type) + is_signless = LibMLIR.mlirIntegerTypeIsSignless(type) + is_signed = LibMLIR.mlirIntegerTypeIsSigned(type) + + width = LibMLIR.mlirIntegerTypeGetWidth(type) + t = if is_signed + "si" + elseif is_signless + "i" + else + "u" + end + print(io, t, width) + elseif LibMLIR.mlirTypeIsAF64(type) + print(io, "f64") + elseif LibMLIR.mlirTypeIsAF32(type) + print(io, "f32") + elseif LibMLIR.mlirTypeIsARankedTensor(type) + print(io, "tensor<") + s = size(type) + print(io, join(s, "x"), "x") + show_inner(io, eltype(type)) + print(io, ">") + elseif LibMLIR.mlirTypeIsAIndex(type) + print(io, "index") + else + print(io, "unknown") + end +end + +function Base.show(io::IO, type::MType) + print(io, "MType(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + GC.@preserve ref LibMLIR.mlirTypePrint(type, c_print_callback, ref) + print(io, " =#)") +end + +function inttype(size, issigned) + size == 1 && issigned && return Bool + ints = (Int8, Int16, Int32, Int64, Int128) + IT = ints[Int(log2(size)) - 2] + issigned ? IT : unsigned(IT) +end + +function julia_type(type::MType) + if LibMLIR.mlirTypeIsAInteger(type) + is_signed = LibMLIR.mlirIntegerTypeIsSigned(type) || + LibMLIR.mlirIntegerTypeIsSignless(type) + width = LibMLIR.mlirIntegerTypeGetWidth(type) + + try + inttype(width, is_signed) + catch + t = is_signed ? "i" : "u" + throw("could not convert type $(t)$(width) to julia") + end + elseif LibMLIR.mlirTypeIsAF32(type) + Float32 + elseif LibMLIR.mlirTypeIsAF64(type) + Float64 + else + throw("could not convert type $type to julia") + end +end + +Base.ndims(type::MType) = + if LibMLIR.mlirTypeIsAShaped(type) && LibMLIR.mlirShapedTypeHasRank(type) + LibMLIR.mlirShapedTypeGetRank(type) + else + 0 + end + +Base.size(type::MType, i::Int) = LibMLIR.mlirShapedTypeGetDimSize(type, i - 1) +Base.size(type::MType) = Tuple(size(type, i) for i in 1:ndims(type)) + +function is_tensor(type::MType) + LibMLIR.mlirTypeIsAShaped(type) +end + +function is_integer(type::MType) + LibMLIR.mlirTypeIsAInteger(type) +end + +is_function_type(mtype) = LibMLIR.mlirTypeIsAFunction(mtype) + +function get_num_inputs(ftype) + @assert is_function_type(ftype) "cannot get the number of inputs on type $(ftype), expected a function type" + LibMLIR.mlirFunctionTypeGetNumInputs(ftype) +end +function get_num_results(ftype) + @assert is_function_type(ftype) "cannot get the number of results on type $(ftype), expected a function type" + LibMLIR.mlirFunctionTypeGetNumResults(ftype) +end + +function get_input(ftype::MType, pos) + @assert is_function_type(ftype) "cannot get input on type $(ftype), expected a function type" + MType(LibMLIR.mlirFunctionTypeGetInput(ftype, pos - 1)) +end +function get_result(ftype::MType, pos=1) + @assert is_function_type(ftype) "cannot get result on type $(ftype), expected a function type" + MType(LibMLIR.mlirFunctionTypeGetResult(ftype, pos - 1)) +end + +### Attribute + +struct Attribute + attribute::MlirAttribute +end + +Attribute() = Attribute(LibMLIR.mlirAttributeGetNull()) +Attribute(context, s::AbstractString) = Attribute(LibMLIR.mlirStringAttrGet(context, s)) +Attribute(type::MType) = Attribute(LibMLIR.mlirTypeAttrGet(type)) +Attribute(context, f::F, type=MType(context, F)) where {F<:AbstractFloat} = Attribute( + LibMLIR.mlirFloatAttrDoubleGet(context, type, Float64(f)) +) +Attribute(context, i::T) where {T<:Integer} = Attribute( + LibMLIR.mlirIntegerAttrGet(MType(context, T), Int64(i)) +) +function Attribute(context, values::T) where {T<:AbstractArray{Int32}} + type = MType(context, T, size(values)) + Attribute( + LibMLIR.mlirDenseElementsAttrInt32Get(type, length(values), values) + ) +end +function Attribute(context, values::T) where {T<:AbstractArray{Int64}} + type = MType(context, T, size(values)) + Attribute( + LibMLIR.mlirDenseElementsAttrInt64Get(type, length(values), values) + ) +end +function Attribute(context, values::T) where {T<:AbstractArray{Float64}} + type = MType(context, T, size(values)) + Attribute( + LibMLIR.mlirDenseElementsAttrDoubleGet(type, length(values), values) + ) +end +function Attribute(context, values::T) where {T<:AbstractArray{Float32}} + type = MType(context, T, size(values)) + Attribute( + LibMLIR.mlirDenseElementsAttrFloatGet(type, length(values), values) + ) +end +function Attribute(context, values::AbstractArray{Int32}, type) + Attribute( + LibMLIR.mlirDenseElementsAttrInt32Get(type, length(values), values) + ) +end +function Attribute(context, values::AbstractArray{Int}, type) + Attribute( + LibMLIR.mlirDenseElementsAttrInt64Get(type, length(values), values) + ) +end +function Attribute(context, values::AbstractArray{Float32}, type) + Attribute( + LibMLIR.mlirDenseElementsAttrFloatGet(type, length(values), values) + ) +end +function ArrayAttribute(context, values::AbstractVector{Int}) + elements = Attribute.((context,), values) + Attribute( + LibMLIR.mlirArrayAttrGet(context, length(elements), elements) + ) +end +function ArrayAttribute(context, attributes::Vector{Attribute}) + Attribute( + LibMLIR.mlirArrayAttrGet(context, length(attributes), attributes), + ) +end +function DenseArrayAttribute(context, values::AbstractVector{Int}) + Attribute( + LibMLIR.mlirDenseI64ArrayGet(context, length(values), collect(values)) + ) +end +function Attribute(context, value::Int, type::MType) + Attribute( + LibMLIR.mlirIntegerAttrGet(type, value) + ) +end +function Attribute(context, value::Bool, ::MType=nothing) + Attribute( + LibMLIR.mlirBoolAttrGet(context, value) + ) +end + +Base.convert(::Type{MlirAttribute}, attribute::Attribute) = attribute.attribute +Base.parse(::Type{Attribute}, context, s) = + Attribute(LibMLIR.mlirAttributeParseGet(context, s)) + +function get_type(attribute::Attribute) + MType(LibMLIR.mlirAttributeGetType(attribute)) +end +function get_type_value(attribute) + @assert LibMLIR.mlirAttributeIsAType(attribute) "attribute $(attribute) is not a type" + MType(LibMLIR.mlirTypeAttrGetValue(attribute)) +end +function get_bool_value(attribute) + @assert LibMLIR.mlirAttributeIsABool(attribute) "attribute $(attribute) is not a boolean" + LibMLIR.mlirBoolAttrGetValue(attribute) +end +function get_string_value(attribute) + @assert LibMLIR.mlirAttributeIsAString(attribute) "attribute $(attribute) is not a string attribute" + String(LibMLIR.mlirStringAttrGetValue(attribute)) +end + +function Base.show(io::IO, attribute::Attribute) + print(io, "Attribute(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + GC.@preserve ref LibMLIR.mlirAttributePrint(attribute, c_print_callback, ref) + print(io, " =#)") +end + +### Named Attribute + +struct NamedAttribute + named_attribute::MlirNamedAttribute +end + +function NamedAttribute(context, name, attribute) + @assert !mlirIsNull(attribute.attribute) + NamedAttribute(LibMLIR.mlirNamedAttributeGet( + LibMLIR.mlirIdentifierGet(context, name), + attribute + )) +end + +Base.convert(::Type{MlirAttribute}, named_attribute::NamedAttribute) = + named_attribute.named_attribute + +### Value + +struct Value + value::MlirValue + + Value(value) = begin + @assert !mlirIsNull(value) "cannot create Value with null MlirValue" + new(value) + end +end + +get_type(value) = MType(LibMLIR.mlirValueGetType(value)) + +Base.convert(::Type{MlirValue}, value::Value) = value.value +Base.size(value::Value) = Base.size(get_type(value)) +Base.ndims(value::Value) = Base.ndims(get_type(value)) + +function Base.show(io::IO, value::Value) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + GC.@preserve ref LibMLIR.mlirValuePrint(value, c_print_callback, ref) +end + +is_a_op_result(value) = LibMLIR.mlirValueIsAOpResult(value) +is_a_block_argument(value) = LibMLIR.mlirValueIsABlockArgument(value) + +function set_type!(value, type) + @assert is_a_block_argument(value) "could not set type, value is not a block argument" + LibMLIR.mlirBlockArgumentSetType(value, type) + value +end + +function get_owner(value::Value) + if is_a_block_argument(value) + raw_block = LibMLIR.mlirBlockArgumentGetOwner(value) + if mlirIsNull(raw_block) + return nothing + end + + return Block(raw_block, false) + end + + raw_op = LibMLIR.mlirOpResultGetOwner(value) + if mlirIsNull(raw_op) + return nothing + end + + return Operation(raw_op, false) +end + +### OperationState + +struct OperationState + opstate::MlirOperationState +end + +OperationState(name, location) = OperationState(LibMLIR.mlirOperationStateGet(name, location)) + +add_results!(state, results) = + LibMLIR.mlirOperationStateAddResults(state, length(results), results) +add_operands!(state, operands) = + LibMLIR.mlirOperationStateAddOperands(state, length(operands), operands) +function add_owned_regions!(state, regions) + mlir_regions = Base.convert.(MlirRegion, regions) + lose_ownership!.(regions) + LibMLIR.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), mlir_regions) +end +add_attributes!(state, attributes) = + LibMLIR.mlirOperationStateAddAttributes(state, length(attributes), attributes) +add_successors!(state, successors) = + LibMLIR.mlirOperationStateAddSuccessors( + state, length(successors), + convert(Vector{LibMLIR.MlirBlock}, successors), + ) + +enable_type_inference!(state) = + LibMLIR.mlirOperationStateEnableResultTypeInference(state) + +Base.unsafe_convert(::Type{Ptr{MlirOperationState}}, state::OperationState) = + Base.unsafe_convert(Ptr{MlirOperationState}, Base.pointer_from_objref(state.opstate)) + +### Operation + +mutable struct Operation + operation::MlirOperation + @atomic owned::Bool + + Operation(operation, owned=true) = begin + @assert !mlirIsNull(operation) "cannot create Operation with null MlirOperation" + finalizer(new(operation, owned)) do op + if op.owned + LibMLIR.mlirOperationDestroy(op.operation) + end + end + end +end + +Operation(state::OperationState) = Operation(LibMLIR.mlirOperationCreate(state), true) + +Base.copy(operation::Operation) = Operation(LibMLIR.mlirOperationClone(operation)) + +num_regions(operation) = LibMLIR.mlirOperationGetNumRegions(operation) +function get_region(operation, i) + i ∈ 1:num_regions(operation) && throw(BoundsError(operation, i)) + Region(LibMLIR.mlirOperationGetRegion(operation, i - 1), false) +end +num_results(operation) = LibMLIR.mlirOperationGetNumResults(operation) +get_results(operation) = [ + get_result(operation, i) + for i in 1:num_results(operation) +] +function get_result(operation::Operation, i=1) + i ∉ 1:num_results(operation) && throw(BoundsError(operation, i)) + Value(LibMLIR.mlirOperationGetResult(operation, i - 1)) +end +num_operands(operation) = LibMLIR.mlirOperationGetNumOperands(operation) +function get_operand(operation, i=1) + i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) + Value(LibMLIR.mlirOperationGetOperand(operation, i - 1)) +end +function set_operand!(operation, i, value) + i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) + LibMLIR.mlirOperationSetOperand(operation, i - 1, value) + value +end + +function get_attribute_by_name(operation, name) + raw_attr = LibMLIR.mlirOperationGetAttributeByName(operation, name) + if mlirIsNull(raw_attr) + return nothing + end + Attribute(raw_attr) +end +function set_attribute_by_name!(operation, name, attribute) + LibMLIR.mlirOperationSetAttributeByName(operation, name, attribute) + operation +end + +get_location(operation) = Location(LibMLIR.mlirOperationGetLocation(operation)) +get_name(operation) = String(LibMLIR.mlirOperationGetName(operation)) +get_block(operation) = Block(LibMLIR.mlirOperationGetBlock(operation), false) +get_parent_operation(operation) = Operation(LibMLIR.mlirOperationGetParentOperation(operation), false) +get_dialect(operation) = first(split(get_name(operation), '.')) |> Symbol + +function get_first_region(op::Operation) + reg = iterate(RegionIterator(op)) + isnothing(reg) && return nothing + first(reg) +end +function get_first_block(op::Operation) + reg = get_first_region(op) + isnothing(reg) && return nothing + block = iterate(BlockIterator(reg)) + isnothing(block) && return nothing + first(block) +end +function get_first_child_op(op::Operation) + block = get_first_block(op) + isnothing(block) && return nothing + cop = iterate(OperationIterator(block)) + first(cop) +end + +op::Operation == other::Operation = LibMLIR.mlirOperationEqual(op, other) + +Base.convert(::Type{MlirOperation}, op::Operation) = op.operation + +function lose_ownership!(operation::Operation) + @assert operation.owned + @atomic operation.owned = false + operation +end + +function Base.show(io::IO, operation::Operation) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + flags = LibMLIR.mlirOpPrintingFlagsCreate() + get(io, :debug, false) && LibMLIR.mlirOpPrintingFlagsEnableDebugInfo(flags, true, true) + GC.@preserve ref LibMLIR.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) + println(io) +end + +verify(operation::Operation) = LibMLIR.mlirOperationVerify(operation) + +### Block + +mutable struct Block + block::MlirBlock + @atomic owned::Bool + + Block(block::MlirBlock, owned::Bool=true) = begin + @assert !mlirIsNull(block) "cannot create Block with null MlirBlock" + finalizer(new(block, owned)) do block + if block.owned + LibMLIR.mlirBlockDestroy(block.block) + end + end + end +end + +Block() = Block(MType[], Location[]) +function Block(args::Vector{MType}, locs::Vector{Location}) + @assert length(args) == length(locs) "there should be one args for each locs (got $(length(args)) & $(length(locs)))" + Block(LibMLIR.mlirBlockCreate(length(args), args, locs)) +end + +function Base.push!(block::Block, op::Operation) + LibMLIR.mlirBlockAppendOwnedOperation(block, lose_ownership!(op)) + op +end +function Base.insert!(block::Block, pos, op::Operation) + LibMLIR.mlirBlockInsertOwnedOperation(block, pos - 1, lose_ownership!(op)) + op +end +function Base.pushfirst!(block::Block, op::Operation) + insert!(block, 1, op) + op +end +function insert_after!(block::Block, reference::Operation, op::Operation) + LibMLIR.mlirBlockInsertOwnedOperationAfter(block, reference, lose_ownership!(op)) + op +end +function insert_before!(block::Block, reference::Operation, op::Operation) + LibMLIR.mlirBlockInsertOwnedOperationBefore(block, reference, lose_ownership!(op)) + op +end + +num_arguments(block::Block) = + LibMLIR.mlirBlockGetNumArguments(block) +function get_argument(block::Block, i) + i ∉ 1:num_arguments(block) && throw(BoundsError(block, i)) + Value(LibMLIR.mlirBlockGetArgument(block, i - 1)) +end +push_argument!(block::Block, type, loc) = + Value(LibMLIR.mlirBlockAddArgument(block, type, loc)) + +Base.convert(::Type{MlirBlock}, block::Block) = block.block + +function lose_ownership!(block::Block) + @assert block.owned + @atomic block.owned = false + block +end + +function Base.show(io::IO, block::Block) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + GC.@preserve ref LibMLIR.mlirBlockPrint(block, c_print_callback, ref) +end + +### Region + +mutable struct Region + region::MlirRegion + @atomic owned::Bool # TODO: make atomic? + + Region(region, owned=true) = begin + @assert !mlirIsNull(region) + finalizer(new(region, owned)) do region + if region.owned + LibMLIR.mlirRegionDestroy(region.region) + end + end + end +end + +Region() = Region(LibMLIR.mlirRegionCreate()) + +function Base.push!(region::Region, block::Block) + LibMLIR.mlirRegionAppendOwnedBlock(region, lose_ownership!(block)) + block +end +function Base.insert!(region::Region, pos, block::Block) + LibMLIR.mlirRegionInsertOwnedBlock(region, pos - 1, lose_ownership!(block)) + block +end +function Base.pushfirst!(region::Region, block) + insert!(region, 1, block) + block +end +insert_after!(region::Region, reference::Block, block::Block) = + LibMLIR.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) +insert_before!(region::Region, reference::Block, block::Block) = + LibMLIR.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) + +function get_first_block(region::Region) + block = iterate(BlockIterator(region)) + isnothing(block) && return nothing + first(block) +end + +function lose_ownership!(region::Region) + @assert region.owned + @atomic region.owned = false + region +end + +Base.convert(::Type{MlirRegion}, region::Region) = region.region + +### Module + +mutable struct MModule + module_::MlirModule + context::Context + + MModule(module_, context) = begin + @assert !mlirIsNull(module_) "cannot create MModule with null MlirModule" + finalizer(LibMLIR.mlirModuleDestroy, new(module_, context)) + end +end + +MModule(context::Context, loc=Location(context)) = + MModule(LibMLIR.mlirModuleCreateEmpty(loc), context) +get_operation(module_) = Operation(LibMLIR.mlirModuleGetOperation(module_), false) +get_body(module_) = Block(LibMLIR.mlirModuleGetBody(module_), false) +get_first_child_op(mod::MModule) = get_first_child_op(get_operation(mod)) + +Base.convert(::Type{MlirModule}, module_::MModule) = module_.module_ +Base.parse(::Type{MModule}, context, module_) = MModule(LibMLIR.mlirModuleCreateParse(context, module_), context) + +macro mlir_str(code) + quote + ctx = Context() + parse(MModule, ctx, code) + end +end + +function Base.show(io::IO, module_::MModule) + println(io, "MModule:") + show(io, get_operation(module_)) +end + +### TypeID + +struct TypeID + typeid::LibMLIR.MlirTypeID +end + +Base.hash(typeid::TypeID) = LibMLIR.mlirTypeIDHashValue(typeid.typeid) +Base.convert(::Type{LibMLIR.MlirTypeID}, typeid::TypeID) = typeid.typeid + +@static if isdefined(LibMLIR, :MlirTypeIDAllocator) + +### TypeIDAllocator + +mutable struct TypeIDAllocator + allocator::LibMLIR.MlirTypeIDAllocator + + function TypeIDAllocator() + ptr = LibMLIR.mlirTypeIDAllocatorCreate() + @assert ptr != C_NULL "cannot create TypeIDAllocator" + finalizer(LibMLIR.mlirTypeIDAllocatorDestroy, new(ptr)) + end +end + +Base.convert(::Type{LibMLIR.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator.allocator + +TypeID(allocator::TypeIDAllocator) = TypeID(LibMLIR.mlirTypeIDCreate(allocator)) + +else + +struct TypeIDAllocator end + +end + +### Pass Manager + +abstract type AbstractPass end + +mutable struct ExternalPassHandle + ctx::Union{Nothing,Context} + pass::AbstractPass +end + +mutable struct PassManager + pass::MlirPassManager + context::Context + allocator::TypeIDAllocator + passes::Dict{TypeID,ExternalPassHandle} + + PassManager(pm::MlirPassManager, context) = begin + @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" + finalizer(new(pm, context, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm + LibMLIR.mlirPassManagerDestroy(pm.pass) + end + end +end + +function enable_verifier!(pm) + LibMLIR.mlirPassManagerEnableVerifier(pm) + pm +end + +PassManager(context) = + PassManager(LibMLIR.mlirPassManagerCreate(context), context) + +function run(pm::PassManager, module_) + status = LibMLIR.mlirPassManagerRun(pm, module_) + if mlirLogicalResultIsFailure(status) + throw("failed to run pass manager on module") + end + module_ +end + +Base.convert(::Type{MlirPassManager}, pass::PassManager) = pass.pass + +### Op Pass Manager + +struct OpPassManager + op_pass::MlirOpPassManager + pass::PassManager + + OpPassManager(op_pass, pass) = begin + @assert !mlirIsNull(op_pass) "cannot create OpPassManager with null MlirOpPassManager" + new(op_pass, pass) + end +end + +OpPassManager(pm::PassManager) = OpPassManager(LibMLIR.mlirPassManagerGetAsOpPassManager(pm), pm) +OpPassManager(pm::PassManager, opname) = OpPassManager(LibMLIR.mlirPassManagerGetNestedUnder(pm, opname), pm) +OpPassManager(opm::OpPassManager, opname) = OpPassManager(LibMLIR.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) + +Base.convert(::Type{MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass + +function Base.show(io::IO, op_pass::OpPassManager) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + println(io, "OpPassManager(\"\"\"") + GC.@preserve ref LibMLIR.mlirPrintPassPipeline(op_pass, c_print_callback, ref) + println(io) + print(io, "\"\"\")") +end + +struct AddPipelineException <: Exception + message::String +end + +function Base.showerror(io::IO, err::AddPipelineException) + print(io, "failed to add pipeline:", err.message) + nothing +end + +mlirLogicalResultIsFailure(result) = result.value == 0 + +function add_pipeline!(op_pass::OpPassManager, pipeline) + @static if isdefined(LibMLIR, :mlirOpPassManagerAddPipeline) + io = IOBuffer() + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + result = GC.@preserve io LibMLIR.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io) + if LibMLIR.mlirLogicalResultIsFailure(result) + exc = AddPipelineException(String(take!(io))) + throw(exc) + end + else + result = LibMLIR.mlirParsePassPipeline(op_pass, pipeline) + if mlirLogicalResultIsFailure(result) + throw(AddPipelineException(" " * pipeline)) + end + end + op_pass +end + +@static if isdefined(LibMLIR, :mlirCreateExternalPass) + +### Pass + +# AbstractPass interface: +get_opname(::AbstractPass) = "" +function pass_run(::Context, ::P, op) where {P<:AbstractPass} + error("pass $P does not implement `MLIR.pass_run`") +end + +function _pass_construct(ptr::ExternalPassHandle) + nothing +end + +function _pass_destruct(ptr::ExternalPassHandle) + nothing +end + +function _pass_initialize(ctx, handle::ExternalPassHandle) + try + handle.ctx = Context(ctx) + LibMLIR.mlirLogicalResultSuccess() + catch + LibMLIR.mlirLogicalResultFailure() + end +end + +function _pass_clone(handle::ExternalPassHandle) + ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) +end + +function _pass_run(rawop, external_pass, handle::ExternalPassHandle) + op = Operation(rawop, false) + try + pass_run(handle.ctx, handle.pass, op) + catch ex + @error "Something went wrong running pass" exception=(ex,catch_backtrace()) + LibMLIR.mlirExternalPassSignalFailure(external_pass) + end + nothing +end + +function create_external_pass!(oppass::OpPassManager, args...) + create_external_pass!(oppass.pass, args...) +end +function create_external_pass!(manager, pass, name, argument, + description, opname=get_opname(pass), + dependent_dialects=MlirDialectHandle[]) + passid = TypeID(manager.allocator) + callbacks = LibMLIR.MlirExternalPassCallbacks( + @cfunction(_pass_construct, Cvoid, (Any,)), + @cfunction(_pass_destruct, Cvoid, (Any,)), + @cfunction(_pass_initialize, LibMLIR.MlirLogicalResult, (MlirContext, Any,)), + @cfunction(_pass_clone, Any, (Any,)), + @cfunction(_pass_run, Cvoid, (MlirOperation, LibMLIR.MlirExternalPass, Any)) + ) + pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) + userdata = Base.pointer_from_objref(pass_handle) + mlir_pass = LibMLIR.mlirCreateExternalPass(passid, name, argument, description, opname, + length(dependent_dialects), dependent_dialects, + callbacks, userdata) + mlir_pass +end + +function add_owned_pass!(pm::PassManager, pass) + LibMLIR.mlirPassManagerAddOwnedPass(pm, pass) + pm +end + +function add_owned_pass!(opm::OpPassManager, pass) + LibMLIR.mlirOpPassManagerAddOwnedPass(opm, pass) + opm +end + +end + +### Iterators + +""" + BlockIterator(region::Region) + +Iterates over all blocks in the given region. +""" +struct BlockIterator + region::Region +end + +function Base.iterate(it::BlockIterator) + reg = it.region + raw_block = LibMLIR.mlirRegionGetFirstBlock(reg) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +function Base.iterate(it::BlockIterator, block) + raw_block = LibMLIR.mlirBlockGetNextInRegion(block) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +""" + OperationIterator(block::Block) + +Iterates over all operations for the given block. +""" +struct OperationIterator + block::Block +end + +function Base.iterate(it::OperationIterator) + raw_op = LibMLIR.mlirBlockGetFirstOperation(it.block) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +function Base.iterate(it::OperationIterator, op) + raw_op = LibMLIR.mlirOperationGetNextInBlock(op) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +""" + RegionIterator(::Operation) + +Iterates over all sub-regions for the given operation. +""" +struct RegionIterator + op::Operation +end + +function Base.iterate(it::RegionIterator) + raw_region = LibMLIR.mlirOperationGetFirstRegion(it.op) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +function Base.iterate(it::RegionIterator, region) + raw_region = LibMLIR.mlirRegionGetNextInOperation(region) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +### Utils + +function visit(f, op) + for region in RegionIterator(op) + for block in BlockIterator(region) + for op in OperationIterator(block) + f(op) + end + end + end +end + +""" + verifyall(operation; debug=false) + +Prints the operations which could not be verified. +""" +function verifyall(operation::Operation; debug=false) + io = IOContext(stdout, :debug => debug) + visit(operation) do op + if !verify(op) + show(io, op) + end + end +end +verifyall(module_::MModule) = get_operation(module_) |> verifyall + +function get_dialects!(dialects::Set{Symbol}, op::Operation) + push!(dialects, get_dialect(op)) + + visit(op) do op + get_dialects!(dialects, op) + end + + dialects +end + +function get_input_type(module_) + dialects = Set{Symbol}() + + op = get_operation(module_) + get_dialects!(dialects, op) + + if :mhlo ∈ dialects + # :tosa ∉ dialects || throw("cannot have both tosa and mhlo operations") + :mhlo + elseif :tosa ∈ dialects + :tosa + else + :none + end +end + +end # module IR diff --git a/src/MLIR.jl b/src/MLIR.jl index c60e67e4..4dead798 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -35,4 +35,7 @@ function Base.unsafe_convert(::Type{API.MlirStringRef}, s::Union{Symbol, String, return API.MlirStringRef(p, length(s)) end -end # module +include("./IR.jl") +include("./Dialects.jl") + +end # module MLIR From d438436e8ad82e2615c677023406f8bd8192649a Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 16 May 2023 21:56:46 +0200 Subject: [PATCH 02/11] Make toy brutus work for both 14 & 15 --- examples/brutus-14.jl | 263 --------------------------- examples/brutus.jl | 57 +++--- src/IR.jl | 408 ++++++++++++++++++++---------------------- 3 files changed, 223 insertions(+), 505 deletions(-) delete mode 100644 examples/brutus-14.jl diff --git a/examples/brutus-14.jl b/examples/brutus-14.jl deleted file mode 100644 index 53c3c640..00000000 --- a/examples/brutus-14.jl +++ /dev/null @@ -1,263 +0,0 @@ -module Brutus - -using MLIR.IR -using MLIR.Dialects: arith, std -using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode - - -const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} - -function cmpi_pred(predicate) - function(ctx, ops; loc=Location(ctx)) - arith.cmpi(ctx, predicate, ops; loc) - end -end - -function single_op_wrapper(fop) - (ctx::Context, block::Block, args::Vector{Value}; loc=Location(ctx)) -> push!(block, fop(ctx, args; loc)) -end - -const intrinsics_to_mlir = Dict([ - Base.add_int => single_op_wrapper(arith.addi), - Base.sle_int => single_op_wrapper(cmpi_pred(arith.Predicates.sle)), - Base.slt_int => single_op_wrapper(cmpi_pred(arith.Predicates.slt)), - Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), - Base.mul_int => single_op_wrapper(arith.muli), - Base.mul_float => single_op_wrapper(arith.mulf), - # TODO: i don't know how to do a bitwise negation in any other way - Base.not_int => function(ctx, block, args; loc=Location(ctx)) - arg = only(args) - ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result - push!(block, arith.xori(ctx, Value[arg, ones]; loc)) - end, -]) - -"Generates a block argument for each phi node present in the block." -function prepare_block(ctx, ir, bb) - b = Block() - - for sidx in bb.stmts - stmt = ir.stmts[sidx] - inst = stmt[:inst] - inst isa Core.PhiNode || continue - - type = stmt[:type] - IR.push_argument!(b, MType(ctx, type), Location(ctx)) - end - - return b -end - -"Values to populate the Phi Node when jumping from `from` to `to`." -function collect_value_arguments(ir, from, to) - to = ir.cfg.blocks[to] - values = [] - for s in to.stmts - stmt = ir.stmts[s] - inst = stmt[:inst] - inst isa Core.PhiNode || continue - - edge = findfirst(==(from), inst.edges) - if isnothing(edge) # use dummy scalar val instead - val = zero(stmt[:type]) - push!(values, val) - else - push!(values, inst.values[edge]) - end - end - values -end - -""" - code_mlir(f, types::Type{Tuple}; ctx=Context()) -> IR.Operation - -Returns a `builtin.func` operation corresponding to the ircode of the provided method. -This only supports a few Julia Core primitives and scalar types of type $BrutusScalar. - -!!! note - The Julia SSAIR to MLIR conversion implemented is very primitive and only supports a - handful of primitives. A better to perform this conversion would to create a dialect - representing Julia IR and progressively lower it to base MLIR dialects. -""" -function code_mlir(f, types; ctx=Context()) - ir, ret = Core.Compiler.code_ircode(f, types) |> only - @assert first(ir.argtypes) isa Core.Const - - values = Vector{Value}(undef, length(ir.stmts)) - @show length(ir.stmts) length(values) - - for dialect in ("std",) - IR.get_or_load_dialect!(ctx, dialect) - end - - - blocks = [ - prepare_block(ctx, ir, bb) - for bb in ir.cfg.blocks - ] - - current_block = entry_block = blocks[begin] - - for argtype in types.parameters - IR.push_argument!(entry_block, MType(ctx, argtype), Location(ctx)) - end - - function get_value(x)::Value - if x isa Core.SSAValue - @assert isassigned(values, x.id) "value $x was not assigned" - values[x.id] - elseif x isa Core.Argument - IR.get_argument(entry_block, x.n - 1) - elseif x isa Number - IR.get_result(push!(current_block, arith.constant(ctx, x))) - else - error("could not use value $x inside MLIR") - end - end - - for (block_id, (b, bb)) in enumerate(zip(blocks, ir.cfg.blocks)) - current_block = b - n_phi_nodes = 0 - - for sidx in bb.stmts - stmt = ir.stmts[sidx] - inst = stmt[:inst] - line = ir.linetable[stmt[:line]] - - if Meta.isexpr(inst, :call) - line = ir.linetable[stmt[:line]] - val_type = stmt[:type] - if !(val_type <: BrutusScalar) - error("type $val_type is not supported") - end - out_type = MType(ctx, val_type) - - called_func = first(inst.args) - if called_func isa GlobalRef # TODO: should probably use something else here - called_func = getproperty(called_func.mod, called_func.name) - end - - fop! = intrinsics_to_mlir[called_func] - args = get_value.(@view inst.args[begin+1:end]) - - res = IR.get_result(fop!(ctx, current_block, args; loc=Location(ctx, line))) - - values[sidx] = res - elseif inst isa PhiNode - values[sidx] = IR.get_argument(current_block, n_phi_nodes += 1) - elseif inst isa PiNode - values[sidx] = get_value(inst.val) - elseif inst isa GotoNode - args = get_value.(collect_value_arguments(ir, block_id, inst.label)) - dest = blocks[inst.label] - push!(current_block, std.br(ctx, dest, args; loc=Location(ctx, line))) - elseif inst isa GotoIfNot - false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) - cond = get_value(inst.cond) - @assert length(bb.succs) == 2 # NOTE: We assume that length(bb.succs) == 2, this might be wrong - other_dest = setdiff(bb.succs, inst.dest) |> only - true_args = get_value.(collect_value_arguments(ir, block_id, other_dest)) - other_dest = blocks[other_dest] - dest = blocks[inst.dest] - - cond_br = std.cond_br(ctx, cond, other_dest, dest, true_args, false_args; loc=Location(ctx, line)) - push!(current_block, cond_br) - elseif inst isa ReturnNode - line = ir.linetable[stmt[:line]] - push!(current_block, std.return_(ctx, [get_value(inst.val)]; loc=Location(ctx, line))) - else - error("unhandled ir $(inst)") - end - end - end - - func_name = nameof(f) - - region = Region() - for b in blocks - push!(region, b) - end - - state = OperationState("builtin.func", Location(ctx)) - - input_types = MType[ - IR.get_type(IR.get_argument(entry_block, i)) - for i in 1:IR.num_arguments(entry_block) - ] - result_types = [MType(ctx, ret)] - - ftype = MType(ctx, input_types => result_types) - IR.add_attributes!(state, [ - NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), - NamedAttribute(ctx, "type", IR.Attribute(ftype)), - ]) - IR.add_owned_regions!(state, Region[region]) - - op = Operation(state) - - IR.verifyall(op) - - op -end - -""" - @code_mlir f(args...) -""" -macro code_mlir(call) - @assert Meta.isexpr(call, :call) "only calls are supported" - - f = first(call.args) |> esc - args = Expr(:curly, - Tuple, - map(arg -> :($(Core.Typeof)($arg)), - call.args[begin+1:end])..., - ) |> esc - - quote - code_mlir($f, $args) - end -end - -end # module Brutus - -# --- - -function pow(x::F, n) where {F} - p = one(F) - for _ in 1:n - p *= x - end - p -end - -function f(x) - if x == 1 - 2 - else - 3 - end -end - -# --- - -using MLIR.IR, MLIR - -ctx = Context() -MLIR.API.mlirContextEnableMultithreading(ctx, false) - -op = Brutus.code_mlir(pow, Tuple{Float64, Int}) - -mod = MModule(ctx, Location(ctx)) -body = IR.get_body(mod) -push!(body, op) - -pm = IR.PassManager(ctx) -opm = IR.OpPassManager(pm, "builtin.func") - -# TODO: make high-level API for these -MLIR.API.mlirPassManagerEnableIRPrinting(pm) -MLIR.API.mlirPassManagerEnableVerifier(pm, true) -MLIR.API.mlirOpPassManagerAddOwnedPass(opm, MLIR.API.mlirCreateConversionConvertArithmeticToLLVM()) -MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateConversionConvertStandardToLLVM()) - -IR.run(pm, mod) diff --git a/examples/brutus.jl b/examples/brutus.jl index 93cefe4a..496db864 100644 --- a/examples/brutus.jl +++ b/examples/brutus.jl @@ -1,10 +1,10 @@ module Brutus +import LLVM using MLIR.IR -using MLIR.Dialects: arith, func, cf +using MLIR.Dialects: arith, func, cf, std using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode - const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} function cmpi_pred(predicate) @@ -24,7 +24,6 @@ const intrinsics_to_mlir = Dict([ Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), Base.mul_int => single_op_wrapper(arith.muli), Base.mul_float => single_op_wrapper(arith.mulf), - # TODO: i don't know how to do a bitwise negation in any other way Base.not_int => function(ctx, block, args; loc=Location(ctx)) arg = only(args) ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result @@ -84,9 +83,8 @@ function code_mlir(f, types; ctx=Context()) @assert first(ir.argtypes) isa Core.Const values = Vector{Value}(undef, length(ir.stmts)) - @show length(ir.stmts) length(values) - for dialect in ("func", "cf") + for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",)) IR.get_or_load_dialect!(ctx, dialect) end @@ -107,7 +105,7 @@ function code_mlir(f, types; ctx=Context()) values[x.id] elseif x isa Core.Argument IR.get_argument(entry_block, x.n - 1) - elseif x isa Number + elseif x isa BrutusScalar IR.get_result(push!(current_block, arith.constant(ctx, x))) else error("could not use value $x inside MLIR") @@ -149,7 +147,8 @@ function code_mlir(f, types; ctx=Context()) elseif inst isa GotoNode args = get_value.(collect_value_arguments(ir, block_id, inst.label)) dest = blocks[inst.label] - push!(current_block, cf.br(ctx, dest, args; loc=Location(ctx, line))) + brop = LLVM.version() >= v"15" ? cf.br : std.br + push!(current_block, brop(ctx, dest, args; loc=Location(ctx, line))) elseif inst isa GotoIfNot false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) cond = get_value(inst.cond) @@ -159,11 +158,13 @@ function code_mlir(f, types; ctx=Context()) other_dest = blocks[other_dest] dest = blocks[inst.dest] - cond_br = cf.cond_br(ctx, cond, other_dest, dest, true_args, false_args; loc=Location(ctx, line)) + cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br + cond_br = cond_brop(ctx, cond, other_dest, dest, true_args, false_args; loc=Location(ctx, line)) push!(current_block, cond_br) elseif inst isa ReturnNode line = ir.linetable[stmt[:line]] - push!(current_block, func.return_(ctx, [get_value(inst.val)]; loc=Location(ctx, line))) + retop = LLVM.version() >= v"15" ? func.return_ : std.return_ + push!(current_block, retop(ctx, [get_value(inst.val)]; loc=Location(ctx, line))) else error("unhandled ir $(inst)") end @@ -177,7 +178,8 @@ function code_mlir(f, types; ctx=Context()) push!(region, b) end - state = OperationState("func.func", Location(ctx)) + LLVM15 = LLVM.version() >= v"15" + state = OperationState(LLVM15 ? "func.func" : "builtin.func", Location(ctx)) input_types = MType[ IR.get_type(IR.get_argument(entry_block, i)) @@ -188,7 +190,7 @@ function code_mlir(f, types; ctx=Context()) ftype = MType(ctx, input_types => result_types) IR.add_attributes!(state, [ NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), - NamedAttribute(ctx, "function_type", IR.Attribute(ftype)), + NamedAttribute(ctx, LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), ]) IR.add_owned_regions!(state, Region[region]) @@ -239,38 +241,33 @@ end # --- +using Test using MLIR.IR, MLIR ctx = Context() +# IR.enable_multithreading!(ctx, false) -MLIR.API.mlirContextEnableMultithreading(ctx, false) -MLIR.API.mlirRegisterAllLLVMTranslations(ctx) -MLIR.API.mlirRegisterAllPasses() - -op = Brutus.code_mlir(pow, Tuple{Float64, Int}) +op = Brutus.code_mlir(pow, Tuple{Int, Int}; ctx) mod = MModule(ctx, Location(ctx)) body = IR.get_body(mod) push!(body, op) pm = IR.PassManager(ctx) -opm = IR.OpPassManager(pm, "builtin.module") +opm = IR.OpPassManager(pm) -# TODO: make high-level API for these -MLIR.API.mlirPassManagerEnableIRPrinting(pm) -MLIR.API.mlirPassManagerEnableVerifier(pm, true) +# IR.enable_ir_printing!(pm) +IR.enable_verifier!(pm, true) -# MLIR.API.mlirOpPassManagerAddOwnedPass(opm, MLIR.API.mlirCreateConversionConvertArithmeticToLLVM()) -# MLIR.API.mlirOpPassManagerAddOwnedPass(opm, MLIR.API.mlirCreateConversionConvertControlFlowToLLVM()) -# MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateConversionConvertFuncToLLVM()) +MLIR.API.mlirRegisterAllPasses() +MLIR.API.mlirRegisterAllLLVMTranslations(ctx) +IR.add_pipeline!(opm, Brutus.LLVM.version() >= v"15" ? "convert-arith-to-llvm,convert-func-to-llvm" : "convert-std-to-llvm") -# MLIR.API.mlirRegisterConversionConvertFuncToLLVM() -# MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateTransformsCanonicalizer()) -# MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateTransformsControlFlowSink()) -MLIR.API.mlirPassManagerAddOwnedPass(pm, MLIR.API.mlirCreateTransformsTopologicalSort()) +IR.run!(pm, mod) -IR.add_pipeline!(opm, "convert-func-to-llvm") +jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) +fptr = MLIR.API.mlirExecutionEngineLookup(jit, "pow") -IR.run(pm, mod) +x, y = 3, 4 -mod +@test ccall(fptr, Int, (Int, Int), x, y) == pow(x, y) diff --git a/src/IR.jl b/src/IR.jl index 636f67db..47990358 100644 --- a/src/IR.jl +++ b/src/IR.jl @@ -1,6 +1,6 @@ module IR -import ..API: API as LibMLIR +import ..API: API export Operation, @@ -23,7 +23,7 @@ export import Base: ==, String -using .LibMLIR: +using .API: MlirDialectRegistry, MlirDialectHandle, MlirAttribute, @@ -62,7 +62,7 @@ Base.convert(::Type{MlirStringRef}, s::String) = ### Identifier -String(ident::MlirIdentifier) = String(LibMLIR.mlirIdentifierStr(ident)) +String(ident::MlirIdentifier) = String(API.mlirIdentifierStr(ident)) ### Dialect @@ -77,18 +77,18 @@ end Base.convert(::Type{MlirDialect}, dialect::Dialect) = dialect.dialect function Base.show(io::IO, dialect::Dialect) - print(io, "Dialect(\"", String(LibMLIR.mlirDialectGetNamespace(dialect)), "\")") + print(io, "Dialect(\"", String(API.mlirDialectGetNamespace(dialect)), "\")") end ### DialectHandle struct DialectHandle - handle::LibMLIR.MlirDialectHandle + handle::API.MlirDialectHandle end function DialectHandle(s::Symbol) s = Symbol("mlirGetDialectHandle__", s, "__") - DialectHandle(getproperty(LibMLIR, s)()) + DialectHandle(getproperty(API, s)()) end Base.convert(::Type{MlirDialectHandle}, handle::DialectHandle) = handle.handle @@ -99,15 +99,15 @@ mutable struct DialectRegistry registry::MlirDialectRegistry end function DialectRegistry() - registry = LibMLIR.mlirDialectRegistryCreate() + registry = API.mlirDialectRegistryCreate() @assert !mlirIsNull(registry) "cannot create DialectRegistry with null MlirDialectRegistry" finalizer(DialectRegistry(registry)) do registry - LibMLIR.mlirDialectRegistryDestroy(registry.registry) + API.mlirDialectRegistryDestroy(registry.registry) end end function Base.insert!(registry::DialectRegistry, handle::DialectHandle) - LibMLIR.mlirDialectHandleInsertDialect(registry, handle) + API.mlirDialectHandleInsertDialect(registry, handle) end ### Context @@ -116,18 +116,18 @@ mutable struct Context context::MlirContext end function Context() - context = LibMLIR.mlirContextCreate() + context = API.mlirContextCreate() @assert !mlirIsNull(context) "cannot create Context with null MlirContext" finalizer(Context(context)) do context - LibMLIR.mlirContextDestroy(context.context) + API.mlirContextDestroy(context.context) end end Base.convert(::Type{MlirContext}, c::Context) = c.context -num_loaded_dialects(context) = LibMLIR.mlirContextGetNumLoadedDialects(context) +num_loaded_dialects(context) = API.mlirContextGetNumLoadedDialects(context) function get_or_load_dialect!(context, handle::DialectHandle) - mlir_dialect = LibMLIR.mlirDialectHandleLoadDialect(handle, context) + mlir_dialect = API.mlirDialectHandleLoadDialect(handle, context) if mlirIsNull(mlir_dialect) error("could not load dialect from handle $handle") else @@ -138,7 +138,12 @@ function get_or_load_dialect!(context, dialect::String) get_or_load_dialect!(context, DialectHandle(Symbol(dialect))) end -is_registered_operation(context, opname) = LibMLIR.mlirContextIsRegisteredOperation(context, opname) +function enable_multithreading!(context, enable=true) + API.mlirContextEnableMultithreading(context, enable) + context +end + +is_registered_operation(context, opname) = API.mlirContextIsRegisteredOperation(context, opname) ### Location @@ -151,9 +156,9 @@ struct Location end end -Location(context::Context) = Location(LibMLIR.mlirLocationUnknownGet(context)) +Location(context::Context) = Location(API.mlirLocationUnknownGet(context)) Location(context::Context, filename, line, column=0) = - Location(LibMLIR.mlirLocationFileLineColGet(context, filename, line, column)) + Location(API.mlirLocationFileLineColGet(context, filename, line, column)) Location(context::Context, lin::Core.LineInfoNode) = Location(context, string(lin.file), lin.line) Location(context::Context, lin::LineNumberNode) = @@ -168,7 +173,7 @@ function Base.show(io::IO, location::Location) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) print(io, "Location(#= ") - GC.@preserve ref LibMLIR.mlirLocationPrint(location, c_print_callback, ref) + GC.@preserve ref API.mlirLocationPrint(location, c_print_callback, ref) print(io, " =#)") end @@ -185,29 +190,29 @@ end MType(t::MType) = t MType(context::Context, T::Type{<:Signed}) = - MType(LibMLIR.mlirIntegerTypeGet(context, sizeof(T) * 8)) + MType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) MType(context::Context, T::Type{<:Unsigned}) = - MType(LibMLIR.mlirIntegerTypeGet(context, sizeof(T) * 8)) + MType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) MType(context::Context, ::Type{Bool}) = - MType(LibMLIR.mlirIntegerTypeGet(context, 1)) + MType(API.mlirIntegerTypeGet(context, 1)) MType(context::Context, ::Type{Float32}) = - MType(LibMLIR.mlirF32TypeGet(context)) + MType(API.mlirF32TypeGet(context)) MType(context::Context, ::Type{Float64}) = - MType(LibMLIR.mlirF64TypeGet(context)) + MType(API.mlirF64TypeGet(context)) MType(context::Context, ft::Pair) = - MType(LibMLIR.mlirFunctionTypeGet(context, + MType(API.mlirFunctionTypeGet(context, length(ft.first), [MType(t) for t in ft.first], length(ft.second), [MType(t) for t in ft.second])) MType(context, a::AbstractArray{T}) where {T} = MType(context, MType(context, T), size(a)) MType(context, ::Type{<:AbstractArray{T,N}}, dims) where {T,N} = - MType(LibMLIR.mlirRankedTensorTypeGetChecked( + MType(API.mlirRankedTensorTypeGetChecked( Location(context), N, collect(dims), MType(context, T), Attribute(), )) MType(context, element_type::MType, dims) = - MType(LibMLIR.mlirRankedTensorTypeGetChecked( + MType(API.mlirRankedTensorTypeGetChecked( Location(context), length(dims), collect(dims), element_type, @@ -216,24 +221,24 @@ MType(context, element_type::MType, dims) = MType(context, ::T) where {T<:Real} = MType(context, T) MType(_, type::MType) = type -IndexType(context) = MType(LibMLIR.mlirIndexTypeGet(context)) +IndexType(context) = MType(API.mlirIndexTypeGet(context)) Base.convert(::Type{MlirType}, mtype::MType) = mtype.type function Base.eltype(type::MType) - if LibMLIR.mlirTypeIsAShaped(type) - MType(LibMLIR.mlirShapedTypeGetElementType(type)) + if API.mlirTypeIsAShaped(type) + MType(API.mlirShapedTypeGetElementType(type)) else type end end function show_inner(io::IO, type::MType) - if LibMLIR.mlirTypeIsAInteger(type) - is_signless = LibMLIR.mlirIntegerTypeIsSignless(type) - is_signed = LibMLIR.mlirIntegerTypeIsSigned(type) + if API.mlirTypeIsAInteger(type) + is_signless = API.mlirIntegerTypeIsSignless(type) + is_signed = API.mlirIntegerTypeIsSigned(type) - width = LibMLIR.mlirIntegerTypeGetWidth(type) + width = API.mlirIntegerTypeGetWidth(type) t = if is_signed "si" elseif is_signless @@ -242,17 +247,17 @@ function show_inner(io::IO, type::MType) "u" end print(io, t, width) - elseif LibMLIR.mlirTypeIsAF64(type) + elseif API.mlirTypeIsAF64(type) print(io, "f64") - elseif LibMLIR.mlirTypeIsAF32(type) + elseif API.mlirTypeIsAF32(type) print(io, "f32") - elseif LibMLIR.mlirTypeIsARankedTensor(type) + elseif API.mlirTypeIsARankedTensor(type) print(io, "tensor<") s = size(type) print(io, join(s, "x"), "x") show_inner(io, eltype(type)) print(io, ">") - elseif LibMLIR.mlirTypeIsAIndex(type) + elseif API.mlirTypeIsAIndex(type) print(io, "index") else print(io, "unknown") @@ -263,7 +268,7 @@ function Base.show(io::IO, type::MType) print(io, "MType(#= ") c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref LibMLIR.mlirTypePrint(type, c_print_callback, ref) + GC.@preserve ref API.mlirTypePrint(type, c_print_callback, ref) print(io, " =#)") end @@ -275,10 +280,10 @@ function inttype(size, issigned) end function julia_type(type::MType) - if LibMLIR.mlirTypeIsAInteger(type) - is_signed = LibMLIR.mlirIntegerTypeIsSigned(type) || - LibMLIR.mlirIntegerTypeIsSignless(type) - width = LibMLIR.mlirIntegerTypeGetWidth(type) + if API.mlirTypeIsAInteger(type) + is_signed = API.mlirIntegerTypeIsSigned(type) || + API.mlirIntegerTypeIsSignless(type) + width = API.mlirIntegerTypeGetWidth(type) try inttype(width, is_signed) @@ -286,9 +291,9 @@ function julia_type(type::MType) t = is_signed ? "i" : "u" throw("could not convert type $(t)$(width) to julia") end - elseif LibMLIR.mlirTypeIsAF32(type) + elseif API.mlirTypeIsAF32(type) Float32 - elseif LibMLIR.mlirTypeIsAF64(type) + elseif API.mlirTypeIsAF64(type) Float64 else throw("could not convert type $type to julia") @@ -296,41 +301,41 @@ function julia_type(type::MType) end Base.ndims(type::MType) = - if LibMLIR.mlirTypeIsAShaped(type) && LibMLIR.mlirShapedTypeHasRank(type) - LibMLIR.mlirShapedTypeGetRank(type) + if API.mlirTypeIsAShaped(type) && API.mlirShapedTypeHasRank(type) + API.mlirShapedTypeGetRank(type) else 0 end -Base.size(type::MType, i::Int) = LibMLIR.mlirShapedTypeGetDimSize(type, i - 1) +Base.size(type::MType, i::Int) = API.mlirShapedTypeGetDimSize(type, i - 1) Base.size(type::MType) = Tuple(size(type, i) for i in 1:ndims(type)) function is_tensor(type::MType) - LibMLIR.mlirTypeIsAShaped(type) + API.mlirTypeIsAShaped(type) end function is_integer(type::MType) - LibMLIR.mlirTypeIsAInteger(type) + API.mlirTypeIsAInteger(type) end -is_function_type(mtype) = LibMLIR.mlirTypeIsAFunction(mtype) +is_function_type(mtype) = API.mlirTypeIsAFunction(mtype) function get_num_inputs(ftype) @assert is_function_type(ftype) "cannot get the number of inputs on type $(ftype), expected a function type" - LibMLIR.mlirFunctionTypeGetNumInputs(ftype) + API.mlirFunctionTypeGetNumInputs(ftype) end function get_num_results(ftype) @assert is_function_type(ftype) "cannot get the number of results on type $(ftype), expected a function type" - LibMLIR.mlirFunctionTypeGetNumResults(ftype) + API.mlirFunctionTypeGetNumResults(ftype) end function get_input(ftype::MType, pos) @assert is_function_type(ftype) "cannot get input on type $(ftype), expected a function type" - MType(LibMLIR.mlirFunctionTypeGetInput(ftype, pos - 1)) + MType(API.mlirFunctionTypeGetInput(ftype, pos - 1)) end function get_result(ftype::MType, pos=1) @assert is_function_type(ftype) "cannot get result on type $(ftype), expected a function type" - MType(LibMLIR.mlirFunctionTypeGetResult(ftype, pos - 1)) + MType(API.mlirFunctionTypeGetResult(ftype, pos - 1)) end ### Attribute @@ -339,106 +344,106 @@ struct Attribute attribute::MlirAttribute end -Attribute() = Attribute(LibMLIR.mlirAttributeGetNull()) -Attribute(context, s::AbstractString) = Attribute(LibMLIR.mlirStringAttrGet(context, s)) -Attribute(type::MType) = Attribute(LibMLIR.mlirTypeAttrGet(type)) +Attribute() = Attribute(API.mlirAttributeGetNull()) +Attribute(context, s::AbstractString) = Attribute(API.mlirStringAttrGet(context, s)) +Attribute(type::MType) = Attribute(API.mlirTypeAttrGet(type)) Attribute(context, f::F, type=MType(context, F)) where {F<:AbstractFloat} = Attribute( - LibMLIR.mlirFloatAttrDoubleGet(context, type, Float64(f)) + API.mlirFloatAttrDoubleGet(context, type, Float64(f)) ) Attribute(context, i::T) where {T<:Integer} = Attribute( - LibMLIR.mlirIntegerAttrGet(MType(context, T), Int64(i)) + API.mlirIntegerAttrGet(MType(context, T), Int64(i)) ) function Attribute(context, values::T) where {T<:AbstractArray{Int32}} type = MType(context, T, size(values)) Attribute( - LibMLIR.mlirDenseElementsAttrInt32Get(type, length(values), values) + API.mlirDenseElementsAttrInt32Get(type, length(values), values) ) end function Attribute(context, values::T) where {T<:AbstractArray{Int64}} type = MType(context, T, size(values)) Attribute( - LibMLIR.mlirDenseElementsAttrInt64Get(type, length(values), values) + API.mlirDenseElementsAttrInt64Get(type, length(values), values) ) end function Attribute(context, values::T) where {T<:AbstractArray{Float64}} type = MType(context, T, size(values)) Attribute( - LibMLIR.mlirDenseElementsAttrDoubleGet(type, length(values), values) + API.mlirDenseElementsAttrDoubleGet(type, length(values), values) ) end function Attribute(context, values::T) where {T<:AbstractArray{Float32}} type = MType(context, T, size(values)) Attribute( - LibMLIR.mlirDenseElementsAttrFloatGet(type, length(values), values) + API.mlirDenseElementsAttrFloatGet(type, length(values), values) ) end function Attribute(context, values::AbstractArray{Int32}, type) Attribute( - LibMLIR.mlirDenseElementsAttrInt32Get(type, length(values), values) + API.mlirDenseElementsAttrInt32Get(type, length(values), values) ) end function Attribute(context, values::AbstractArray{Int}, type) Attribute( - LibMLIR.mlirDenseElementsAttrInt64Get(type, length(values), values) + API.mlirDenseElementsAttrInt64Get(type, length(values), values) ) end function Attribute(context, values::AbstractArray{Float32}, type) Attribute( - LibMLIR.mlirDenseElementsAttrFloatGet(type, length(values), values) + API.mlirDenseElementsAttrFloatGet(type, length(values), values) ) end function ArrayAttribute(context, values::AbstractVector{Int}) elements = Attribute.((context,), values) Attribute( - LibMLIR.mlirArrayAttrGet(context, length(elements), elements) + API.mlirArrayAttrGet(context, length(elements), elements) ) end function ArrayAttribute(context, attributes::Vector{Attribute}) Attribute( - LibMLIR.mlirArrayAttrGet(context, length(attributes), attributes), + API.mlirArrayAttrGet(context, length(attributes), attributes), ) end function DenseArrayAttribute(context, values::AbstractVector{Int}) Attribute( - LibMLIR.mlirDenseI64ArrayGet(context, length(values), collect(values)) + API.mlirDenseI64ArrayGet(context, length(values), collect(values)) ) end function Attribute(context, value::Int, type::MType) Attribute( - LibMLIR.mlirIntegerAttrGet(type, value) + API.mlirIntegerAttrGet(type, value) ) end function Attribute(context, value::Bool, ::MType=nothing) Attribute( - LibMLIR.mlirBoolAttrGet(context, value) + API.mlirBoolAttrGet(context, value) ) end Base.convert(::Type{MlirAttribute}, attribute::Attribute) = attribute.attribute Base.parse(::Type{Attribute}, context, s) = - Attribute(LibMLIR.mlirAttributeParseGet(context, s)) + Attribute(API.mlirAttributeParseGet(context, s)) function get_type(attribute::Attribute) - MType(LibMLIR.mlirAttributeGetType(attribute)) + MType(API.mlirAttributeGetType(attribute)) end function get_type_value(attribute) - @assert LibMLIR.mlirAttributeIsAType(attribute) "attribute $(attribute) is not a type" - MType(LibMLIR.mlirTypeAttrGetValue(attribute)) + @assert API.mlirAttributeIsAType(attribute) "attribute $(attribute) is not a type" + MType(API.mlirTypeAttrGetValue(attribute)) end function get_bool_value(attribute) - @assert LibMLIR.mlirAttributeIsABool(attribute) "attribute $(attribute) is not a boolean" - LibMLIR.mlirBoolAttrGetValue(attribute) + @assert API.mlirAttributeIsABool(attribute) "attribute $(attribute) is not a boolean" + API.mlirBoolAttrGetValue(attribute) end function get_string_value(attribute) - @assert LibMLIR.mlirAttributeIsAString(attribute) "attribute $(attribute) is not a string attribute" - String(LibMLIR.mlirStringAttrGetValue(attribute)) + @assert API.mlirAttributeIsAString(attribute) "attribute $(attribute) is not a string attribute" + String(API.mlirStringAttrGetValue(attribute)) end function Base.show(io::IO, attribute::Attribute) print(io, "Attribute(#= ") c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref LibMLIR.mlirAttributePrint(attribute, c_print_callback, ref) + GC.@preserve ref API.mlirAttributePrint(attribute, c_print_callback, ref) print(io, " =#)") end @@ -450,8 +455,8 @@ end function NamedAttribute(context, name, attribute) @assert !mlirIsNull(attribute.attribute) - NamedAttribute(LibMLIR.mlirNamedAttributeGet( - LibMLIR.mlirIdentifierGet(context, name), + NamedAttribute(API.mlirNamedAttributeGet( + API.mlirIdentifierGet(context, name), attribute )) end @@ -470,7 +475,7 @@ struct Value end end -get_type(value) = MType(LibMLIR.mlirValueGetType(value)) +get_type(value) = MType(API.mlirValueGetType(value)) Base.convert(::Type{MlirValue}, value::Value) = value.value Base.size(value::Value) = Base.size(get_type(value)) @@ -479,21 +484,21 @@ Base.ndims(value::Value) = Base.ndims(get_type(value)) function Base.show(io::IO, value::Value) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref LibMLIR.mlirValuePrint(value, c_print_callback, ref) + GC.@preserve ref API.mlirValuePrint(value, c_print_callback, ref) end -is_a_op_result(value) = LibMLIR.mlirValueIsAOpResult(value) -is_a_block_argument(value) = LibMLIR.mlirValueIsABlockArgument(value) +is_a_op_result(value) = API.mlirValueIsAOpResult(value) +is_a_block_argument(value) = API.mlirValueIsABlockArgument(value) function set_type!(value, type) @assert is_a_block_argument(value) "could not set type, value is not a block argument" - LibMLIR.mlirBlockArgumentSetType(value, type) + API.mlirBlockArgumentSetType(value, type) value end function get_owner(value::Value) if is_a_block_argument(value) - raw_block = LibMLIR.mlirBlockArgumentGetOwner(value) + raw_block = API.mlirBlockArgumentGetOwner(value) if mlirIsNull(raw_block) return nothing end @@ -501,7 +506,7 @@ function get_owner(value::Value) return Block(raw_block, false) end - raw_op = LibMLIR.mlirOpResultGetOwner(value) + raw_op = API.mlirOpResultGetOwner(value) if mlirIsNull(raw_op) return nothing end @@ -515,27 +520,27 @@ struct OperationState opstate::MlirOperationState end -OperationState(name, location) = OperationState(LibMLIR.mlirOperationStateGet(name, location)) +OperationState(name, location) = OperationState(API.mlirOperationStateGet(name, location)) add_results!(state, results) = - LibMLIR.mlirOperationStateAddResults(state, length(results), results) + API.mlirOperationStateAddResults(state, length(results), results) add_operands!(state, operands) = - LibMLIR.mlirOperationStateAddOperands(state, length(operands), operands) + API.mlirOperationStateAddOperands(state, length(operands), operands) function add_owned_regions!(state, regions) mlir_regions = Base.convert.(MlirRegion, regions) lose_ownership!.(regions) - LibMLIR.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), mlir_regions) + API.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), mlir_regions) end add_attributes!(state, attributes) = - LibMLIR.mlirOperationStateAddAttributes(state, length(attributes), attributes) + API.mlirOperationStateAddAttributes(state, length(attributes), attributes) add_successors!(state, successors) = - LibMLIR.mlirOperationStateAddSuccessors( + API.mlirOperationStateAddSuccessors( state, length(successors), - convert(Vector{LibMLIR.MlirBlock}, successors), + convert(Vector{API.MlirBlock}, successors), ) enable_type_inference!(state) = - LibMLIR.mlirOperationStateEnableResultTypeInference(state) + API.mlirOperationStateEnableResultTypeInference(state) Base.unsafe_convert(::Type{Ptr{MlirOperationState}}, state::OperationState) = Base.unsafe_convert(Ptr{MlirOperationState}, Base.pointer_from_objref(state.opstate)) @@ -550,57 +555,57 @@ mutable struct Operation @assert !mlirIsNull(operation) "cannot create Operation with null MlirOperation" finalizer(new(operation, owned)) do op if op.owned - LibMLIR.mlirOperationDestroy(op.operation) + API.mlirOperationDestroy(op.operation) end end end end -Operation(state::OperationState) = Operation(LibMLIR.mlirOperationCreate(state), true) +Operation(state::OperationState) = Operation(API.mlirOperationCreate(state), true) -Base.copy(operation::Operation) = Operation(LibMLIR.mlirOperationClone(operation)) +Base.copy(operation::Operation) = Operation(API.mlirOperationClone(operation)) -num_regions(operation) = LibMLIR.mlirOperationGetNumRegions(operation) +num_regions(operation) = API.mlirOperationGetNumRegions(operation) function get_region(operation, i) i ∈ 1:num_regions(operation) && throw(BoundsError(operation, i)) - Region(LibMLIR.mlirOperationGetRegion(operation, i - 1), false) + Region(API.mlirOperationGetRegion(operation, i - 1), false) end -num_results(operation) = LibMLIR.mlirOperationGetNumResults(operation) +num_results(operation) = API.mlirOperationGetNumResults(operation) get_results(operation) = [ get_result(operation, i) for i in 1:num_results(operation) ] function get_result(operation::Operation, i=1) i ∉ 1:num_results(operation) && throw(BoundsError(operation, i)) - Value(LibMLIR.mlirOperationGetResult(operation, i - 1)) + Value(API.mlirOperationGetResult(operation, i - 1)) end -num_operands(operation) = LibMLIR.mlirOperationGetNumOperands(operation) +num_operands(operation) = API.mlirOperationGetNumOperands(operation) function get_operand(operation, i=1) i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) - Value(LibMLIR.mlirOperationGetOperand(operation, i - 1)) + Value(API.mlirOperationGetOperand(operation, i - 1)) end function set_operand!(operation, i, value) i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) - LibMLIR.mlirOperationSetOperand(operation, i - 1, value) + API.mlirOperationSetOperand(operation, i - 1, value) value end function get_attribute_by_name(operation, name) - raw_attr = LibMLIR.mlirOperationGetAttributeByName(operation, name) + raw_attr = API.mlirOperationGetAttributeByName(operation, name) if mlirIsNull(raw_attr) return nothing end Attribute(raw_attr) end function set_attribute_by_name!(operation, name, attribute) - LibMLIR.mlirOperationSetAttributeByName(operation, name, attribute) + API.mlirOperationSetAttributeByName(operation, name, attribute) operation end -get_location(operation) = Location(LibMLIR.mlirOperationGetLocation(operation)) -get_name(operation) = String(LibMLIR.mlirOperationGetName(operation)) -get_block(operation) = Block(LibMLIR.mlirOperationGetBlock(operation), false) -get_parent_operation(operation) = Operation(LibMLIR.mlirOperationGetParentOperation(operation), false) +get_location(operation) = Location(API.mlirOperationGetLocation(operation)) +get_name(operation) = String(API.mlirOperationGetName(operation)) +get_block(operation) = Block(API.mlirOperationGetBlock(operation), false) +get_parent_operation(operation) = Operation(API.mlirOperationGetParentOperation(operation), false) get_dialect(operation) = first(split(get_name(operation), '.')) |> Symbol function get_first_region(op::Operation) @@ -622,7 +627,7 @@ function get_first_child_op(op::Operation) first(cop) end -op::Operation == other::Operation = LibMLIR.mlirOperationEqual(op, other) +op::Operation == other::Operation = API.mlirOperationEqual(op, other) Base.convert(::Type{MlirOperation}, op::Operation) = op.operation @@ -635,13 +640,13 @@ end function Base.show(io::IO, operation::Operation) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - flags = LibMLIR.mlirOpPrintingFlagsCreate() - get(io, :debug, false) && LibMLIR.mlirOpPrintingFlagsEnableDebugInfo(flags, true, true) - GC.@preserve ref LibMLIR.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) + flags = API.mlirOpPrintingFlagsCreate() + get(io, :debug, false) && API.mlirOpPrintingFlagsEnableDebugInfo(flags, true, true) + GC.@preserve ref API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) println(io) end -verify(operation::Operation) = LibMLIR.mlirOperationVerify(operation) +verify(operation::Operation) = API.mlirOperationVerify(operation) ### Block @@ -653,7 +658,7 @@ mutable struct Block @assert !mlirIsNull(block) "cannot create Block with null MlirBlock" finalizer(new(block, owned)) do block if block.owned - LibMLIR.mlirBlockDestroy(block.block) + API.mlirBlockDestroy(block.block) end end end @@ -662,15 +667,15 @@ end Block() = Block(MType[], Location[]) function Block(args::Vector{MType}, locs::Vector{Location}) @assert length(args) == length(locs) "there should be one args for each locs (got $(length(args)) & $(length(locs)))" - Block(LibMLIR.mlirBlockCreate(length(args), args, locs)) + Block(API.mlirBlockCreate(length(args), args, locs)) end function Base.push!(block::Block, op::Operation) - LibMLIR.mlirBlockAppendOwnedOperation(block, lose_ownership!(op)) + API.mlirBlockAppendOwnedOperation(block, lose_ownership!(op)) op end function Base.insert!(block::Block, pos, op::Operation) - LibMLIR.mlirBlockInsertOwnedOperation(block, pos - 1, lose_ownership!(op)) + API.mlirBlockInsertOwnedOperation(block, pos - 1, lose_ownership!(op)) op end function Base.pushfirst!(block::Block, op::Operation) @@ -678,22 +683,22 @@ function Base.pushfirst!(block::Block, op::Operation) op end function insert_after!(block::Block, reference::Operation, op::Operation) - LibMLIR.mlirBlockInsertOwnedOperationAfter(block, reference, lose_ownership!(op)) + API.mlirBlockInsertOwnedOperationAfter(block, reference, lose_ownership!(op)) op end function insert_before!(block::Block, reference::Operation, op::Operation) - LibMLIR.mlirBlockInsertOwnedOperationBefore(block, reference, lose_ownership!(op)) + API.mlirBlockInsertOwnedOperationBefore(block, reference, lose_ownership!(op)) op end num_arguments(block::Block) = - LibMLIR.mlirBlockGetNumArguments(block) + API.mlirBlockGetNumArguments(block) function get_argument(block::Block, i) i ∉ 1:num_arguments(block) && throw(BoundsError(block, i)) - Value(LibMLIR.mlirBlockGetArgument(block, i - 1)) + Value(API.mlirBlockGetArgument(block, i - 1)) end push_argument!(block::Block, type, loc) = - Value(LibMLIR.mlirBlockAddArgument(block, type, loc)) + Value(API.mlirBlockAddArgument(block, type, loc)) Base.convert(::Type{MlirBlock}, block::Block) = block.block @@ -706,33 +711,33 @@ end function Base.show(io::IO, block::Block) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref LibMLIR.mlirBlockPrint(block, c_print_callback, ref) + GC.@preserve ref API.mlirBlockPrint(block, c_print_callback, ref) end ### Region mutable struct Region region::MlirRegion - @atomic owned::Bool # TODO: make atomic? + @atomic owned::Bool Region(region, owned=true) = begin @assert !mlirIsNull(region) finalizer(new(region, owned)) do region if region.owned - LibMLIR.mlirRegionDestroy(region.region) + API.mlirRegionDestroy(region.region) end end end end -Region() = Region(LibMLIR.mlirRegionCreate()) +Region() = Region(API.mlirRegionCreate()) function Base.push!(region::Region, block::Block) - LibMLIR.mlirRegionAppendOwnedBlock(region, lose_ownership!(block)) + API.mlirRegionAppendOwnedBlock(region, lose_ownership!(block)) block end function Base.insert!(region::Region, pos, block::Block) - LibMLIR.mlirRegionInsertOwnedBlock(region, pos - 1, lose_ownership!(block)) + API.mlirRegionInsertOwnedBlock(region, pos - 1, lose_ownership!(block)) block end function Base.pushfirst!(region::Region, block) @@ -740,9 +745,9 @@ function Base.pushfirst!(region::Region, block) block end insert_after!(region::Region, reference::Block, block::Block) = - LibMLIR.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) + API.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) insert_before!(region::Region, reference::Block, block::Block) = - LibMLIR.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) + API.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) function get_first_block(region::Region) block = iterate(BlockIterator(region)) @@ -766,18 +771,18 @@ mutable struct MModule MModule(module_, context) = begin @assert !mlirIsNull(module_) "cannot create MModule with null MlirModule" - finalizer(LibMLIR.mlirModuleDestroy, new(module_, context)) + finalizer(API.mlirModuleDestroy, new(module_, context)) end end MModule(context::Context, loc=Location(context)) = - MModule(LibMLIR.mlirModuleCreateEmpty(loc), context) -get_operation(module_) = Operation(LibMLIR.mlirModuleGetOperation(module_), false) -get_body(module_) = Block(LibMLIR.mlirModuleGetBody(module_), false) + MModule(API.mlirModuleCreateEmpty(loc), context) +get_operation(module_) = Operation(API.mlirModuleGetOperation(module_), false) +get_body(module_) = Block(API.mlirModuleGetBody(module_), false) get_first_child_op(mod::MModule) = get_first_child_op(get_operation(mod)) Base.convert(::Type{MlirModule}, module_::MModule) = module_.module_ -Base.parse(::Type{MModule}, context, module_) = MModule(LibMLIR.mlirModuleCreateParse(context, module_), context) +Base.parse(::Type{MModule}, context, module_) = MModule(API.mlirModuleCreateParse(context, module_), context) macro mlir_str(code) quote @@ -794,29 +799,29 @@ end ### TypeID struct TypeID - typeid::LibMLIR.MlirTypeID + typeid::API.MlirTypeID end -Base.hash(typeid::TypeID) = LibMLIR.mlirTypeIDHashValue(typeid.typeid) -Base.convert(::Type{LibMLIR.MlirTypeID}, typeid::TypeID) = typeid.typeid +Base.hash(typeid::TypeID) = API.mlirTypeIDHashValue(typeid.typeid) +Base.convert(::Type{API.MlirTypeID}, typeid::TypeID) = typeid.typeid -@static if isdefined(LibMLIR, :MlirTypeIDAllocator) +@static if isdefined(API, :MlirTypeIDAllocator) ### TypeIDAllocator mutable struct TypeIDAllocator - allocator::LibMLIR.MlirTypeIDAllocator + allocator::API.MlirTypeIDAllocator function TypeIDAllocator() - ptr = LibMLIR.mlirTypeIDAllocatorCreate() + ptr = API.mlirTypeIDAllocatorCreate() @assert ptr != C_NULL "cannot create TypeIDAllocator" - finalizer(LibMLIR.mlirTypeIDAllocatorDestroy, new(ptr)) + finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) end end -Base.convert(::Type{LibMLIR.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator.allocator +Base.convert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator.allocator -TypeID(allocator::TypeIDAllocator) = TypeID(LibMLIR.mlirTypeIDCreate(allocator)) +TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) else @@ -842,21 +847,25 @@ mutable struct PassManager PassManager(pm::MlirPassManager, context) = begin @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" finalizer(new(pm, context, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm - LibMLIR.mlirPassManagerDestroy(pm.pass) + API.mlirPassManagerDestroy(pm.pass) end end end -function enable_verifier!(pm) - LibMLIR.mlirPassManagerEnableVerifier(pm) +function enable_ir_printing!(pm) + API.mlirPassManagerEnableIRPrinting(pm) + pm +end +function enable_verifier!(pm, enable=true) + API.mlirPassManagerEnableVerifier(pm, enable) pm end PassManager(context) = - PassManager(LibMLIR.mlirPassManagerCreate(context), context) + PassManager(API.mlirPassManagerCreate(context), context) -function run(pm::PassManager, module_) - status = LibMLIR.mlirPassManagerRun(pm, module_) +function run!(pm::PassManager, module_) + status = API.mlirPassManagerRun(pm, module_) if mlirLogicalResultIsFailure(status) throw("failed to run pass manager on module") end @@ -877,9 +886,9 @@ struct OpPassManager end end -OpPassManager(pm::PassManager) = OpPassManager(LibMLIR.mlirPassManagerGetAsOpPassManager(pm), pm) -OpPassManager(pm::PassManager, opname) = OpPassManager(LibMLIR.mlirPassManagerGetNestedUnder(pm, opname), pm) -OpPassManager(opm::OpPassManager, opname) = OpPassManager(LibMLIR.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) +OpPassManager(pm::PassManager) = OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) +OpPassManager(pm::PassManager, opname) = OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) +OpPassManager(opm::OpPassManager, opname) = OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) Base.convert(::Type{MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass @@ -887,7 +896,7 @@ function Base.show(io::IO, op_pass::OpPassManager) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) println(io, "OpPassManager(\"\"\"") - GC.@preserve ref LibMLIR.mlirPrintPassPipeline(op_pass, c_print_callback, ref) + GC.@preserve ref API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) println(io) print(io, "\"\"\")") end @@ -904,24 +913,35 @@ end mlirLogicalResultIsFailure(result) = result.value == 0 function add_pipeline!(op_pass::OpPassManager, pipeline) - @static if isdefined(LibMLIR, :mlirOpPassManagerAddPipeline) + @static if isdefined(API, :mlirOpPassManagerAddPipeline) io = IOBuffer() c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) - result = GC.@preserve io LibMLIR.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io) - if LibMLIR.mlirLogicalResultIsFailure(result) + result = GC.@preserve io API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io) + if API.mlirLogicalResultIsFailure(result) exc = AddPipelineException(String(take!(io))) throw(exc) end else - result = LibMLIR.mlirParsePassPipeline(op_pass, pipeline) + result = API.mlirParsePassPipeline(op_pass, pipeline) if mlirLogicalResultIsFailure(result) throw(AddPipelineException(" " * pipeline)) end end op_pass end + +function add_owned_pass!(pm::PassManager, pass) + API.mlirPassManagerAddOwnedPass(pm, pass) + pm +end + +function add_owned_pass!(opm::OpPassManager, pass) + API.mlirOpPassManagerAddOwnedPass(opm, pass) + opm +end -@static if isdefined(LibMLIR, :mlirCreateExternalPass) + +@static if isdefined(API, :mlirCreateExternalPass) ### Pass @@ -942,9 +962,9 @@ end function _pass_initialize(ctx, handle::ExternalPassHandle) try handle.ctx = Context(ctx) - LibMLIR.mlirLogicalResultSuccess() + API.mlirLogicalResultSuccess() catch - LibMLIR.mlirLogicalResultFailure() + API.mlirLogicalResultFailure() end end @@ -958,7 +978,7 @@ function _pass_run(rawop, external_pass, handle::ExternalPassHandle) pass_run(handle.ctx, handle.pass, op) catch ex @error "Something went wrong running pass" exception=(ex,catch_backtrace()) - LibMLIR.mlirExternalPassSignalFailure(external_pass) + API.mlirExternalPassSignalFailure(external_pass) end nothing end @@ -970,31 +990,21 @@ function create_external_pass!(manager, pass, name, argument, description, opname=get_opname(pass), dependent_dialects=MlirDialectHandle[]) passid = TypeID(manager.allocator) - callbacks = LibMLIR.MlirExternalPassCallbacks( + callbacks = API.MlirExternalPassCallbacks( @cfunction(_pass_construct, Cvoid, (Any,)), @cfunction(_pass_destruct, Cvoid, (Any,)), - @cfunction(_pass_initialize, LibMLIR.MlirLogicalResult, (MlirContext, Any,)), + @cfunction(_pass_initialize, API.MlirLogicalResult, (MlirContext, Any,)), @cfunction(_pass_clone, Any, (Any,)), - @cfunction(_pass_run, Cvoid, (MlirOperation, LibMLIR.MlirExternalPass, Any)) + @cfunction(_pass_run, Cvoid, (MlirOperation, API.MlirExternalPass, Any)) ) pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) userdata = Base.pointer_from_objref(pass_handle) - mlir_pass = LibMLIR.mlirCreateExternalPass(passid, name, argument, description, opname, + mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, length(dependent_dialects), dependent_dialects, callbacks, userdata) mlir_pass end -function add_owned_pass!(pm::PassManager, pass) - LibMLIR.mlirPassManagerAddOwnedPass(pm, pass) - pm -end - -function add_owned_pass!(opm::OpPassManager, pass) - LibMLIR.mlirOpPassManagerAddOwnedPass(opm, pass) - opm -end - end ### Iterators @@ -1010,7 +1020,7 @@ end function Base.iterate(it::BlockIterator) reg = it.region - raw_block = LibMLIR.mlirRegionGetFirstBlock(reg) + raw_block = API.mlirRegionGetFirstBlock(reg) if mlirIsNull(raw_block) nothing else @@ -1020,7 +1030,7 @@ function Base.iterate(it::BlockIterator) end function Base.iterate(it::BlockIterator, block) - raw_block = LibMLIR.mlirBlockGetNextInRegion(block) + raw_block = API.mlirBlockGetNextInRegion(block) if mlirIsNull(raw_block) nothing else @@ -1039,7 +1049,7 @@ struct OperationIterator end function Base.iterate(it::OperationIterator) - raw_op = LibMLIR.mlirBlockGetFirstOperation(it.block) + raw_op = API.mlirBlockGetFirstOperation(it.block) if mlirIsNull(raw_op) nothing else @@ -1049,7 +1059,7 @@ function Base.iterate(it::OperationIterator) end function Base.iterate(it::OperationIterator, op) - raw_op = LibMLIR.mlirOperationGetNextInBlock(op) + raw_op = API.mlirOperationGetNextInBlock(op) if mlirIsNull(raw_op) nothing else @@ -1068,7 +1078,7 @@ struct RegionIterator end function Base.iterate(it::RegionIterator) - raw_region = LibMLIR.mlirOperationGetFirstRegion(it.op) + raw_region = API.mlirOperationGetFirstRegion(it.op) if mlirIsNull(raw_region) nothing else @@ -1078,7 +1088,7 @@ function Base.iterate(it::RegionIterator) end function Base.iterate(it::RegionIterator, region) - raw_region = LibMLIR.mlirRegionGetNextInOperation(region) + raw_region = API.mlirRegionGetNextInOperation(region) if mlirIsNull(raw_region) nothing else @@ -1114,30 +1124,4 @@ function verifyall(operation::Operation; debug=false) end verifyall(module_::MModule) = get_operation(module_) |> verifyall -function get_dialects!(dialects::Set{Symbol}, op::Operation) - push!(dialects, get_dialect(op)) - - visit(op) do op - get_dialects!(dialects, op) - end - - dialects -end - -function get_input_type(module_) - dialects = Set{Symbol}() - - op = get_operation(module_) - get_dialects!(dialects, op) - - if :mhlo ∈ dialects - # :tosa ∉ dialects || throw("cannot have both tosa and mhlo operations") - :mhlo - elseif :tosa ∈ dialects - :tosa - else - :none - end -end - end # module IR From b158dd55e6da4db5b0bab82fbe87a71a0c90bd18 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 16 May 2023 22:30:25 +0200 Subject: [PATCH 03/11] skip :code_coverage_effect --- examples/brutus.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/brutus.jl b/examples/brutus.jl index 496db864..c9f7e81e 100644 --- a/examples/brutus.jl +++ b/examples/brutus.jl @@ -165,6 +165,8 @@ function code_mlir(f, types; ctx=Context()) line = ir.linetable[stmt[:line]] retop = LLVM.version() >= v"15" ? func.return_ : std.return_ push!(current_block, retop(ctx, [get_value(inst.val)]; loc=Location(ctx, line))) + elseif Meta.isexpr(inst, :code_coverage_effect) + # Skip else error("unhandled ir $(inst)") end From 0e70ccf9c4b1c97c0ca1fb277aa0c51fe5d00b9f Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 16 May 2023 22:50:52 +0200 Subject: [PATCH 04/11] remove `GC.@preserve ref` && `MlirStringRef` handling --- src/IR.jl | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/IR.jl b/src/IR.jl index 47990358..4a870c40 100644 --- a/src/IR.jl +++ b/src/IR.jl @@ -53,13 +53,6 @@ function print_callback(str::MlirStringRef, userdata) return Cvoid() end -### String Ref - -String(strref::MlirStringRef) = - Base.unsafe_string(Base.convert(Ptr{Cchar}, strref.data), strref.length) -Base.convert(::Type{MlirStringRef}, s::String) = - MlirStringRef(Base.unsafe_convert(Cstring, s), sizeof(s)) - ### Identifier String(ident::MlirIdentifier) = String(API.mlirIdentifierStr(ident)) @@ -173,7 +166,7 @@ function Base.show(io::IO, location::Location) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) print(io, "Location(#= ") - GC.@preserve ref API.mlirLocationPrint(location, c_print_callback, ref) + API.mlirLocationPrint(location, c_print_callback, ref) print(io, " =#)") end @@ -268,7 +261,7 @@ function Base.show(io::IO, type::MType) print(io, "MType(#= ") c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref API.mlirTypePrint(type, c_print_callback, ref) + API.mlirTypePrint(type, c_print_callback, ref) print(io, " =#)") end @@ -443,7 +436,7 @@ function Base.show(io::IO, attribute::Attribute) print(io, "Attribute(#= ") c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref API.mlirAttributePrint(attribute, c_print_callback, ref) + API.mlirAttributePrint(attribute, c_print_callback, ref) print(io, " =#)") end @@ -484,7 +477,7 @@ Base.ndims(value::Value) = Base.ndims(get_type(value)) function Base.show(io::IO, value::Value) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref API.mlirValuePrint(value, c_print_callback, ref) + API.mlirValuePrint(value, c_print_callback, ref) end is_a_op_result(value) = API.mlirValueIsAOpResult(value) @@ -642,7 +635,7 @@ function Base.show(io::IO, operation::Operation) ref = Ref(io) flags = API.mlirOpPrintingFlagsCreate() get(io, :debug, false) && API.mlirOpPrintingFlagsEnableDebugInfo(flags, true, true) - GC.@preserve ref API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) + API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) println(io) end @@ -711,7 +704,7 @@ end function Base.show(io::IO, block::Block) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) - GC.@preserve ref API.mlirBlockPrint(block, c_print_callback, ref) + API.mlirBlockPrint(block, c_print_callback, ref) end ### Region @@ -896,7 +889,7 @@ function Base.show(io::IO, op_pass::OpPassManager) c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) println(io, "OpPassManager(\"\"\"") - GC.@preserve ref API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) + API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) println(io) print(io, "\"\"\")") end From 34ea3e633c54b083942b4f32f8dfc30fc61c5ae9 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 16 May 2023 22:56:03 +0200 Subject: [PATCH 05/11] use `Base.RefValue` for `MlirOperationState` --- lib/14/libMLIR_h.jl | 2 +- lib/15/libMLIR_h.jl | 2 +- src/IR.jl | 21 +++++++++------------ 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/lib/14/libMLIR_h.jl b/lib/14/libMLIR_h.jl index 4a7e19a3..c9dfbb76 100644 --- a/lib/14/libMLIR_h.jl +++ b/lib/14/libMLIR_h.jl @@ -546,7 +546,7 @@ An auxiliary class for constructing operations. This class contains all the information necessary to construct the operation. It owns the MlirRegions it has pointers to and does not own anything else. By default, the state can be constructed from a name and location, the latter being also used to access the context, and has no other components. These components can be added progressively until the operation is constructed. Users are not expected to rely on the internals of this class and should use mlirOperationState* functions instead. """ -mutable struct MlirOperationState # TODO: make mutable in res +struct MlirOperationState name::MlirStringRef location::MlirLocation nResults::intptr_t diff --git a/lib/15/libMLIR_h.jl b/lib/15/libMLIR_h.jl index e7b97ca8..702af778 100644 --- a/lib/15/libMLIR_h.jl +++ b/lib/15/libMLIR_h.jl @@ -656,7 +656,7 @@ An auxiliary class for constructing operations. This class contains all the information necessary to construct the operation. It owns the MlirRegions it has pointers to and does not own anything else. By default, the state can be constructed from a name and location, the latter being also used to access the context, and has no other components. These components can be added progressively until the operation is constructed. Users are not expected to rely on the internals of this class and should use mlirOperationState* functions instead. """ -mutable struct MlirOperationState +struct MlirOperationState name::MlirStringRef location::MlirLocation nResults::intptr_t diff --git a/src/IR.jl b/src/IR.jl index 4a870c40..01599027 100644 --- a/src/IR.jl +++ b/src/IR.jl @@ -510,33 +510,30 @@ end ### OperationState struct OperationState - opstate::MlirOperationState + opstate::Base.RefValue{MlirOperationState} end -OperationState(name, location) = OperationState(API.mlirOperationStateGet(name, location)) +OperationState(name, location) = OperationState(Ref(API.mlirOperationStateGet(name, location))) add_results!(state, results) = - API.mlirOperationStateAddResults(state, length(results), results) + API.mlirOperationStateAddResults(state.opstate, length(results), results) add_operands!(state, operands) = - API.mlirOperationStateAddOperands(state, length(operands), operands) + API.mlirOperationStateAddOperands(state.opstate, length(operands), operands) function add_owned_regions!(state, regions) mlir_regions = Base.convert.(MlirRegion, regions) lose_ownership!.(regions) - API.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), mlir_regions) + API.mlirOperationStateAddOwnedRegions(state.opstate, length(mlir_regions), mlir_regions) end add_attributes!(state, attributes) = - API.mlirOperationStateAddAttributes(state, length(attributes), attributes) + API.mlirOperationStateAddAttributes(state.opstate, length(attributes), attributes) add_successors!(state, successors) = API.mlirOperationStateAddSuccessors( - state, length(successors), + state.opstate, length(successors), convert(Vector{API.MlirBlock}, successors), ) enable_type_inference!(state) = - API.mlirOperationStateEnableResultTypeInference(state) - -Base.unsafe_convert(::Type{Ptr{MlirOperationState}}, state::OperationState) = - Base.unsafe_convert(Ptr{MlirOperationState}, Base.pointer_from_objref(state.opstate)) + API.mlirOperationStateEnableResultTypeInference(state.opstate) ### Operation @@ -554,7 +551,7 @@ mutable struct Operation end end -Operation(state::OperationState) = Operation(API.mlirOperationCreate(state), true) +Operation(state::OperationState) = Operation(API.mlirOperationCreate(state.opstate), true) Base.copy(operation::Operation) = Operation(API.mlirOperationClone(operation)) From d3527818e38d8de43bf2591b647dc35318cf14ed Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Fri, 19 May 2023 18:02:27 +0200 Subject: [PATCH 06/11] Update to `create_operation` API instead of exposing `OperationState` --- examples/brutus.jl | 45 +++++---- src/Dialects.jl | 148 ++++++++++----------------- src/IR.jl | 244 +++++++++++++++++++++++---------------------- 3 files changed, 207 insertions(+), 230 deletions(-) diff --git a/examples/brutus.jl b/examples/brutus.jl index c9f7e81e..36df2877 100644 --- a/examples/brutus.jl +++ b/examples/brutus.jl @@ -41,7 +41,7 @@ function prepare_block(ctx, ir, bb) inst isa Core.PhiNode || continue type = stmt[:type] - IR.push_argument!(b, MType(ctx, type), Location(ctx)) + IR.push_argument!(b, MLIRType(ctx, type), Location(ctx)) end return b @@ -96,7 +96,7 @@ function code_mlir(f, types; ctx=Context()) current_block = entry_block = blocks[begin] for argtype in types.parameters - IR.push_argument!(entry_block, MType(ctx, argtype), Location(ctx)) + IR.push_argument!(entry_block, MLIRType(ctx, argtype), Location(ctx)) end function get_value(x)::Value @@ -122,12 +122,11 @@ function code_mlir(f, types; ctx=Context()) line = ir.linetable[stmt[:line]] if Meta.isexpr(inst, :call) - line = ir.linetable[stmt[:line]] val_type = stmt[:type] if !(val_type <: BrutusScalar) error("type $val_type is not supported") end - out_type = MType(ctx, val_type) + out_type = MLIRType(ctx, val_type) called_func = first(inst.args) if called_func isa GlobalRef # TODO: should probably use something else here @@ -137,7 +136,8 @@ function code_mlir(f, types; ctx=Context()) fop! = intrinsics_to_mlir[called_func] args = get_value.(@view inst.args[begin+1:end]) - res = IR.get_result(fop!(ctx, current_block, args; loc=Location(ctx, line))) + loc = Location(ctx, string(line.file), line.line, 0) + res = IR.get_result(fop!(ctx, current_block, args; loc)) values[sidx] = res elseif inst isa PhiNode @@ -147,8 +147,9 @@ function code_mlir(f, types; ctx=Context()) elseif inst isa GotoNode args = get_value.(collect_value_arguments(ir, block_id, inst.label)) dest = blocks[inst.label] + loc = Location(ctx, string(line.file), line.line, 0) brop = LLVM.version() >= v"15" ? cf.br : std.br - push!(current_block, brop(ctx, dest, args; loc=Location(ctx, line))) + push!(current_block, brop(ctx, dest, args; loc)) elseif inst isa GotoIfNot false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) cond = get_value(inst.cond) @@ -158,13 +159,15 @@ function code_mlir(f, types; ctx=Context()) other_dest = blocks[other_dest] dest = blocks[inst.dest] + loc = Location(ctx, string(line.file), line.line, 0) cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br - cond_br = cond_brop(ctx, cond, other_dest, dest, true_args, false_args; loc=Location(ctx, line)) + cond_br = cond_brop(ctx, cond, other_dest, dest, true_args, false_args; loc) push!(current_block, cond_br) elseif inst isa ReturnNode line = ir.linetable[stmt[:line]] retop = LLVM.version() >= v"15" ? func.return_ : std.return_ - push!(current_block, retop(ctx, [get_value(inst.val)]; loc=Location(ctx, line))) + loc = Location(ctx, string(line.file), line.line, 0) + push!(current_block, retop(ctx, [get_value(inst.val)]; loc)) elseif Meta.isexpr(inst, :code_coverage_effect) # Skip else @@ -181,22 +184,24 @@ function code_mlir(f, types; ctx=Context()) end LLVM15 = LLVM.version() >= v"15" - state = OperationState(LLVM15 ? "func.func" : "builtin.func", Location(ctx)) - input_types = MType[ + input_types = MLIRType[ IR.get_type(IR.get_argument(entry_block, i)) for i in 1:IR.num_arguments(entry_block) ] - result_types = [MType(ctx, ret)] - - ftype = MType(ctx, input_types => result_types) - IR.add_attributes!(state, [ - NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), - NamedAttribute(ctx, LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), - ]) - IR.add_owned_regions!(state, Region[region]) - - op = Operation(state) + result_types = [MLIRType(ctx, ret)] + + ftype = MLIRType(ctx, input_types => result_types) + op = IR.create_operation( + LLVM15 ? "func.func" : "builtin.func", + Location(ctx); + attributes = [ + NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), + NamedAttribute(ctx, LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), + ], + owned_regions = Region[region], + result_inference=false, + ) IR.verifyall(op) diff --git a/src/Dialects.jl b/src/Dialects.jl index f88fd4c3..4cb400eb 100644 --- a/src/Dialects.jl +++ b/src/Dialects.jl @@ -10,19 +10,13 @@ for (f, t) in Iterators.product( ) fname = Symbol(f, t) @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) - state = OperationState($(string("arith.", fname)), loc) - IR.add_operands!(state, operands) - IR.add_results!(state, [type]) - Operation(state) + IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end for fname in (:xori, :andi, :ori) @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) - state = OperationState($(string("arith.", fname)), loc) - IR.add_operands!(state, operands) - IR.add_results!(state, [type]) - Operation(state) + IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end @@ -32,67 +26,38 @@ for (f, t) in Iterators.product( ) fname = Symbol(f, t) @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) - state = OperationState($(string("arith.", fname)), loc) - IR.add_operands!(state, operands) - IR.add_results!(state, [type]) - Operation(state) + IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop for f in (:index_cast, :index_castui) @eval function $f(context, operand; loc=Location(context)) - state = OperationState($(string("arith.", f)), loc) - add_operands!(state, [operand]) - add_results!(state, [IR.IndexType(context)]) - Operation(state) + IR.create_operation( + $(string("arith.", f)), + loc; + operands=[operand], + results=[IR.IndexType(context)], + ) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop function extf(context, operand, type; loc=Location(context)) - state = OperationState("arith.exf", loc) - IR.add_results!(state, [type]) - IR.add_operands!(state, [operand]) - Operation(state) + IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) end -# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithsitofp-mlirarithsitofpop -function sitofp(context, operand, ftype=float(julia_type(eltype(get_type(operand)))); loc=Location(context)) - state = OperationState("arith.sitofp", loc) - type = get_type(operand) - IR.add_results!(state, [ - IR.is_tensor(type) ? - MType(context, ftype isa MType ? eltype(ftype) : MType(context, ftype), size(type)) : - MType(context, ftype) - ]) - IR.add_operands!(state, [operand]) - Operation(state) -end - -# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithfptosi-mlirarithfptosiop -function fptosi(context, operand, itype; loc=Location(context)) - state = OperationState("arith.fptosi", loc) - type = get_type(operand) - IR.add_results!(state, [ - IR.is_tensor(type) ? - MType(context, itype isa MType ? itype : MType(context, itype), size(type)) : - MType(context, itype) - ]) - IR.add_operands!(state, [operand]) - Operation(state) -end - - # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop -function constant(context, value, type=MType(context, typeof(value)); loc=Location(context)) - state = OperationState("arith.constant", loc) - IR.add_results!(state, [type]) - IR.add_attributes!(state, [ - IR.NamedAttribute(context, "value", - Attribute(context, value, type)), - ]) - Operation(state) +function constant(context, value, type=MLIRType(context, typeof(value)); loc=Location(context)) + IR.create_operation( + "arith.constant", + loc; + results=[type], + attributes=[ + IR.NamedAttribute(context, "value", + Attribute(context, value, type)), + ], + ) end module Predicates @@ -109,14 +74,16 @@ module Predicates end function cmpi(context, predicate, operands; loc=Location(context)) - state = OperationState("arith.cmpi", loc) - IR.add_operands!(state, operands) - IR.add_attributes!(state, [ - IR.NamedAttribute(context, "predicate", - Attribute(context, predicate)) - ]) - IR.add_results!(state, [MType(context, Bool)]) - Operation(state) + IR.create_operation( + "arith.cmpi", + loc; + operands, + results=[MLIRType(context, Bool)], + attributes=[ + IR.NamedAttribute(context, "predicate", + Attribute(context, predicate)) + ], + ) end end # module arith @@ -127,16 +94,11 @@ module std using ...IR function return_(context, operands; loc=Location(context)) - state = OperationState("std.return", loc) - IR.add_operands!(state, operands) - Operation(state) + IR.create_operation("std.return", loc; operands, result_inference=false) end function br(context, dest, operands; loc=Location(context)) - state = OperationState("std.br", loc) - IR.add_successors!(state, [dest]) - IR.add_operands!(state, operands) - Operation(state) + IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false) end function cond_br( @@ -146,14 +108,17 @@ function cond_br( false_dest_operands; loc=Location(context), ) - state = OperationState("std.cond_br", loc) - IR.add_successors!(state, [true_dest, false_dest]) - IR.add_operands!(state, [cond, true_dest_operands..., false_dest_operands...]) - IR.add_attributes!(state, [ - IR.NamedAttribute(context, "operand_segment_sizes", - IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) - ]) - Operation(state) + IR.create_operation( + "std.cond_br", + loc; + successors=[true_dest, false_dest], + operands=[cond, true_dest_operands..., false_dest_operands...], + attributes=[ + IR.NamedAttribute(context, "operand_segment_sizes", + IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ], + result_inference=false, + ) end end # module std @@ -164,9 +129,7 @@ module func using ...IR function return_(context, operands; loc=Location(context)) - state = OperationState("func.return", loc) - IR.add_operands!(state, operands) - Operation(state) + IR.create_operation("func.return", loc; operands, result_inference=false) end end # module func @@ -176,10 +139,7 @@ module cf using ...IR function br(context, dest, operands; loc=Location(context)) - state = OperationState("cf.br", loc) - IR.add_successors!(state, [dest]) - IR.add_operands!(state, operands) - Operation(state) + IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) end function cond_br( @@ -189,14 +149,16 @@ function cond_br( false_dest_operands; loc=Location(context), ) - state = OperationState("cf.cond_br", loc) - IR.add_successors!(state, [true_dest, false_dest]) - IR.add_operands!(state, [cond, true_dest_operands..., false_dest_operands...]) - IR.add_attributes!(state, [ - IR.NamedAttribute(context, "operand_segment_sizes", - IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) - ]) - Operation(state) + IR.create_operation( + "cf.cond_br", loc; + operands=[cond, true_dest_operands..., false_dest_operands...], + successors=[true_dest, false_dest], + attributes=[ + IR.NamedAttribute(context, "operand_segment_sizes", + IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ], + result_inference=false, + ) end end # module cf diff --git a/src/IR.jl b/src/IR.jl index 01599027..219ce4fd 100644 --- a/src/IR.jl +++ b/src/IR.jl @@ -9,19 +9,12 @@ export Context, MModule, Value, - MType, + MLIRType, Region, Block, Attribute, NamedAttribute -export - add_results!, - add_attributes!, - add_owned_regions!, - add_successors! - - import Base: ==, String using .API: MlirDialectRegistry, @@ -150,15 +143,8 @@ struct Location end Location(context::Context) = Location(API.mlirLocationUnknownGet(context)) -Location(context::Context, filename, line, column=0) = +Location(context::Context, filename, line, column) = Location(API.mlirLocationFileLineColGet(context, filename, line, column)) -Location(context::Context, lin::Core.LineInfoNode) = - Location(context, string(lin.file), lin.line) -Location(context::Context, lin::LineNumberNode) = - isnothing(lin.file) ? - Location(context) : - Location(context, string(lin.file), lin.line) -Location(context::Context, ::Nothing) = Location(context) Base.convert(::Type{MlirLocation}, location::Location) = location.location @@ -172,61 +158,61 @@ end ### Type -struct MType +struct MLIRType type::MlirType - MType(type) = begin + MLIRType(type) = begin @assert !mlirIsNull(type) new(type) end end -MType(t::MType) = t -MType(context::Context, T::Type{<:Signed}) = - MType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) -MType(context::Context, T::Type{<:Unsigned}) = - MType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) -MType(context::Context, ::Type{Bool}) = - MType(API.mlirIntegerTypeGet(context, 1)) -MType(context::Context, ::Type{Float32}) = - MType(API.mlirF32TypeGet(context)) -MType(context::Context, ::Type{Float64}) = - MType(API.mlirF64TypeGet(context)) -MType(context::Context, ft::Pair) = - MType(API.mlirFunctionTypeGet(context, - length(ft.first), [MType(t) for t in ft.first], - length(ft.second), [MType(t) for t in ft.second])) -MType(context, a::AbstractArray{T}) where {T} = MType(context, MType(context, T), size(a)) -MType(context, ::Type{<:AbstractArray{T,N}}, dims) where {T,N} = - MType(API.mlirRankedTensorTypeGetChecked( +MLIRType(t::MLIRType) = t +MLIRType(context::Context, T::Type{<:Signed}) = + MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +MLIRType(context::Context, T::Type{<:Unsigned}) = + MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +MLIRType(context::Context, ::Type{Bool}) = + MLIRType(API.mlirIntegerTypeGet(context, 1)) +MLIRType(context::Context, ::Type{Float32}) = + MLIRType(API.mlirF32TypeGet(context)) +MLIRType(context::Context, ::Type{Float64}) = + MLIRType(API.mlirF64TypeGet(context)) +MLIRType(context::Context, ft::Pair) = + MLIRType(API.mlirFunctionTypeGet(context, + length(ft.first), [MLIRType(t) for t in ft.first], + length(ft.second), [MLIRType(t) for t in ft.second])) +MLIRType(context, a::AbstractArray{T}) where {T} = MLIRType(context, MLIRType(context, T), size(a)) +MLIRType(context, ::Type{<:AbstractArray{T,N}}, dims) where {T,N} = + MLIRType(API.mlirRankedTensorTypeGetChecked( Location(context), N, collect(dims), - MType(context, T), + MLIRType(context, T), Attribute(), )) -MType(context, element_type::MType, dims) = - MType(API.mlirRankedTensorTypeGetChecked( +MLIRType(context, element_type::MLIRType, dims) = + MLIRType(API.mlirRankedTensorTypeGetChecked( Location(context), length(dims), collect(dims), element_type, Attribute(), )) -MType(context, ::T) where {T<:Real} = MType(context, T) -MType(_, type::MType) = type +MLIRType(context, ::T) where {T<:Real} = MLIRType(context, T) +MLIRType(_, type::MLIRType) = type -IndexType(context) = MType(API.mlirIndexTypeGet(context)) +IndexType(context) = MLIRType(API.mlirIndexTypeGet(context)) -Base.convert(::Type{MlirType}, mtype::MType) = mtype.type +Base.convert(::Type{MlirType}, mtype::MLIRType) = mtype.type -function Base.eltype(type::MType) +function Base.eltype(type::MLIRType) if API.mlirTypeIsAShaped(type) - MType(API.mlirShapedTypeGetElementType(type)) + MLIRType(API.mlirShapedTypeGetElementType(type)) else type end end -function show_inner(io::IO, type::MType) +function show_inner(io::IO, type::MLIRType) if API.mlirTypeIsAInteger(type) is_signless = API.mlirIntegerTypeIsSignless(type) is_signed = API.mlirIntegerTypeIsSigned(type) @@ -257,8 +243,8 @@ function show_inner(io::IO, type::MType) end end -function Base.show(io::IO, type::MType) - print(io, "MType(#= ") +function Base.show(io::IO, type::MLIRType) + print(io, "MLIRType(#= ") c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) ref = Ref(io) API.mlirTypePrint(type, c_print_callback, ref) @@ -272,7 +258,7 @@ function inttype(size, issigned) issigned ? IT : unsigned(IT) end -function julia_type(type::MType) +function julia_type(type::MLIRType) if API.mlirTypeIsAInteger(type) is_signed = API.mlirIntegerTypeIsSigned(type) || API.mlirIntegerTypeIsSignless(type) @@ -293,42 +279,42 @@ function julia_type(type::MType) end end -Base.ndims(type::MType) = +Base.ndims(type::MLIRType) = if API.mlirTypeIsAShaped(type) && API.mlirShapedTypeHasRank(type) API.mlirShapedTypeGetRank(type) else 0 end -Base.size(type::MType, i::Int) = API.mlirShapedTypeGetDimSize(type, i - 1) -Base.size(type::MType) = Tuple(size(type, i) for i in 1:ndims(type)) +Base.size(type::MLIRType, i::Int) = API.mlirShapedTypeGetDimSize(type, i - 1) +Base.size(type::MLIRType) = Tuple(size(type, i) for i in 1:ndims(type)) -function is_tensor(type::MType) +function is_tensor(type::MLIRType) API.mlirTypeIsAShaped(type) end -function is_integer(type::MType) +function is_integer(type::MLIRType) API.mlirTypeIsAInteger(type) end is_function_type(mtype) = API.mlirTypeIsAFunction(mtype) -function get_num_inputs(ftype) +function num_inputs(ftype::MLIRType) @assert is_function_type(ftype) "cannot get the number of inputs on type $(ftype), expected a function type" API.mlirFunctionTypeGetNumInputs(ftype) end -function get_num_results(ftype) +function num_results(ftype::MLIRType) @assert is_function_type(ftype) "cannot get the number of results on type $(ftype), expected a function type" API.mlirFunctionTypeGetNumResults(ftype) end -function get_input(ftype::MType, pos) +function get_input(ftype::MLIRType, pos) @assert is_function_type(ftype) "cannot get input on type $(ftype), expected a function type" - MType(API.mlirFunctionTypeGetInput(ftype, pos - 1)) + MLIRType(API.mlirFunctionTypeGetInput(ftype, pos - 1)) end -function get_result(ftype::MType, pos=1) +function get_result(ftype::MLIRType, pos=1) @assert is_function_type(ftype) "cannot get result on type $(ftype), expected a function type" - MType(API.mlirFunctionTypeGetResult(ftype, pos - 1)) + MLIRType(API.mlirFunctionTypeGetResult(ftype, pos - 1)) end ### Attribute @@ -339,33 +325,33 @@ end Attribute() = Attribute(API.mlirAttributeGetNull()) Attribute(context, s::AbstractString) = Attribute(API.mlirStringAttrGet(context, s)) -Attribute(type::MType) = Attribute(API.mlirTypeAttrGet(type)) -Attribute(context, f::F, type=MType(context, F)) where {F<:AbstractFloat} = Attribute( +Attribute(type::MLIRType) = Attribute(API.mlirTypeAttrGet(type)) +Attribute(context, f::F, type=MLIRType(context, F)) where {F<:AbstractFloat} = Attribute( API.mlirFloatAttrDoubleGet(context, type, Float64(f)) ) Attribute(context, i::T) where {T<:Integer} = Attribute( - API.mlirIntegerAttrGet(MType(context, T), Int64(i)) + API.mlirIntegerAttrGet(MLIRType(context, T), Int64(i)) ) function Attribute(context, values::T) where {T<:AbstractArray{Int32}} - type = MType(context, T, size(values)) + type = MLIRType(context, T, size(values)) Attribute( API.mlirDenseElementsAttrInt32Get(type, length(values), values) ) end function Attribute(context, values::T) where {T<:AbstractArray{Int64}} - type = MType(context, T, size(values)) + type = MLIRType(context, T, size(values)) Attribute( API.mlirDenseElementsAttrInt64Get(type, length(values), values) ) end function Attribute(context, values::T) where {T<:AbstractArray{Float64}} - type = MType(context, T, size(values)) + type = MLIRType(context, T, size(values)) Attribute( API.mlirDenseElementsAttrDoubleGet(type, length(values), values) ) end function Attribute(context, values::T) where {T<:AbstractArray{Float32}} - type = MType(context, T, size(values)) + type = MLIRType(context, T, size(values)) Attribute( API.mlirDenseElementsAttrFloatGet(type, length(values), values) ) @@ -401,12 +387,12 @@ function DenseArrayAttribute(context, values::AbstractVector{Int}) API.mlirDenseI64ArrayGet(context, length(values), collect(values)) ) end -function Attribute(context, value::Int, type::MType) +function Attribute(context, value::Int, type::MLIRType) Attribute( API.mlirIntegerAttrGet(type, value) ) end -function Attribute(context, value::Bool, ::MType=nothing) +function Attribute(context, value::Bool, ::MLIRType=nothing) Attribute( API.mlirBoolAttrGet(context, value) ) @@ -417,17 +403,17 @@ Base.parse(::Type{Attribute}, context, s) = Attribute(API.mlirAttributeParseGet(context, s)) function get_type(attribute::Attribute) - MType(API.mlirAttributeGetType(attribute)) + MLIRType(API.mlirAttributeGetType(attribute)) end -function get_type_value(attribute) +function type_value(attribute) @assert API.mlirAttributeIsAType(attribute) "attribute $(attribute) is not a type" - MType(API.mlirTypeAttrGetValue(attribute)) + MLIRType(API.mlirTypeAttrGetValue(attribute)) end -function get_bool_value(attribute) +function bool_value(attribute) @assert API.mlirAttributeIsABool(attribute) "attribute $(attribute) is not a boolean" API.mlirBoolAttrGetValue(attribute) end -function get_string_value(attribute) +function string_value(attribute) @assert API.mlirAttributeIsAString(attribute) "attribute $(attribute) is not a string attribute" String(API.mlirStringAttrGetValue(attribute)) end @@ -468,7 +454,7 @@ struct Value end end -get_type(value) = MType(API.mlirValueGetType(value)) +get_type(value) = MLIRType(API.mlirValueGetType(value)) Base.convert(::Type{MlirValue}, value::Value) = value.value Base.size(value::Value) = Base.size(get_type(value)) @@ -507,34 +493,6 @@ function get_owner(value::Value) return Operation(raw_op, false) end -### OperationState - -struct OperationState - opstate::Base.RefValue{MlirOperationState} -end - -OperationState(name, location) = OperationState(Ref(API.mlirOperationStateGet(name, location))) - -add_results!(state, results) = - API.mlirOperationStateAddResults(state.opstate, length(results), results) -add_operands!(state, operands) = - API.mlirOperationStateAddOperands(state.opstate, length(operands), operands) -function add_owned_regions!(state, regions) - mlir_regions = Base.convert.(MlirRegion, regions) - lose_ownership!.(regions) - API.mlirOperationStateAddOwnedRegions(state.opstate, length(mlir_regions), mlir_regions) -end -add_attributes!(state, attributes) = - API.mlirOperationStateAddAttributes(state.opstate, length(attributes), attributes) -add_successors!(state, successors) = - API.mlirOperationStateAddSuccessors( - state.opstate, length(successors), - convert(Vector{API.MlirBlock}, successors), - ) - -enable_type_inference!(state) = - API.mlirOperationStateEnableResultTypeInference(state.opstate) - ### Operation mutable struct Operation @@ -551,7 +509,56 @@ mutable struct Operation end end -Operation(state::OperationState) = Operation(API.mlirOperationCreate(state.opstate), true) +function create_operation( + name, loc; + results=nothing, + operands=nothing, + owned_regions=nothing, + successors=nothing, + attributes=nothing, + result_inference=isnothing(results), +) + GC.@preserve name loc begin + state = Ref(API.mlirOperationStateGet(name, loc)) + if !isnothing(results) + if result_inference + error("Result inference and provided results conflict") + end + API.mlirOperationStateAddResults(state, length(results), results) + end + if !isnothing(operands) + API.mlirOperationStateAddOperands(state, length(operands), operands) + end + if !isnothing(owned_regions) + lose_ownership!.(owned_regions) + GC.@preserve owned_regions begin + mlir_regions = Base.unsafe_convert.(MlirRegion, owned_regions) + API.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), mlir_regions) + end + end + if !isnothing(successors) + GC.@preserve successors begin + mlir_blocks = Base.unsafe_convert.(MlirBlock, successors) + API.mlirOperationStateAddSuccessors( + state, + length(mlir_blocks), + mlir_blocks, + ) + end + end + if !isnothing(attributes) + API.mlirOperationStateAddAttributes(state, length(attributes), attributes) + end + if result_inference + API.mlirOperationStateEnableResultTypeInference(state) + end + op = API.mlirOperationCreate(state) + if mlirIsNull(op) + error("Create Operation failed") + end + Operation(op, true) + end +end Base.copy(operation::Operation) = Operation(API.mlirOperationClone(operation)) @@ -560,7 +567,7 @@ function get_region(operation, i) i ∈ 1:num_regions(operation) && throw(BoundsError(operation, i)) Region(API.mlirOperationGetRegion(operation, i - 1), false) end -num_results(operation) = API.mlirOperationGetNumResults(operation) +num_results(operation::Operation) = API.mlirOperationGetNumResults(operation) get_results(operation) = [ get_result(operation, i) for i in 1:num_results(operation) @@ -592,11 +599,11 @@ function set_attribute_by_name!(operation, name, attribute) operation end -get_location(operation) = Location(API.mlirOperationGetLocation(operation)) -get_name(operation) = String(API.mlirOperationGetName(operation)) -get_block(operation) = Block(API.mlirOperationGetBlock(operation), false) -get_parent_operation(operation) = Operation(API.mlirOperationGetParentOperation(operation), false) -get_dialect(operation) = first(split(get_name(operation), '.')) |> Symbol +location(operation) = Location(API.mlirOperationGetLocation(operation)) +name(operation) = String(API.mlirOperationGetName(operation)) +block(operation) = Block(API.mlirOperationGetBlock(operation), false) +parent_operation(operation) = Operation(API.mlirOperationGetParentOperation(operation), false) +dialect(operation) = first(split(get_name(operation), '.')) |> Symbol function get_first_region(op::Operation) reg = iterate(RegionIterator(op)) @@ -619,7 +626,8 @@ end op::Operation == other::Operation = API.mlirOperationEqual(op, other) -Base.convert(::Type{MlirOperation}, op::Operation) = op.operation +Base.cconvert(::Type{MlirOperation}, operation::Operation) = operation +Base.unsafe_convert(::Type{MlirOperation}, operation::Operation) = operation.operation function lose_ownership!(operation::Operation) @assert operation.owned @@ -654,8 +662,8 @@ mutable struct Block end end -Block() = Block(MType[], Location[]) -function Block(args::Vector{MType}, locs::Vector{Location}) +Block() = Block(MLIRType[], Location[]) +function Block(args::Vector{MLIRType}, locs::Vector{Location}) @assert length(args) == length(locs) "there should be one args for each locs (got $(length(args)) & $(length(locs)))" Block(API.mlirBlockCreate(length(args), args, locs)) end @@ -690,7 +698,8 @@ end push_argument!(block::Block, type, loc) = Value(API.mlirBlockAddArgument(block, type, loc)) -Base.convert(::Type{MlirBlock}, block::Block) = block.block +Base.cconvert(::Type{MlirBlock}, block::Block) = block +Base.unsafe_convert(::Type{MlirBlock}, block::Block) = block.block function lose_ownership!(block::Block) @assert block.owned @@ -751,7 +760,8 @@ function lose_ownership!(region::Region) region end -Base.convert(::Type{MlirRegion}, region::Region) = region.region +Base.cconvert(::Type{MlirRegion}, region::Region) = region +Base.unsafe_convert(::Type{MlirRegion}, region::Region) = region.region ### Module @@ -936,7 +946,7 @@ end ### Pass # AbstractPass interface: -get_opname(::AbstractPass) = "" +opname(::AbstractPass) = "" function pass_run(::Context, ::P, op) where {P<:AbstractPass} error("pass $P does not implement `MLIR.pass_run`") end @@ -977,7 +987,7 @@ function create_external_pass!(oppass::OpPassManager, args...) create_external_pass!(oppass.pass, args...) end function create_external_pass!(manager, pass, name, argument, - description, opname=get_opname(pass), + description, opname=opname(pass), dependent_dialects=MlirDialectHandle[]) passid = TypeID(manager.allocator) callbacks = API.MlirExternalPassCallbacks( From 781f688a1bc4b08453b578b2f9bffc2ce6cd375b Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Sun, 13 Aug 2023 11:55:07 +0200 Subject: [PATCH 07/11] Fix boundscheck Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com> --- src/IR.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/IR.jl b/src/IR.jl index 219ce4fd..956a1c97 100644 --- a/src/IR.jl +++ b/src/IR.jl @@ -564,7 +564,7 @@ Base.copy(operation::Operation) = Operation(API.mlirOperationClone(operation)) num_regions(operation) = API.mlirOperationGetNumRegions(operation) function get_region(operation, i) - i ∈ 1:num_regions(operation) && throw(BoundsError(operation, i)) + i ∉ 1:num_regions(operation) && throw(BoundsError(operation, i)) Region(API.mlirOperationGetRegion(operation, i - 1), false) end num_results(operation::Operation) = API.mlirOperationGetNumResults(operation) From ef31356b092705bba22bba0946ab8c885cb5dde0 Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Sun, 13 Aug 2023 11:56:15 +0200 Subject: [PATCH 08/11] Add parse(::Type{MLIRType}, ::Context, s) Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com> --- src/IR.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/IR.jl b/src/IR.jl index 956a1c97..3941336d 100644 --- a/src/IR.jl +++ b/src/IR.jl @@ -203,7 +203,9 @@ MLIRType(_, type::MLIRType) = type IndexType(context) = MLIRType(API.mlirIndexTypeGet(context)) Base.convert(::Type{MlirType}, mtype::MLIRType) = mtype.type - +Base.parse(::Type{MLIRType}, context, s) = + MLIRType(API.mlirTypeParseGet(context, s)) + function Base.eltype(type::MLIRType) if API.mlirTypeIsAShaped(type) MLIRType(API.mlirShapedTypeGetElementType(type)) From 812f803837fca104875c276e66e02ee32e1813c3 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 13 Aug 2023 12:16:38 +0200 Subject: [PATCH 09/11] Refactor in files to mimic `mlir-c/include/` --- src/{ => IR}/IR.jl | 307 +-------------------------------------------- src/IR/Pass.jl | 176 ++++++++++++++++++++++++++ src/IR/Support.jl | 133 ++++++++++++++++++++ src/MLIR.jl | 2 +- 4 files changed, 314 insertions(+), 304 deletions(-) rename src/{ => IR}/IR.jl (75%) create mode 100644 src/IR/Pass.jl create mode 100644 src/IR/Support.jl diff --git a/src/IR.jl b/src/IR/IR.jl similarity index 75% rename from src/IR.jl rename to src/IR/IR.jl index 3941336d..3606d329 100644 --- a/src/IR.jl +++ b/src/IR/IR.jl @@ -36,20 +36,12 @@ using .API: MlirPassManager, MlirOpPassManager -function mlirIsNull(val) - val.ptr == C_NULL -end - function print_callback(str::MlirStringRef, userdata) data = unsafe_wrap(Array, Base.convert(Ptr{Cchar}, str.data), str.length; own=false) write(userdata isa Base.RefValue ? userdata[] : userdata, data) return Cvoid() end -### Identifier - -String(ident::MlirIdentifier) = String(API.mlirIdentifierStr(ident)) - ### Dialect struct Dialect @@ -821,7 +813,8 @@ mutable struct TypeIDAllocator end end -Base.convert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator.allocator +Base.cconvert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator +Base.unsafe_convert(::Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) @@ -831,299 +824,7 @@ struct TypeIDAllocator end end -### Pass Manager - -abstract type AbstractPass end - -mutable struct ExternalPassHandle - ctx::Union{Nothing,Context} - pass::AbstractPass -end - -mutable struct PassManager - pass::MlirPassManager - context::Context - allocator::TypeIDAllocator - passes::Dict{TypeID,ExternalPassHandle} - - PassManager(pm::MlirPassManager, context) = begin - @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" - finalizer(new(pm, context, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm - API.mlirPassManagerDestroy(pm.pass) - end - end -end - -function enable_ir_printing!(pm) - API.mlirPassManagerEnableIRPrinting(pm) - pm -end -function enable_verifier!(pm, enable=true) - API.mlirPassManagerEnableVerifier(pm, enable) - pm -end - -PassManager(context) = - PassManager(API.mlirPassManagerCreate(context), context) - -function run!(pm::PassManager, module_) - status = API.mlirPassManagerRun(pm, module_) - if mlirLogicalResultIsFailure(status) - throw("failed to run pass manager on module") - end - module_ -end - -Base.convert(::Type{MlirPassManager}, pass::PassManager) = pass.pass - -### Op Pass Manager - -struct OpPassManager - op_pass::MlirOpPassManager - pass::PassManager - - OpPassManager(op_pass, pass) = begin - @assert !mlirIsNull(op_pass) "cannot create OpPassManager with null MlirOpPassManager" - new(op_pass, pass) - end -end - -OpPassManager(pm::PassManager) = OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) -OpPassManager(pm::PassManager, opname) = OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) -OpPassManager(opm::OpPassManager, opname) = OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) - -Base.convert(::Type{MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass - -function Base.show(io::IO, op_pass::OpPassManager) - c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) - ref = Ref(io) - println(io, "OpPassManager(\"\"\"") - API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) - println(io) - print(io, "\"\"\")") -end - -struct AddPipelineException <: Exception - message::String -end - -function Base.showerror(io::IO, err::AddPipelineException) - print(io, "failed to add pipeline:", err.message) - nothing -end - -mlirLogicalResultIsFailure(result) = result.value == 0 - -function add_pipeline!(op_pass::OpPassManager, pipeline) - @static if isdefined(API, :mlirOpPassManagerAddPipeline) - io = IOBuffer() - c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) - result = GC.@preserve io API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io) - if API.mlirLogicalResultIsFailure(result) - exc = AddPipelineException(String(take!(io))) - throw(exc) - end - else - result = API.mlirParsePassPipeline(op_pass, pipeline) - if mlirLogicalResultIsFailure(result) - throw(AddPipelineException(" " * pipeline)) - end - end - op_pass -end - -function add_owned_pass!(pm::PassManager, pass) - API.mlirPassManagerAddOwnedPass(pm, pass) - pm -end - -function add_owned_pass!(opm::OpPassManager, pass) - API.mlirOpPassManagerAddOwnedPass(opm, pass) - opm -end - - -@static if isdefined(API, :mlirCreateExternalPass) - -### Pass - -# AbstractPass interface: -opname(::AbstractPass) = "" -function pass_run(::Context, ::P, op) where {P<:AbstractPass} - error("pass $P does not implement `MLIR.pass_run`") -end - -function _pass_construct(ptr::ExternalPassHandle) - nothing -end - -function _pass_destruct(ptr::ExternalPassHandle) - nothing -end - -function _pass_initialize(ctx, handle::ExternalPassHandle) - try - handle.ctx = Context(ctx) - API.mlirLogicalResultSuccess() - catch - API.mlirLogicalResultFailure() - end -end - -function _pass_clone(handle::ExternalPassHandle) - ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) -end - -function _pass_run(rawop, external_pass, handle::ExternalPassHandle) - op = Operation(rawop, false) - try - pass_run(handle.ctx, handle.pass, op) - catch ex - @error "Something went wrong running pass" exception=(ex,catch_backtrace()) - API.mlirExternalPassSignalFailure(external_pass) - end - nothing -end - -function create_external_pass!(oppass::OpPassManager, args...) - create_external_pass!(oppass.pass, args...) -end -function create_external_pass!(manager, pass, name, argument, - description, opname=opname(pass), - dependent_dialects=MlirDialectHandle[]) - passid = TypeID(manager.allocator) - callbacks = API.MlirExternalPassCallbacks( - @cfunction(_pass_construct, Cvoid, (Any,)), - @cfunction(_pass_destruct, Cvoid, (Any,)), - @cfunction(_pass_initialize, API.MlirLogicalResult, (MlirContext, Any,)), - @cfunction(_pass_clone, Any, (Any,)), - @cfunction(_pass_run, Cvoid, (MlirOperation, API.MlirExternalPass, Any)) - ) - pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) - userdata = Base.pointer_from_objref(pass_handle) - mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, - length(dependent_dialects), dependent_dialects, - callbacks, userdata) - mlir_pass -end - -end - -### Iterators - -""" - BlockIterator(region::Region) - -Iterates over all blocks in the given region. -""" -struct BlockIterator - region::Region -end - -function Base.iterate(it::BlockIterator) - reg = it.region - raw_block = API.mlirRegionGetFirstBlock(reg) - if mlirIsNull(raw_block) - nothing - else - b = Block(raw_block, false) - (b, b) - end -end - -function Base.iterate(it::BlockIterator, block) - raw_block = API.mlirBlockGetNextInRegion(block) - if mlirIsNull(raw_block) - nothing - else - b = Block(raw_block, false) - (b, b) - end -end - -""" - OperationIterator(block::Block) - -Iterates over all operations for the given block. -""" -struct OperationIterator - block::Block -end - -function Base.iterate(it::OperationIterator) - raw_op = API.mlirBlockGetFirstOperation(it.block) - if mlirIsNull(raw_op) - nothing - else - op = Operation(raw_op, false) - (op, op) - end -end - -function Base.iterate(it::OperationIterator, op) - raw_op = API.mlirOperationGetNextInBlock(op) - if mlirIsNull(raw_op) - nothing - else - op = Operation(raw_op, false) - (op, op) - end -end - -""" - RegionIterator(::Operation) - -Iterates over all sub-regions for the given operation. -""" -struct RegionIterator - op::Operation -end - -function Base.iterate(it::RegionIterator) - raw_region = API.mlirOperationGetFirstRegion(it.op) - if mlirIsNull(raw_region) - nothing - else - region = Region(raw_region, false) - (region, region) - end -end - -function Base.iterate(it::RegionIterator, region) - raw_region = API.mlirRegionGetNextInOperation(region) - if mlirIsNull(raw_region) - nothing - else - region = Region(raw_region, false) - (region, region) - end -end - -### Utils - -function visit(f, op) - for region in RegionIterator(op) - for block in BlockIterator(region) - for op in OperationIterator(block) - f(op) - end - end - end -end - -""" - verifyall(operation; debug=false) - -Prints the operations which could not be verified. -""" -function verifyall(operation::Operation; debug=false) - io = IOContext(stdout, :debug => debug) - visit(operation) do op - if !verify(op) - show(io, op) - end - end -end -verifyall(module_::MModule) = get_operation(module_) |> verifyall +include("./Support.jl") +include("./Pass.jl") end # module IR diff --git a/src/IR/Pass.jl b/src/IR/Pass.jl new file mode 100644 index 00000000..7eef5b88 --- /dev/null +++ b/src/IR/Pass.jl @@ -0,0 +1,176 @@ +### Pass Manager + +abstract type AbstractPass end + +mutable struct ExternalPassHandle + ctx::Union{Nothing,Context} + pass::AbstractPass +end + +mutable struct PassManager + pass::MlirPassManager + context::Context + allocator::TypeIDAllocator + passes::Dict{TypeID,ExternalPassHandle} + + PassManager(pm::MlirPassManager, context) = begin + @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" + finalizer(new(pm, context, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm + API.mlirPassManagerDestroy(pm.pass) + end + end +end + +function enable_ir_printing!(pm) + API.mlirPassManagerEnableIRPrinting(pm) + pm +end +function enable_verifier!(pm, enable=true) + API.mlirPassManagerEnableVerifier(pm, enable) + pm +end + +PassManager(context) = + PassManager(API.mlirPassManagerCreate(context), context) + +function run!(pm::PassManager, module_) + status = API.mlirPassManagerRun(pm, module_) + if mlirLogicalResultIsFailure(status) + throw("failed to run pass manager on module") + end + module_ +end + +Base.convert(::Type{MlirPassManager}, pass::PassManager) = pass.pass + +### Op Pass Manager + +struct OpPassManager + op_pass::MlirOpPassManager + pass::PassManager + + OpPassManager(op_pass, pass) = begin + @assert !mlirIsNull(op_pass) "cannot create OpPassManager with null MlirOpPassManager" + new(op_pass, pass) + end +end + +OpPassManager(pm::PassManager) = OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) +OpPassManager(pm::PassManager, opname) = OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) +OpPassManager(opm::OpPassManager, opname) = OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) + +Base.convert(::Type{MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass + +function Base.show(io::IO, op_pass::OpPassManager) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + println(io, "OpPassManager(\"\"\"") + API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) + println(io) + print(io, "\"\"\")") +end + +struct AddPipelineException <: Exception + message::String +end + +function Base.showerror(io::IO, err::AddPipelineException) + print(io, "failed to add pipeline:", err.message) + nothing +end + +function add_pipeline!(op_pass::OpPassManager, pipeline) + @static if isdefined(API, :mlirOpPassManagerAddPipeline) + io = IOBuffer() + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + result = GC.@preserve io API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io) + if mlirLogicalResultIsFailure(result) + exc = AddPipelineException(String(take!(io))) + throw(exc) + end + else + result = API.mlirParsePassPipeline(op_pass, pipeline) + if mlirLogicalResultIsFailure(result) + throw(AddPipelineException(" " * pipeline)) + end + end + op_pass +end + +function add_owned_pass!(pm::PassManager, pass) + API.mlirPassManagerAddOwnedPass(pm, pass) + pm +end + +function add_owned_pass!(opm::OpPassManager, pass) + API.mlirOpPassManagerAddOwnedPass(opm, pass) + opm +end + + +@static if isdefined(API, :mlirCreateExternalPass) + +### Pass + +# AbstractPass interface: +opname(::AbstractPass) = "" +function pass_run(::Context, ::P, op) where {P<:AbstractPass} + error("pass $P does not implement `MLIR.pass_run`") +end + +function _pass_construct(ptr::ExternalPassHandle) + nothing +end + +function _pass_destruct(ptr::ExternalPassHandle) + nothing +end + +function _pass_initialize(ctx, handle::ExternalPassHandle) + try + handle.ctx = Context(ctx) + mlirLogicalResultSuccess() + catch + mlirLogicalResultFailure() + end +end + +function _pass_clone(handle::ExternalPassHandle) + ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) +end + +function _pass_run(rawop, external_pass, handle::ExternalPassHandle) + op = Operation(rawop, false) + try + pass_run(handle.ctx, handle.pass, op) + catch ex + @error "Something went wrong running pass" exception=(ex,catch_backtrace()) + API.mlirExternalPassSignalFailure(external_pass) + end + nothing +end + +function create_external_pass!(oppass::OpPassManager, args...) + create_external_pass!(oppass.pass, args...) +end +function create_external_pass!(manager, pass, name, argument, + description, opname=opname(pass), + dependent_dialects=MlirDialectHandle[]) + passid = TypeID(manager.allocator) + callbacks = API.MlirExternalPassCallbacks( + @cfunction(_pass_construct, Cvoid, (Any,)), + @cfunction(_pass_destruct, Cvoid, (Any,)), + @cfunction(_pass_initialize, API.MlirLogicalResult, (MlirContext, Any,)), + @cfunction(_pass_clone, Any, (Any,)), + @cfunction(_pass_run, Cvoid, (MlirOperation, API.MlirExternalPass, Any)) + ) + pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) + userdata = Base.pointer_from_objref(pass_handle) + mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, + length(dependent_dialects), dependent_dialects, + callbacks, userdata) + mlir_pass +end + +end + diff --git a/src/IR/Support.jl b/src/IR/Support.jl new file mode 100644 index 00000000..f84689e3 --- /dev/null +++ b/src/IR/Support.jl @@ -0,0 +1,133 @@ +function mlirIsNull(val) + val.ptr == C_NULL +end + +### Identifier + +String(ident::MlirIdentifier) = String(API.mlirIdentifierStr(ident)) + +### Logical Result + +mlirLogicalResultSuccess() = API.MlirLogicalResult(1) +mlirLogicalResultFailure() = API.MlirLogicalResult(0) + +mlirLogicalResultIsSuccess(result) = result.value != 0 +mlirLogicalResultIsFailure(result) = result.value == 0 + +### Iterators + +""" + BlockIterator(region::Region) + +Iterates over all blocks in the given region. +""" +struct BlockIterator + region::Region +end + +function Base.iterate(it::BlockIterator) + reg = it.region + raw_block = API.mlirRegionGetFirstBlock(reg) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +function Base.iterate(it::BlockIterator, block) + raw_block = API.mlirBlockGetNextInRegion(block) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +""" + OperationIterator(block::Block) + +Iterates over all operations for the given block. +""" +struct OperationIterator + block::Block +end + +function Base.iterate(it::OperationIterator) + raw_op = API.mlirBlockGetFirstOperation(it.block) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +function Base.iterate(it::OperationIterator, op) + raw_op = API.mlirOperationGetNextInBlock(op) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +""" + RegionIterator(::Operation) + +Iterates over all sub-regions for the given operation. +""" +struct RegionIterator + op::Operation +end + +function Base.iterate(it::RegionIterator) + raw_region = API.mlirOperationGetFirstRegion(it.op) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +function Base.iterate(it::RegionIterator, region) + raw_region = API.mlirRegionGetNextInOperation(region) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +### Utils + +function visit(f, op) + for region in RegionIterator(op) + for block in BlockIterator(region) + for op in OperationIterator(block) + f(op) + end + end + end +end + +""" + verifyall(operation; debug=false) + +Prints the operations which could not be verified. +""" +function verifyall(operation::Operation; debug=false) + io = IOContext(stdout, :debug => debug) + visit(operation) do op + if !verify(op) + show(io, op) + end + end +end +verifyall(module_::MModule) = get_operation(module_) |> verifyall + diff --git a/src/MLIR.jl b/src/MLIR.jl index 4dead798..36638296 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -35,7 +35,7 @@ function Base.unsafe_convert(::Type{API.MlirStringRef}, s::Union{Symbol, String, return API.MlirStringRef(p, length(s)) end -include("./IR.jl") +include("./IR/IR.jl") include("./Dialects.jl") end # module MLIR From 7a1ee3e6883cca84696ad7064b7a1626f0893666 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Sun, 13 Aug 2023 21:45:59 +0200 Subject: [PATCH 10/11] add LLVM.jl-like context state handling. * functions in Pass.jl have not yet been properly adapted. * I might've been a bit too enthousiastic with removing contexts: MModule and PassManager might still need to keep theirs. * The way code_mlir in brutus.jl now uses context might need a closer look. --- examples/brutus.jl | 63 +++++++------- src/Dialects.jl | 50 +++++------ src/IR/IR.jl | 201 +++++++++++++++++++++++---------------------- src/IR/Pass.jl | 108 ++++++++++++------------ src/IR/state.jl | 43 ++++++++++ src/MLIR.jl | 9 +- 6 files changed, 266 insertions(+), 208 deletions(-) create mode 100644 src/IR/state.jl diff --git a/examples/brutus.jl b/examples/brutus.jl index 36df2877..bba5835b 100644 --- a/examples/brutus.jl +++ b/examples/brutus.jl @@ -8,13 +8,13 @@ using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} function cmpi_pred(predicate) - function(ctx, ops; loc=Location(ctx)) - arith.cmpi(ctx, predicate, ops; loc) + function(ops; loc=Location()) + arith.cmpi(predicate, ops; loc) end end function single_op_wrapper(fop) - (ctx::Context, block::Block, args::Vector{Value}; loc=Location(ctx)) -> push!(block, fop(ctx, args; loc)) + (block::Block, args::Vector{Value}; loc=Location()) -> push!(block, fop(args; loc)) end const intrinsics_to_mlir = Dict([ @@ -24,15 +24,15 @@ const intrinsics_to_mlir = Dict([ Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), Base.mul_int => single_op_wrapper(arith.muli), Base.mul_float => single_op_wrapper(arith.mulf), - Base.not_int => function(ctx, block, args; loc=Location(ctx)) + Base.not_int => function(block, args; loc=Location()) arg = only(args) - ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result - push!(block, arith.xori(ctx, Value[arg, ones]; loc)) + ones = push!(block, arith.constant(-1, IR.get_type(arg); loc)) |> IR.get_result + push!(block, arith.xori(Value[arg, ones]; loc)) end, ]) "Generates a block argument for each phi node present in the block." -function prepare_block(ctx, ir, bb) +function prepare_block(ir, bb) b = Block() for sidx in bb.stmts @@ -41,7 +41,7 @@ function prepare_block(ctx, ir, bb) inst isa Core.PhiNode || continue type = stmt[:type] - IR.push_argument!(b, MLIRType(ctx, type), Location(ctx)) + IR.push_argument!(b, MLIRType(type), Location()) end return b @@ -68,7 +68,7 @@ function collect_value_arguments(ir, from, to) end """ - code_mlir(f, types::Type{Tuple}; ctx=Context()) -> IR.Operation + code_mlir(f, types::Type{Tuple}) -> IR.Operation Returns a `func.func` operation corresponding to the ircode of the provided method. This only supports a few Julia Core primitives and scalar types of type $BrutusScalar. @@ -78,25 +78,26 @@ This only supports a few Julia Core primitives and scalar types of type $BrutusS handful of primitives. A better to perform this conversion would to create a dialect representing Julia IR and progressively lower it to base MLIR dialects. """ -function code_mlir(f, types; ctx=Context()) +function code_mlir(f, types) + ctx = context() ir, ret = Core.Compiler.code_ircode(f, types) |> only @assert first(ir.argtypes) isa Core.Const values = Vector{Value}(undef, length(ir.stmts)) for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",)) - IR.get_or_load_dialect!(ctx, dialect) + IR.get_or_load_dialect!(dialect) end blocks = [ - prepare_block(ctx, ir, bb) + prepare_block(ir, bb) for bb in ir.cfg.blocks ] current_block = entry_block = blocks[begin] for argtype in types.parameters - IR.push_argument!(entry_block, MLIRType(ctx, argtype), Location(ctx)) + IR.push_argument!(entry_block, MLIRType(argtype), Location()) end function get_value(x)::Value @@ -106,7 +107,7 @@ function code_mlir(f, types; ctx=Context()) elseif x isa Core.Argument IR.get_argument(entry_block, x.n - 1) elseif x isa BrutusScalar - IR.get_result(push!(current_block, arith.constant(ctx, x))) + IR.get_result(push!(current_block, arith.constant(x))) else error("could not use value $x inside MLIR") end @@ -126,7 +127,7 @@ function code_mlir(f, types; ctx=Context()) if !(val_type <: BrutusScalar) error("type $val_type is not supported") end - out_type = MLIRType(ctx, val_type) + out_type = MLIRType(val_type) called_func = first(inst.args) if called_func isa GlobalRef # TODO: should probably use something else here @@ -136,8 +137,8 @@ function code_mlir(f, types; ctx=Context()) fop! = intrinsics_to_mlir[called_func] args = get_value.(@view inst.args[begin+1:end]) - loc = Location(ctx, string(line.file), line.line, 0) - res = IR.get_result(fop!(ctx, current_block, args; loc)) + loc = Location(string(line.file), line.line, 0) + res = IR.get_result(fop!(current_block, args; loc)) values[sidx] = res elseif inst isa PhiNode @@ -147,9 +148,9 @@ function code_mlir(f, types; ctx=Context()) elseif inst isa GotoNode args = get_value.(collect_value_arguments(ir, block_id, inst.label)) dest = blocks[inst.label] - loc = Location(ctx, string(line.file), line.line, 0) + loc = Location(string(line.file), line.line, 0) brop = LLVM.version() >= v"15" ? cf.br : std.br - push!(current_block, brop(ctx, dest, args; loc)) + push!(current_block, brop(dest, args; loc)) elseif inst isa GotoIfNot false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) cond = get_value(inst.cond) @@ -159,15 +160,15 @@ function code_mlir(f, types; ctx=Context()) other_dest = blocks[other_dest] dest = blocks[inst.dest] - loc = Location(ctx, string(line.file), line.line, 0) + loc = Location(string(line.file), line.line, 0) cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br - cond_br = cond_brop(ctx, cond, other_dest, dest, true_args, false_args; loc) + cond_br = cond_brop(cond, other_dest, dest, true_args, false_args; loc) push!(current_block, cond_br) elseif inst isa ReturnNode line = ir.linetable[stmt[:line]] retop = LLVM.version() >= v"15" ? func.return_ : std.return_ - loc = Location(ctx, string(line.file), line.line, 0) - push!(current_block, retop(ctx, [get_value(inst.val)]; loc)) + loc = Location(string(line.file), line.line, 0) + push!(current_block, retop([get_value(inst.val)]; loc)) elseif Meta.isexpr(inst, :code_coverage_effect) # Skip else @@ -189,15 +190,15 @@ function code_mlir(f, types; ctx=Context()) IR.get_type(IR.get_argument(entry_block, i)) for i in 1:IR.num_arguments(entry_block) ] - result_types = [MLIRType(ctx, ret)] + result_types = [MLIRType(ret)] - ftype = MLIRType(ctx, input_types => result_types) + ftype = MLIRType(input_types => result_types) op = IR.create_operation( LLVM15 ? "func.func" : "builtin.func", - Location(ctx); + Location(); attributes = [ - NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), - NamedAttribute(ctx, LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), + NamedAttribute("sym_name", IR.Attribute(string(func_name))), + NamedAttribute(LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), ], owned_regions = Region[region], result_inference=false, @@ -254,13 +255,13 @@ using MLIR.IR, MLIR ctx = Context() # IR.enable_multithreading!(ctx, false) -op = Brutus.code_mlir(pow, Tuple{Int, Int}; ctx) +op = Brutus.code_mlir(pow, Tuple{Int, Int}) -mod = MModule(ctx, Location(ctx)) +mod = MModule(Location()) body = IR.get_body(mod) push!(body, op) -pm = IR.PassManager(ctx) +pm = IR.PassManager() opm = IR.OpPassManager(pm) # IR.enable_ir_printing!(pm) diff --git a/src/Dialects.jl b/src/Dialects.jl index 4cb400eb..cd6f4244 100644 --- a/src/Dialects.jl +++ b/src/Dialects.jl @@ -9,13 +9,13 @@ for (f, t) in Iterators.product( (:i, :f), ) fname = Symbol(f, t) - @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end for fname in (:xori, :andi, :ori) - @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end @@ -25,37 +25,37 @@ for (f, t) in Iterators.product( (:si, :ui, :f), ) fname = Symbol(f, t) - @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop for f in (:index_cast, :index_castui) - @eval function $f(context, operand; loc=Location(context)) + @eval function $f(operand; loc=Location()) IR.create_operation( $(string("arith.", f)), loc; operands=[operand], - results=[IR.IndexType(context)], + results=[IR.IndexType()], ) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop -function extf(context, operand, type; loc=Location(context)) +function extf(operand, type; loc=Location()) IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop -function constant(context, value, type=MLIRType(context, typeof(value)); loc=Location(context)) +function constant(value, type=MLIRType(typeof(value)); loc=Location()) IR.create_operation( "arith.constant", loc; results=[type], attributes=[ - IR.NamedAttribute(context, "value", - Attribute(context, value, type)), + IR.NamedAttribute("value", + Attribute(value, type)), ], ) end @@ -73,15 +73,15 @@ module Predicates const uge = 9 end -function cmpi(context, predicate, operands; loc=Location(context)) +function cmpi(predicate, operands; loc=Location()) IR.create_operation( "arith.cmpi", loc; operands, - results=[MLIRType(context, Bool)], + results=[MLIRType(Bool)], attributes=[ - IR.NamedAttribute(context, "predicate", - Attribute(context, predicate)) + IR.NamedAttribute("predicate", + Attribute(predicate)) ], ) end @@ -93,20 +93,20 @@ module std using ...IR -function return_(context, operands; loc=Location(context)) +function return_(operands; loc=Location()) IR.create_operation("std.return", loc; operands, result_inference=false) end -function br(context, dest, operands; loc=Location(context)) +function br(dest, operands; loc=Location()) IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false) end function cond_br( - context, cond, + cond, true_dest, false_dest, true_dest_operands, false_dest_operands; - loc=Location(context), + loc=Location(), ) IR.create_operation( "std.cond_br", @@ -114,8 +114,8 @@ function cond_br( successors=[true_dest, false_dest], operands=[cond, true_dest_operands..., false_dest_operands...], attributes=[ - IR.NamedAttribute(context, "operand_segment_sizes", - IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + IR.NamedAttribute("operand_segment_sizes", + IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) ], result_inference=false, ) @@ -128,7 +128,7 @@ module func using ...IR -function return_(context, operands; loc=Location(context)) +function return_(operands; loc=Location()) IR.create_operation("func.return", loc; operands, result_inference=false) end @@ -138,24 +138,24 @@ module cf using ...IR -function br(context, dest, operands; loc=Location(context)) +function br(dest, operands; loc=Location()) IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) end function cond_br( - context, cond, + cond, true_dest, false_dest, true_dest_operands, false_dest_operands; - loc=Location(context), + loc=Location(), ) IR.create_operation( "cf.cond_br", loc; operands=[cond, true_dest_operands..., false_dest_operands...], successors=[true_dest, false_dest], attributes=[ - IR.NamedAttribute(context, "operand_segment_sizes", - IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + IR.NamedAttribute("operand_segment_sizes", + IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) ], result_inference=false, ) diff --git a/src/IR/IR.jl b/src/IR/IR.jl index 3606d329..72701e7e 100644 --- a/src/IR/IR.jl +++ b/src/IR/IR.jl @@ -1,7 +1,3 @@ -module IR - -import ..API: API - export Operation, OperationState, @@ -90,38 +86,53 @@ end ### Context -mutable struct Context +struct Context context::MlirContext end + function Context() context = API.mlirContextCreate() @assert !mlirIsNull(context) "cannot create Context with null MlirContext" - finalizer(Context(context)) do context - API.mlirContextDestroy(context.context) + context = Context(context) + activate(context) + context +end + +function dispose(ctx::Context) + deactivate(ctx) + API.mlirContextDestroy(context.context) +end + +function Context(f::Core.Function) + ctx = Context() + try + f(ctx) + finally + dispose(ctx) end end Base.convert(::Type{MlirContext}, c::Context) = c.context -num_loaded_dialects(context) = API.mlirContextGetNumLoadedDialects(context) -function get_or_load_dialect!(context, handle::DialectHandle) - mlir_dialect = API.mlirDialectHandleLoadDialect(handle, context) +num_loaded_dialects() = API.mlirContextGetNumLoadedDialects(context()) +function get_or_load_dialect!(handle::DialectHandle) + mlir_dialect = API.mlirDialectHandleLoadDialect(handle, context()) if mlirIsNull(mlir_dialect) error("could not load dialect from handle $handle") else Dialect(mlir_dialect) end end -function get_or_load_dialect!(context, dialect::String) - get_or_load_dialect!(context, DialectHandle(Symbol(dialect))) +function get_or_load_dialect!(dialect::String) + get_or_load_dialect!(DialectHandle(Symbol(dialect))) end -function enable_multithreading!(context, enable=true) - API.mlirContextEnableMultithreading(context, enable) - context +function enable_multithreading!(enable=true) + API.mlirContextEnableMultithreading(context(), enable) + context() end -is_registered_operation(context, opname) = API.mlirContextIsRegisteredOperation(context, opname) +is_registered_operation(opname) = API.mlirContextIsRegisteredOperation(context(), opname) ### Location @@ -134,9 +145,9 @@ struct Location end end -Location(context::Context) = Location(API.mlirLocationUnknownGet(context)) -Location(context::Context, filename, line, column) = - Location(API.mlirLocationFileLineColGet(context, filename, line, column)) +Location() = Location(API.mlirLocationUnknownGet(context())) +Location(filename, line, column) = + Location(API.mlirLocationFileLineColGet(context(), filename, line, column)) Base.convert(::Type{MlirLocation}, location::Location) = location.location @@ -160,39 +171,39 @@ struct MLIRType end MLIRType(t::MLIRType) = t -MLIRType(context::Context, T::Type{<:Signed}) = - MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) -MLIRType(context::Context, T::Type{<:Unsigned}) = - MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) -MLIRType(context::Context, ::Type{Bool}) = - MLIRType(API.mlirIntegerTypeGet(context, 1)) -MLIRType(context::Context, ::Type{Float32}) = - MLIRType(API.mlirF32TypeGet(context)) -MLIRType(context::Context, ::Type{Float64}) = - MLIRType(API.mlirF64TypeGet(context)) -MLIRType(context::Context, ft::Pair) = - MLIRType(API.mlirFunctionTypeGet(context, +MLIRType(T::Type{<:Signed}) = + MLIRType(API.mlirIntegerTypeGet(context(), sizeof(T) * 8)) +MLIRType(T::Type{<:Unsigned}) = + MLIRType(API.mlirIntegerTypeGet(context(), sizeof(T) * 8)) +MLIRType(::Type{Bool}) = + MLIRType(API.mlirIntegerTypeGet(context(), 1)) +MLIRType(::Type{Float32}) = + MLIRType(API.mlirF32TypeGet(context())) +MLIRType(::Type{Float64}) = + MLIRType(API.mlirF64TypeGet(context())) +MLIRType(ft::Pair) = + MLIRType(API.mlirFunctionTypeGet(context(), length(ft.first), [MLIRType(t) for t in ft.first], length(ft.second), [MLIRType(t) for t in ft.second])) -MLIRType(context, a::AbstractArray{T}) where {T} = MLIRType(context, MLIRType(context, T), size(a)) -MLIRType(context, ::Type{<:AbstractArray{T,N}}, dims) where {T,N} = +MLIRType(a::AbstractArray{T}) where {T} = MLIRType(MLIRType(T), size(a)) +MLIRType(::Type{<:AbstractArray{T,N}}, dims) where {T,N} = MLIRType(API.mlirRankedTensorTypeGetChecked( - Location(context), + Location(), N, collect(dims), - MLIRType(context, T), + MLIRType(T), Attribute(), )) -MLIRType(context, element_type::MLIRType, dims) = +MLIRType(element_type::MLIRType, dims) = MLIRType(API.mlirRankedTensorTypeGetChecked( - Location(context), + Location(), length(dims), collect(dims), element_type, Attribute(), )) -MLIRType(context, ::T) where {T<:Real} = MLIRType(context, T) +MLIRType(::T) where {T<:Real} = MLIRType(T) MLIRType(_, type::MLIRType) = type -IndexType(context) = MLIRType(API.mlirIndexTypeGet(context)) +IndexType() = MLIRType(API.mlirIndexTypeGet(context())) Base.convert(::Type{MlirType}, mtype::MLIRType) = mtype.type Base.parse(::Type{MLIRType}, context, s) = @@ -246,10 +257,10 @@ function Base.show(io::IO, type::MLIRType) end function inttype(size, issigned) - size == 1 && issigned && return Bool - ints = (Int8, Int16, Int32, Int64, Int128) - IT = ints[Int(log2(size)) - 2] - issigned ? IT : unsigned(IT) + size == 1 && issigned && return Bool + ints = (Int8, Int16, Int32, Int64, Int128) + IT = ints[Int(log2(size))-2] + issigned ? IT : unsigned(IT) end function julia_type(type::MLIRType) @@ -318,83 +329,83 @@ struct Attribute end Attribute() = Attribute(API.mlirAttributeGetNull()) -Attribute(context, s::AbstractString) = Attribute(API.mlirStringAttrGet(context, s)) +Attribute(s::AbstractString) = Attribute(API.mlirStringAttrGet(context(), s)) Attribute(type::MLIRType) = Attribute(API.mlirTypeAttrGet(type)) -Attribute(context, f::F, type=MLIRType(context, F)) where {F<:AbstractFloat} = Attribute( - API.mlirFloatAttrDoubleGet(context, type, Float64(f)) +Attribute(f::F, type=MLIRType(F)) where {F<:AbstractFloat} = Attribute( + API.mlirFloatAttrDoubleGet(context(), type, Float64(f)) ) -Attribute(context, i::T) where {T<:Integer} = Attribute( - API.mlirIntegerAttrGet(MLIRType(context, T), Int64(i)) +Attribute(i::T) where {T<:Integer} = Attribute( + API.mlirIntegerAttrGet(MLIRType(T), Int64(i)) ) -function Attribute(context, values::T) where {T<:AbstractArray{Int32}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Int32}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrInt32Get(type, length(values), values) ) end -function Attribute(context, values::T) where {T<:AbstractArray{Int64}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Int64}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrInt64Get(type, length(values), values) ) end -function Attribute(context, values::T) where {T<:AbstractArray{Float64}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Float64}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrDoubleGet(type, length(values), values) ) end -function Attribute(context, values::T) where {T<:AbstractArray{Float32}} - type = MLIRType(context, T, size(values)) +function Attribute(values::T) where {T<:AbstractArray{Float32}} + type = MLIRType(T, size(values)) Attribute( API.mlirDenseElementsAttrFloatGet(type, length(values), values) ) end -function Attribute(context, values::AbstractArray{Int32}, type) +function Attribute(values::AbstractArray{Int32}, type) Attribute( API.mlirDenseElementsAttrInt32Get(type, length(values), values) ) end -function Attribute(context, values::AbstractArray{Int}, type) +function Attribute(values::AbstractArray{Int}, type) Attribute( API.mlirDenseElementsAttrInt64Get(type, length(values), values) ) end -function Attribute(context, values::AbstractArray{Float32}, type) +function Attribute(values::AbstractArray{Float32}, type) Attribute( API.mlirDenseElementsAttrFloatGet(type, length(values), values) ) end -function ArrayAttribute(context, values::AbstractVector{Int}) - elements = Attribute.((context,), values) +function ArrayAttribute(values::AbstractVector{Int}) + elements = Attribute.((context(),), values) Attribute( - API.mlirArrayAttrGet(context, length(elements), elements) + API.mlirArrayAttrGet(context(), length(elements), elements) ) end -function ArrayAttribute(context, attributes::Vector{Attribute}) +function ArrayAttribute(attributes::Vector{Attribute}) Attribute( - API.mlirArrayAttrGet(context, length(attributes), attributes), + API.mlirArrayAttrGet(context(), length(attributes), attributes), ) end -function DenseArrayAttribute(context, values::AbstractVector{Int}) +function DenseArrayAttribute(values::AbstractVector{Int}) Attribute( - API.mlirDenseI64ArrayGet(context, length(values), collect(values)) + API.mlirDenseI64ArrayGet(context(), length(values), collect(values)) ) end -function Attribute(context, value::Int, type::MLIRType) +function Attribute(value::Int, type::MLIRType) Attribute( API.mlirIntegerAttrGet(type, value) ) end -function Attribute(context, value::Bool, ::MLIRType=nothing) +function Attribute(value::Bool, ::MLIRType=nothing) Attribute( - API.mlirBoolAttrGet(context, value) + API.mlirBoolAttrGet(context(), value) ) end Base.convert(::Type{MlirAttribute}, attribute::Attribute) = attribute.attribute -Base.parse(::Type{Attribute}, context, s) = - Attribute(API.mlirAttributeParseGet(context, s)) +Base.parse(::Type{Attribute}, s) = + Attribute(API.mlirAttributeParseGet(context(), s)) function get_type(attribute::Attribute) MLIRType(API.mlirAttributeGetType(attribute)) @@ -426,10 +437,10 @@ struct NamedAttribute named_attribute::MlirNamedAttribute end -function NamedAttribute(context, name, attribute) +function NamedAttribute(name, attribute) @assert !mlirIsNull(attribute.attribute) NamedAttribute(API.mlirNamedAttributeGet( - API.mlirIdentifierGet(context, name), + API.mlirIdentifierGet(context(), name), attribute )) end @@ -510,7 +521,7 @@ function create_operation( owned_regions=nothing, successors=nothing, attributes=nothing, - result_inference=isnothing(results), + result_inference=isnothing(results) ) GC.@preserve name loc begin state = Ref(API.mlirOperationStateGet(name, loc)) @@ -534,9 +545,9 @@ function create_operation( GC.@preserve successors begin mlir_blocks = Base.unsafe_convert.(MlirBlock, successors) API.mlirOperationStateAddSuccessors( - state, - length(mlir_blocks), - mlir_blocks, + state, + length(mlir_blocks), + mlir_blocks, ) end end @@ -761,22 +772,21 @@ Base.unsafe_convert(::Type{MlirRegion}, region::Region) = region.region mutable struct MModule module_::MlirModule - context::Context - MModule(module_, context) = begin + MModule(module_) = begin @assert !mlirIsNull(module_) "cannot create MModule with null MlirModule" - finalizer(API.mlirModuleDestroy, new(module_, context)) + finalizer(API.mlirModuleDestroy, new(module_)) end end -MModule(context::Context, loc=Location(context)) = - MModule(API.mlirModuleCreateEmpty(loc), context) +MModule(loc::Location=Location()) = + MModule(API.mlirModuleCreateEmpty(loc)) get_operation(module_) = Operation(API.mlirModuleGetOperation(module_), false) get_body(module_) = Block(API.mlirModuleGetBody(module_), false) get_first_child_op(mod::MModule) = get_first_child_op(get_operation(mod)) Base.convert(::Type{MlirModule}, module_::MModule) = module_.module_ -Base.parse(::Type{MModule}, context, module_) = MModule(API.mlirModuleCreateParse(context, module_), context) +Base.parse(::Type{MModule}, module_) = MModule(API.mlirModuleCreateParse(context(), module_), context()) macro mlir_str(code) quote @@ -801,30 +811,29 @@ Base.convert(::Type{API.MlirTypeID}, typeid::TypeID) = typeid.typeid @static if isdefined(API, :MlirTypeIDAllocator) -### TypeIDAllocator + ### TypeIDAllocator -mutable struct TypeIDAllocator - allocator::API.MlirTypeIDAllocator + mutable struct TypeIDAllocator + allocator::API.MlirTypeIDAllocator - function TypeIDAllocator() - ptr = API.mlirTypeIDAllocatorCreate() - @assert ptr != C_NULL "cannot create TypeIDAllocator" - finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) + function TypeIDAllocator() + ptr = API.mlirTypeIDAllocatorCreate() + @assert ptr != C_NULL "cannot create TypeIDAllocator" + finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) + end end -end -Base.cconvert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator -Base.unsafe_convert(::Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator + Base.cconvert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator + Base.unsafe_convert(::Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator -TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) + TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) else -struct TypeIDAllocator end + struct TypeIDAllocator end end include("./Support.jl") include("./Pass.jl") -end # module IR diff --git a/src/IR/Pass.jl b/src/IR/Pass.jl index 7eef5b88..f4718dbe 100644 --- a/src/IR/Pass.jl +++ b/src/IR/Pass.jl @@ -9,13 +9,12 @@ end mutable struct PassManager pass::MlirPassManager - context::Context allocator::TypeIDAllocator passes::Dict{TypeID,ExternalPassHandle} - PassManager(pm::MlirPassManager, context) = begin + PassManager(pm::MlirPassManager) = begin @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" - finalizer(new(pm, context, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm + finalizer(new(pm, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm API.mlirPassManagerDestroy(pm.pass) end end @@ -30,8 +29,8 @@ function enable_verifier!(pm, enable=true) pm end -PassManager(context) = - PassManager(API.mlirPassManagerCreate(context), context) +PassManager() = + PassManager(API.mlirPassManagerCreate(context())) function run!(pm::PassManager, module_) status = API.mlirPassManagerRun(pm, module_) @@ -96,7 +95,7 @@ function add_pipeline!(op_pass::OpPassManager, pipeline) end op_pass end - + function add_owned_pass!(pm::PassManager, pass) API.mlirPassManagerAddOwnedPass(pm, pass) pm @@ -110,67 +109,66 @@ end @static if isdefined(API, :mlirCreateExternalPass) -### Pass + ### Pass -# AbstractPass interface: -opname(::AbstractPass) = "" -function pass_run(::Context, ::P, op) where {P<:AbstractPass} - error("pass $P does not implement `MLIR.pass_run`") -end + # AbstractPass interface: + opname(::AbstractPass) = "" + function pass_run(::Context, ::P, op) where {P<:AbstractPass} + error("pass $P does not implement `MLIR.pass_run`") + end -function _pass_construct(ptr::ExternalPassHandle) - nothing -end + function _pass_construct(ptr::ExternalPassHandle) + nothing + end -function _pass_destruct(ptr::ExternalPassHandle) - nothing -end + function _pass_destruct(ptr::ExternalPassHandle) + nothing + end -function _pass_initialize(ctx, handle::ExternalPassHandle) - try - handle.ctx = Context(ctx) - mlirLogicalResultSuccess() - catch - mlirLogicalResultFailure() + function _pass_initialize(ctx, handle::ExternalPassHandle) + try + handle.ctx = Context(ctx) + mlirLogicalResultSuccess() + catch + mlirLogicalResultFailure() + end end -end -function _pass_clone(handle::ExternalPassHandle) - ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) -end + function _pass_clone(handle::ExternalPassHandle) + ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) + end -function _pass_run(rawop, external_pass, handle::ExternalPassHandle) - op = Operation(rawop, false) - try - pass_run(handle.ctx, handle.pass, op) - catch ex - @error "Something went wrong running pass" exception=(ex,catch_backtrace()) - API.mlirExternalPassSignalFailure(external_pass) + function _pass_run(rawop, external_pass, handle::ExternalPassHandle) + op = Operation(rawop, false) + try + pass_run(handle.ctx, handle.pass, op) + catch ex + @error "Something went wrong running pass" exception = (ex, catch_backtrace()) + API.mlirExternalPassSignalFailure(external_pass) + end + nothing end - nothing -end -function create_external_pass!(oppass::OpPassManager, args...) - create_external_pass!(oppass.pass, args...) -end -function create_external_pass!(manager, pass, name, argument, - description, opname=opname(pass), - dependent_dialects=MlirDialectHandle[]) - passid = TypeID(manager.allocator) - callbacks = API.MlirExternalPassCallbacks( + function create_external_pass!(oppass::OpPassManager, args...) + create_external_pass!(oppass.pass, args...) + end + function create_external_pass!(manager, pass, name, argument, + description, opname=opname(pass), + dependent_dialects=MlirDialectHandle[]) + passid = TypeID(manager.allocator) + callbacks = API.MlirExternalPassCallbacks( @cfunction(_pass_construct, Cvoid, (Any,)), @cfunction(_pass_destruct, Cvoid, (Any,)), @cfunction(_pass_initialize, API.MlirLogicalResult, (MlirContext, Any,)), @cfunction(_pass_clone, Any, (Any,)), @cfunction(_pass_run, Cvoid, (MlirOperation, API.MlirExternalPass, Any)) - ) - pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) - userdata = Base.pointer_from_objref(pass_handle) - mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, - length(dependent_dialects), dependent_dialects, - callbacks, userdata) - mlir_pass -end - -end + ) + pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) + userdata = Base.pointer_from_objref(pass_handle) + mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, + length(dependent_dialects), dependent_dialects, + callbacks, userdata) + mlir_pass + end +end \ No newline at end of file diff --git a/src/IR/state.jl b/src/IR/state.jl new file mode 100644 index 00000000..072f65ee --- /dev/null +++ b/src/IR/state.jl @@ -0,0 +1,43 @@ +# Global state + +# to simplify the API, we maintain a stack of contexts in task local storage +# and pass them implicitly to MLIR API's that require them. + +export context, activate, deactivate, context! + +using ..IR + +_has_context() = haskey(task_local_storage(), :MLIRContext) && + !isempty(task_local_storage(:MLIRContext)) + +function context(; throw_error::Core.Bool=true) + if !_has_context() + throw_error && error("No MLIR context is active") + return nothing + end + last(task_local_storage(:MLIRContext)) +end + +function activate(ctx::Context) + stack = get!(task_local_storage(), :MLIRContext) do + Context[] + end + push!(stack, ctx) + return +end + +function deactivate(ctx::Context) + context() == ctx || error("Deactivating wrong context") + pop!(task_local_storage(:MLIRContext)) +end + +function context!(f, ctx::Context) + activate(ctx) + try + f() + finally + deactivate(ctx) + end +end + + diff --git a/src/MLIR.jl b/src/MLIR.jl index 36638296..4d200805 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -35,7 +35,14 @@ function Base.unsafe_convert(::Type{API.MlirStringRef}, s::Union{Symbol, String, return API.MlirStringRef(p, length(s)) end -include("./IR/IR.jl") +module IR + import ..API: API + + include("./IR/IR.jl") + include("./IR/state.jl") +end # module IR + include("./Dialects.jl") + end # module MLIR From 644ea20784f4a18b45fcfd398492aa1bf3bc1f0e Mon Sep 17 00:00:00 2001 From: jumerckx Date: Sat, 19 Aug 2023 20:06:44 +0200 Subject: [PATCH 11/11] idea for higher level API. This doesn't work well as code in the macros is `eval`ed. --- src/Dialects.jl | 52 +++++++++++------ src/MLIR.jl | 1 + src/highlevel.jl | 144 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 16 deletions(-) create mode 100644 src/highlevel.jl diff --git a/src/Dialects.jl b/src/Dialects.jl index cd6f4244..4b1211db 100644 --- a/src/Dialects.jl +++ b/src/Dialects.jl @@ -1,8 +1,9 @@ module Dialects -module arith +module Arith using ...IR +using ...Builder: blockbuilder, _has_blockbuilder for (f, t) in Iterators.product( (:add, :sub, :mul), @@ -10,13 +11,17 @@ for (f, t) in Iterators.product( ) fname = Symbol(f, t) @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) - IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + op = IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) end end for fname in (:xori, :andi, :ori) @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) - IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + op = IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) end end @@ -26,30 +31,36 @@ for (f, t) in Iterators.product( ) fname = Symbol(f, t) @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) - IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + op = IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop for f in (:index_cast, :index_castui) @eval function $f(operand; loc=Location()) - IR.create_operation( + op = IR.create_operation( $(string("arith.", f)), loc; operands=[operand], results=[IR.IndexType()], ) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) end end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop function extf(operand, type; loc=Location()) - IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) + op = IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op , 1) end # https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop function constant(value, type=MLIRType(typeof(value)); loc=Location()) - IR.create_operation( + op = IR.create_operation( "arith.constant", loc; results=[type], @@ -58,6 +69,8 @@ function constant(value, type=MLIRType(typeof(value)); loc=Location()) Attribute(value, type)), ], ) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) end module Predicates @@ -74,7 +87,7 @@ module Predicates end function cmpi(predicate, operands; loc=Location()) - IR.create_operation( + op = IR.create_operation( "arith.cmpi", loc; operands, @@ -84,11 +97,13 @@ function cmpi(predicate, operands; loc=Location()) Attribute(predicate)) ], ) + push!(blockbuilder().block, op) + return get_result(op, 1) end end # module arith -module std +module STD # for llvm 14 using ...IR @@ -123,7 +138,7 @@ end end # module std -module func +module Func # https://mlir.llvm.org/docs/Dialects/Func/ using ...IR @@ -134,22 +149,25 @@ end end # module func -module cf +module CF using ...IR +using ...Builder -function br(dest, operands; loc=Location()) - IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) +function br(dest, operands=[]; loc=Location()) + op = IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) + push!(Builder.blockbuilder().block, op) + return op # no value so returning operation itself (?) end function cond_br( cond, true_dest, false_dest, - true_dest_operands, - false_dest_operands; + true_dest_operands=[], + false_dest_operands=[]; loc=Location(), ) - IR.create_operation( + op = IR.create_operation( "cf.cond_br", loc; operands=[cond, true_dest_operands..., false_dest_operands...], successors=[true_dest, false_dest], @@ -159,6 +177,8 @@ function cond_br( ], result_inference=false, ) + push!(blockbuilder().block, op) + return op end end # module cf diff --git a/src/MLIR.jl b/src/MLIR.jl index 4d200805..3fc155ee 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -42,6 +42,7 @@ module IR include("./IR/state.jl") end # module IR +include("./highlevel.jl") include("./Dialects.jl") diff --git a/src/highlevel.jl b/src/highlevel.jl new file mode 100644 index 00000000..3c26fb43 --- /dev/null +++ b/src/highlevel.jl @@ -0,0 +1,144 @@ +module Builder + +export @Block, @Region + +using ...IR + +ctx = IR.Context() +loc = IR.Location() + +struct BlockBuilder + block::IR.Block + expr::Expr +end +_has_blockbuilder() = haskey(task_local_storage(), :BlockBuilder) && + !isempty(task_local_storage(:BlockBuilder)) + +function blockbuilder() + if !_has_blockbuilder() + error("No BlockBuilder is active") + return nothing + end + last(task_local_storage(:BlockBuilder)) +end +function activate(b::BlockBuilder) + stack = get!(task_local_storage(), :BlockBuilder) do + BlockBuilder[] + end + push!(stack, b) +end +function deactivate(b::BlockBuilder) + blockbuilder() == b || error("Deactivating wrong RegionBuilder") + pop!(task_local_storage(:BlockBuilder)) +end + +struct RegionBuilder + region::IR.Region + blockbuilders::Vector{BlockBuilder} +end +_has_regionbuilder() = haskey(task_local_storage(), :RegionBuilder) && + !isempty(task_local_storage(:RegionBuilder)) +function regionbuilder() + if !_has_regionbuilder() + error("No RegionBuilder is active") + return nothing + end + last(task_local_storage(:RegionBuilder)) +end +function activate(r::RegionBuilder) + stack = get!(task_local_storage(), :RegionBuilder) do + RegionBuilder[] + end + push!(stack, r) +end +function deactivate(r::RegionBuilder) + regionbuilder() == r || error("Deactivating wrong RegionBuilder") + pop!(task_local_storage(:RegionBuilder)) +end + +function Region(expr) + exprs = Expr[] + + #= Create region =# + region = IR.Region() + #= Push region on the stack =# + regionbuilder = RegionBuilder(region, BlockBuilder[]) + activate(regionbuilder) + #= + `expr` calls to @block. + These calls will create the block variables that + are referenced in control flow operations. + Blocks are added to the region at the top of the + stack and a queue of blocks is kept. The + expressions to generate the operations in each + block can't be executed yet since they can't + reference the blocks before their creation. + =# + push!(exprs, expr) + #= + Once the blocks are created, the operation + code can be run. This happens in order. All the + operations are pushed to the block at the front + of the queue + =# + push!(exprs, quote + for blockbuilder in $regionbuilder.blockbuilders + $activate(blockbuilder) + eval(blockbuilder.expr) + $deactivate(blockbuilder) + end + end) + + push!(exprs, quote + $deactivate($regionbuilder) + $region + end) + + return Expr(:block, exprs...) +end +macro Region(expr) + quote + $(esc(Region(expr))) + end +end + +function Block(expr) + block = IR.Block() + blockbuilder = BlockBuilder(block, expr) + + if (_has_regionbuilder()) + #= Add block to current region =# + push!(regionbuilder().region, block) + #= + Add blockbuilder to the queue to come back later to + generate its operations. + =# + push!(regionbuilder().blockbuilders, blockbuilder) + + #= + Only return the block, don't create the + operations yet. + =# + return quote + $block + end + else + #= + If there's no regionbuilder, the operations + defined in `expr` can immediately get executed + =# + return quote + $activate($blockbuilder) + $expr + $deactivate($blockbuilder) + $block + end + end +end +macro Block(expr) + quote + $(esc(Block(expr))) + end +end + +end # Builder \ No newline at end of file