From 61d6d4cdb67f61d1b1cbf6e8b155a5e8bfd1805a Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 28 Jan 2025 16:00:17 +0000 Subject: [PATCH 1/2] Support onnx.GridSampleV22 Signed-off-by: Rickert, Jonas --- docs/Dialects/onnx.md | 53 +++++++++++++++++++++- src/Builder/OpBuildTable.inc | 4 +- src/Dialect/ONNX/ONNXOps.td.inc | 54 ++++++++++++++++++++++- src/Dialect/ONNX/ONNXUnsupportedOps.hpp | 1 + src/Dialect/ONNX/Transforms/Decompose.cpp | 12 +++++ src/Dialect/ONNX/Transforms/Decompose.td | 9 ++++ test/mlir/onnx/onnx_decompose.mlir | 36 +++++++++++++++ utils/gen_onnx_mlir.py | 2 +- 8 files changed, 166 insertions(+), 5 deletions(-) diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 3996ad35d6..d447b19224 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -3531,6 +3531,57 @@ Effects: `MemoryEffects::Effect{}` _ONNX GridSample operation_ +Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. +For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), +the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), +the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). +More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), +the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + +The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). +The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values +at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) +and a padding mode (for `grid` positions falling outside the 2-dimensional image). + +For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. +They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + +The GridSample operator is often used in doing grid generator and sampler in the +[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). +See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
align_corners::mlir::IntegerAttr64-bit signed integer attribute
mode::mlir::StringAttrstring attribute
padding_mode::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values +| `grid` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Y` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values + +### `onnx.GridSampleV16` (ONNXGridSampleV16Op) + +_ONNX GridSample operation_ + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`. Currently, only spatial (4-D) inputs are supported. For input `X` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2), the output `Y` will have shape (N, C, H_out, W_out). @@ -3545,7 +3596,7 @@ They are used to interpolate output values of `Y[N, C, H_out, W_out]`. The GridSample operator is often used in doing grid generator and sampler in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<16>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 0a63f65ff7..0407168972 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -79,7 +79,7 @@ op_dialect_version_map_["GlobalMaxPool"] = {1}; op_dialect_version_map_["Gradient"] = {1}; op_dialect_version_map_["Greater"] = {13}; op_dialect_version_map_["GreaterOrEqual"] = {16}; -op_dialect_version_map_["GridSample"] = {16}; +op_dialect_version_map_["GridSample"] = {22, 16}; op_dialect_version_map_["GroupNormalization"] = {21, 18}; op_dialect_version_map_["HammingWindow"] = {17}; op_dialect_version_map_["HannWindow"] = {17}; @@ -358,6 +358,8 @@ import_handler_map_["GreaterOrEqual"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GridSample"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GridSampleV16"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GroupNormalization"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GroupNormalizationV18"] = diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 0516cd5f3e..89883de50a 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -3117,6 +3117,57 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", } def ONNXGridSampleOp:ONNX_Op<"GridSample", + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "ONNX GridSample operation"; + let description = [{ + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. + For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), + the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), + the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). + More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), + the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + + The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). + The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values + at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) + and a padding mode (for `grid` positions falling outside the 2-dimensional image). + + For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. + They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + + The GridSample operator is often used in doing grid generator and sampler in the + [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). + See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$grid, + DefaultValuedAttr:$align_corners, + DefaultValuedStrAttr:$mode, + DefaultValuedStrAttr:$padding_mode); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {30}; + } + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; + let hasVerifier = 1; +} + +def ONNXGridSampleV16Op:ONNX_Op<"GridSampleV16", [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GridSample operation"; let description = [{ @@ -3154,12 +3205,11 @@ def ONNXGridSampleOp:ONNX_Op<"GridSample", let extraClassDefinition = [{ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { - onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope); + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleV16OpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); return sh; } }]; - let hasVerifier = 1; } def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization", diff --git a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp index e5ecb07917..f314fa42d6 100644 --- a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp +++ b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp @@ -73,6 +73,7 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV11Op) CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op) CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op) CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op) +CONVERTED_TO_SUPPORTED_OPS(ONNXGridSampleV16Op) CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp) CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op) diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 22a5b7a179..ff7b41cc15 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -390,6 +390,18 @@ bool canSequenceAtBeReplaced(Value sequenceAtResult) { return true; } +Attribute upgradeGridSampleV16Mode(PatternRewriter &rewriter, Attribute mode) { + const auto stringMode = mlir::cast(mode); + if (stringMode.strref() == "bilinear") { + return rewriter.getStringAttr("linear"); + } + if (stringMode.strref() == "bicubic") { + return rewriter.getStringAttr("cubic"); + } + assert(stringMode.strref() == "nearest"); + return mode; +} + Value replaceSequenceAt( PatternRewriter &rewriter, Location loc, Value sequenceAtResult) { ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp(); diff --git a/src/Dialect/ONNX/Transforms/Decompose.td b/src/Dialect/ONNX/Transforms/Decompose.td index 00ae9f6ff3..bc7044b524 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.td +++ b/src/Dialect/ONNX/Transforms/Decompose.td @@ -73,6 +73,9 @@ def ReshapeElementsAttrToRank0 : NativeCodeCall< def ReplaceSequenceAt : NativeCodeCall< "onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">; + +def UpgradeGridSampleV16Mode : NativeCodeCall< + "onnx_mlir::upgradeGridSampleV16Mode($_builder, $0)">; def CanSequenceAtBeReplaced : Constraint, "check whether the SequenceAt can be replaced with split">; @@ -365,6 +368,12 @@ def ClipV12Pattern : Pat< (ONNXClipOp $x, $min, $max) >; +// Rewrite GridSample 16 to GridSample 22 +def GridSampleV16Pattern : Pat< + (ONNXGridSampleV16Op $x, $grid, $align_corners, $mode, $padding_mode), + (ONNXGridSampleOp $x, $grid, $align_corners, (UpgradeGridSampleV16Mode $mode), $padding_mode) +>; + def DFTV17Pattern : Pat< (ONNXDFTV17Op $x, $dft_length, $axis, $inverse, $onesided), (ONNXDFTOp $x, $dft_length, (ONNXConstantOpFromDenseAttr(createScalarDenseAttrRank0 $axis)), $inverse, $onesided) diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 57a7a4397c..8bf7df604a 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -2,6 +2,42 @@ // ----- +func.func @test_grid_sample_v16_bicubic(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bicubic", onnx_node_name = "GridSample_181", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_bicubic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "cubic", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_grid_sample_v16_bilinear(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_bilinear +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "linear", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_grid_sample_v16_nearest(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "nearest", onnx_node_name = "GridSample_181", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_nearest +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "nearest", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + func.func @test_dft(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none %0 ="onnx.DFTV17"(%arg0, %arg1) : (tensor, tensor)-> tensor<*xf32> diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 45e879b0e1..aa036f8dc5 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -160,7 +160,7 @@ "Gradient": [1], "Greater": [13], "GreaterOrEqual": [16], - "GridSample": [16], + "GridSample": [22, 16], "GroupNormalization": [21, 18], "HammingWindow": [17], "HannWindow": [17], From e48ba6c73ffb87fd325b86029a9ed9ac2f3b53a3 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 28 Jan 2025 16:00:51 +0000 Subject: [PATCH 2/2] Check attributes in GridSample verifier Signed-off-by: Rickert, Jonas --- .../ONNX/ONNXOps/Tensor/GridSample.cpp | 15 ++++++++++ test/mlir/onnx/invalid.mlir | 30 +++++++++++++++++-- test/mlir/onnx/onnx_shape_inference.mlir | 24 +++++++-------- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp index 027129a657..1dbfca9b0f 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp @@ -59,6 +59,21 @@ LogicalResult ONNXGridSampleOpShapeHelper::computeShape() { LogicalResult ONNXGridSampleOp::verify() { ONNXGridSampleOpAdaptor operandAdaptor(*this); + auto op = mlir::cast(*this); + + const auto alignCorners = op.getAlignCorners(); + if (alignCorners != 0 && alignCorners != 1) { + return emitOpError("align_corners needs to be 0 or 1"); + } + const auto mode = op.getMode(); + if (mode != "linear" && mode != "nearest" && mode != "cubic") { + return emitOpError("mode needs to be linear, nearest or cubic"); + } + const auto paddingMode = op.getPaddingMode(); + if (paddingMode != "zeros" && paddingMode != "border" && + paddingMode != "reflection") { + return emitOpError("padding_mode needs to be zeros, border or reflection"); + } if (!hasShapeAndRank(getOperation())) return success(); diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index 3fa25e883b..6f57e14198 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -820,7 +820,7 @@ func.func @test_mod_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16 func.func @test_grid_sample_diff_ranks(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{'onnx.GridSample' op Input(=4) and grid(=3) have different dim sizes.}} - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -828,7 +828,31 @@ func.func @test_grid_sample_diff_ranks(%arg0: tensor<1x3x1152x1344xf32>, %arg1: func.func @test_grid_sample_diff_batch(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{'onnx.GridSample' op Input and grid must have the same batch value.}} - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_align_corners(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op align_corners needs to be 0 or 1}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 2 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_mode(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op mode needs to be linear, nearest or cubic}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "sampling", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_padding(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op padding_mode needs to be zeros, border or reflection}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "cubic", onnx_node_name = "GridSample_181", padding_mode = "bottom"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -836,6 +860,6 @@ func.func @test_grid_sample_diff_batch(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor func.func @test_grid_sample_wrong_dim_grid(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x3xf32>) -> tensor<*xf32> { // expected-error @+1 {{'onnx.GridSample' op Grid last dim must have been '2' instead of '3'.}} - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x3xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x3xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 639601513a..9a73d90c0c 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -3984,69 +3984,69 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5 // Test Grid Sample func.func @test_grid_sample_same_dims(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_same_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x1152x1344xf32>, [[PARAM_1_:%.+]]: tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> // CHECK: return [[GRID]] : tensor<1x3x1152x1344xf32> // CHECK: } } func.func @test_grid_sample_diff_dims(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x2xf32>) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_diff_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> // CHECK: return [[GRID]] : tensor<1x1x6x6xf32> // CHECK: } } func.func @test_grid_sample_6d(%arg0: tensor<1x2x4x4x4x4xf32>, %arg1: tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_6d // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x4x4x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> // CHECK: return [[GRID]] : tensor<1x2x6x6x4x4xf32> // CHECK: } } func.func @test_grid_sample_dim_shape(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_dim_shape // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor // CHECK: return [[GRID]] : tensor // CHECK: } return %0 : tensor<*xf32> } func.func @test_grid_sample_dim_shape2(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_dim_shape2 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor // CHECK: return [[GRID]] : tensor // CHECK: } return %0 : tensor<*xf32> } func.func @test_grid_sample_dim_shape3(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_dim_shape3 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor // CHECK: return [[GRID]] : tensor // CHECK: } return %0 : tensor<*xf32>