diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp index 027129a657..1dbfca9b0f 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp @@ -59,6 +59,21 @@ LogicalResult ONNXGridSampleOpShapeHelper::computeShape() { LogicalResult ONNXGridSampleOp::verify() { ONNXGridSampleOpAdaptor operandAdaptor(*this); + auto op = mlir::cast(*this); + + const auto alignCorners = op.getAlignCorners(); + if (alignCorners != 0 && alignCorners != 1) { + return emitOpError("align_corners needs to be 0 or 1"); + } + const auto mode = op.getMode(); + if (mode != "linear" && mode != "nearest" && mode != "cubic") { + return emitOpError("mode needs to be linear, nearest or cubic"); + } + const auto paddingMode = op.getPaddingMode(); + if (paddingMode != "zeros" && paddingMode != "border" && + paddingMode != "reflection") { + return emitOpError("padding_mode needs to be zeros, border or reflection"); + } if (!hasShapeAndRank(getOperation())) return success(); diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index 3fa25e883b..6f57e14198 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -820,7 +820,7 @@ func.func @test_mod_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16 func.func @test_grid_sample_diff_ranks(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{'onnx.GridSample' op Input(=4) and grid(=3) have different dim sizes.}} - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -828,7 +828,31 @@ func.func @test_grid_sample_diff_ranks(%arg0: tensor<1x3x1152x1344xf32>, %arg1: func.func @test_grid_sample_diff_batch(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{'onnx.GridSample' op Input and grid must have the same batch value.}} - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_align_corners(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op align_corners needs to be 0 or 1}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 2 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_mode(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op mode needs to be linear, nearest or cubic}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "sampling", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_padding(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op padding_mode needs to be zeros, border or reflection}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "cubic", onnx_node_name = "GridSample_181", padding_mode = "bottom"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -836,6 +860,6 @@ func.func @test_grid_sample_diff_batch(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor func.func @test_grid_sample_wrong_dim_grid(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x3xf32>) -> tensor<*xf32> { // expected-error @+1 {{'onnx.GridSample' op Grid last dim must have been '2' instead of '3'.}} - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x3xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x3xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 639601513a..9a73d90c0c 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -3984,69 +3984,69 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5 // Test Grid Sample func.func @test_grid_sample_same_dims(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_same_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x1152x1344xf32>, [[PARAM_1_:%.+]]: tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> // CHECK: return [[GRID]] : tensor<1x3x1152x1344xf32> // CHECK: } } func.func @test_grid_sample_diff_dims(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x2xf32>) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_diff_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> // CHECK: return [[GRID]] : tensor<1x1x6x6xf32> // CHECK: } } func.func @test_grid_sample_6d(%arg0: tensor<1x2x4x4x4x4xf32>, %arg1: tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_6d // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x4x4x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> // CHECK: return [[GRID]] : tensor<1x2x6x6x4x4xf32> // CHECK: } } func.func @test_grid_sample_dim_shape(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_dim_shape // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor // CHECK: return [[GRID]] : tensor // CHECK: } return %0 : tensor<*xf32> } func.func @test_grid_sample_dim_shape2(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_dim_shape2 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor // CHECK: return [[GRID]] : tensor // CHECK: } return %0 : tensor<*xf32> } func.func @test_grid_sample_dim_shape3(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { - %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_grid_sample_dim_shape3 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor // CHECK: return [[GRID]] : tensor // CHECK: } return %0 : tensor<*xf32>