Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chaitany.convtranspose as option #279

Merged
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ project(onnx-mlir)
option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON)
option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF)
option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON)
option(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE "Enable ONNXConvTransposeOp decomposition." ON)
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)
option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON)
Expand Down Expand Up @@ -223,12 +222,6 @@ if (ONNX_MLIR_ENABLE_STABLEHLO)
add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO)
endif()

if (ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
add_compile_definitions(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 1)
else()
set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 0)
endif()

add_subdirectory(utils)
add_subdirectory(include)
Expand Down
7 changes: 7 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ bool enableParallel; // onnx-mlir only
bool disableSimdOption; // onnx-mlir only
bool enableFastMathOption; // onnx-mlir only
bool disableRecomposeOption; // onnx-mlir only
bool disableConvTransposeDecomposeOption; // onnx-mlir only
bool enableSimdDataLayout; // onnx-mlir only
bool verifyInputTensors; // onnx-mlir only
bool allowSorting; // onnx-mlir only
Expand Down Expand Up @@ -247,6 +248,12 @@ static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> disableConvTranposeDecomposeOptionOpt(
"disable-convtranspose-decompose",
llvm::cl::desc("Disable decomposition of ONNX ConvTranspose operator."),
llvm::cl::location(disableConvTransposeDecomposeOption),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));

// Options for onnx-mlir only
static llvm::cl::opt<EmissionTargetType, true> emissionTargetOpt(
llvm::cl::desc("Choose target to emit:"),
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ extern bool enableParallel; // onnx-mlir only
extern bool disableSimdOption; // onnx-mlir only
extern bool enableFastMathOption; // onnx-mlir only
extern bool disableRecomposeOption; // onnx-mlir only
extern bool disableConvTransposeDecomposeOption; // onnx-mlir only
extern bool enableSimdDataLayout; // onnx-mlir only
extern bool verifyInputTensors; // onnx-mlir only
extern bool allowSorting; // onnx-mlir only
Expand Down
10 changes: 5 additions & 5 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -451,15 +452,14 @@ Value replaceSequenceAt(
}

bool shouldDecomposeConvTransposeOp(Value convTransposeResult) {
#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE
if (onnx_mlir::disableConvTransposeDecomposeOption) {
// Disable the ONNXConvTransposeOp decomposition patterns.
chaitanyakamarapu marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
ONNXConvTransposeOp op =
mlir::cast<ONNXConvTransposeOp>(convTransposeResult.getDefiningOp());
return hasShapeAndRank(convTransposeResult) &&
hasStaticSpatialDims(op.getX()) && hasStaticSpatialDims(op.getW());
#else
// Disable the ONNXConvTransposeOp decomposition patterns.
return false;
#endif
}

// Split on the specified axis. The length of each output is one.
Expand Down
3 changes: 0 additions & 3 deletions test/mlir/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,3 @@
# execution based on the available targets
for arch in config.targets_to_build.split():
config.available_features.add(arch.lower())

if config.decomp_onnx_convtranspose:
config.available_features.add("decomp_onnx_convtranspose")
1 change: 0 additions & 1 deletion test/mlir/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ config.onnx_mlir_obj_root = r"@ONNX_MLIR_BIN_ROOT@"

config.enable_stablehlo = @ONNX_MLIR_STABLEHLO_ENABLED@
config.enable_nnpa= 0x0@NNPA_LIT_ENABLED@
config.decomp_onnx_convtranspose = @ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED@

# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
Expand Down
1 change: 0 additions & 1 deletion test/mlir/onnx/onnx_decompose_convtranspose.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx %s -split-input-file | FileCheck %s

// REQUIRES: decomp_onnx_convtranspose

// -----

Expand Down
104 changes: 104 additions & 0 deletions test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx --disable-convtranspose-decompose %s -split-input-file | FileCheck %s


// -----

// Test unit strides. Only convert weight tensor

func.func @test_convtrans_unitstrides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32>
onnx.Return %1 : tensor<1x2x5x5xf32>
// CHECK-LABEL: func.func @test_convtrans_unitstrides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x5xf32>
// CHECK: }
}

// -----

// Test 1d input

func.func @test_convtrans1d_unitstrides(%arg0: tensor<1x1x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32>
onnx.Return %1 : tensor<1x2x5xf32>
// CHECK-LABEL: func.func @test_convtrans1d_unitstrides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5xf32>
// CHECK: }
}

// -----

// Test 3d input

func.func @test_convtrans3d_unitstrides(%arg0: tensor<1x1x3x4x5xf32>, %arg1: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32>
onnx.Return %1 : tensor<1x2x5x6x7xf32>
// CHECK-LABEL: func.func @test_convtrans3d_unitstrides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x4x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x6x7xf32>
// CHECK: }
}

// -----

// Test non unit strides. Added pads between elements in input data.

func.func @test_convtrans_strides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32>
onnx.Return %1 : tensor<1x2x7x3xf32>
// CHECK-LABEL: func.func @test_convtrans_strides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x7x3xf32>
// CHECK: }
}

// -----

// Test output_padding. Additional pads are inserted after Conv op

func.func @test_convtrans_outputpadding(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32>
onnx.Return %1 : tensor<1x2x10x8xf32>
// CHECK-LABEL: func.func @test_convtrans_outputpadding(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x10x8xf32>
// CHECK: }
}

// -----

// Test for unknown dimension in spatial dimensions

func.func @test_convtranspose_unknown_spatial_dim(%arg0: tensor<?x?x3x3xf32>, %arg1: tensor<?x?x3x3xf32>) -> tensor<?x?x10x8xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor<?x?x3x3xf32>, tensor<?x?x3x3xf32>, none) -> tensor<?x?x10x8xf32>
onnx.Return %1 : tensor<?x?x10x8xf32>
// CHECK-LABEL: func.func @test_convtranspose_unknown_spatial_dim(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?x3x3xf32>) -> tensor<?x?x10x8xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor<?x?x3x3xf32>, tensor<?x?x3x3xf32>, none) -> tensor<?x?x10x8xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<?x?x10x8xf32>
// CHECK: }
}