Skip to content

Commit

Permalink
Add topk operator (#66)
Browse files Browse the repository at this point in the history
Add topk operator
  • Loading branch information
mgehre-amd authored Aug 21, 2024
1 parent b0f0905 commit 9722914
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 18 deletions.
26 changes: 26 additions & 0 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,32 @@ def XtenNN_SignOp: XTenNN_Op<"sign", [Pure, TosaExtension, ElementwiseUnary, Sam
let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }];
}

def XtenNN_TopK: XTenNN_Op<"topk", [
InferTensorTypeAdaptor,
Pure, TosaExtension]> {
let summary = "Calculate the topk";
let description = [{
Follows the specification of ONNX TopK at opset 11
}];
let arguments = (ins
AnyTensor:$input,
I64:$k,
I64Attr:$axis,
I1Attr:$largest,
I1Attr:$sorted
);
let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices);

let assemblyFormat = [{ `(`$input `:` type($input) `,` $k `:` type($k)`)` attr-dict `->` type($output) `,` type($indices) }];

let extraClassDeclaration = [{
/// Returns when two result types are compatible for this op; method used by
/// InferTypeOpInterface.
static bool isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r);
}];
}


def XtenNN_ConvTransposeOp: XTenNN_Op<"ConvTranspose",[Pure, TosaExtension]> {
let summary = "Perform ConvTranspose operation";
let description = [{
Expand Down
102 changes: 84 additions & 18 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
Expand Down Expand Up @@ -96,9 +97,7 @@ static ParseResult parseCaptures(OpAsmParser &p,
/// See parseCaptures() for more details.
static void printCaptures(OpAsmPrinter &p, ValueRange srcs) {
p << '(';
llvm::interleaveComma(srcs, p, [&](auto src) {
printCapture(p, src);
});
llvm::interleaveComma(srcs, p, [&](auto src) { printCapture(p, src); });
p << ')';
}

Expand Down Expand Up @@ -172,7 +171,6 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) {
};
}


//===----------------------------------------------------------------------===//
// KernelOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -286,8 +284,7 @@ LogicalResult SubgraphOp::verify() {
}

// The type of the arguments must match the types of the block arguments
for (auto [idx, argType] :
enumerate(optBody->getArgumentTypes())) {
for (auto [idx, argType] : enumerate(optBody->getArgumentTypes())) {
if (this->getCapture(idx).getType() != argType) {
return this->emitOpError()
<< "type of operand #" << idx << " ("
Expand Down Expand Up @@ -349,11 +346,12 @@ OpFoldResult amd::xten_nn::QuantizeOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult amd::xten_nn::GroupQuantizeOp::fold(FoldAdaptor adaptor) {
// Fold away cases where a xten_nn.group_quantize is preceeded by xten_nn.group_dequantize
// that uses the same shift factor and has same types.
// Fold away cases where a xten_nn.group_quantize is preceeded by
// xten_nn.group_dequantize that uses the same shift factor and has same
// types.

auto dequantizeOp =
dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(getInput().getDefiningOp());
auto dequantizeOp = dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(
getInput().getDefiningOp());
if (!dequantizeOp)
return {};

Expand Down Expand Up @@ -412,19 +410,25 @@ LogicalResult amd::xten_nn::GroupQuantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (inputShape != quantsShape) {
return emitOpError() << "input and quants must have the same shape (" << inputShape << " v " << quantsShape << ")";
return emitOpError() << "input and quants must have the same shape ("
<< inputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand All @@ -441,19 +445,25 @@ LogicalResult amd::xten_nn::GroupDequantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (outputShape != quantsShape) {
return emitOpError() << "output and quants must have the same shape (" << outputShape << " v " << quantsShape << ")";
return emitOpError() << "output and quants must have the same shape ("
<< outputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand Down Expand Up @@ -519,3 +529,59 @@ LogicalResult amd::xten_nn::ResizeOp::verify() {

return success();
}

std::optional<uint64_t> getConstantK(Value k) {
auto *op = k.getDefiningOp();
if (!op) {
return {};
}
auto constantOp = dyn_cast<arith::ConstantOp>(op);
if (!constantOp)
return {};
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (!intAttr)
return {};
return (uint64_t)
intAttr.getInt(); // Always positive by definition of onnx.topk
}

LogicalResult TopK::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
TopK::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

auto inTy = cast<RankedTensorType>(adaptor.getInput().getType());

auto axis = adaptor.getAxis();
if (axis >= (uint64_t)inTy.getRank()) {
return emitOptionalError(location, "expected axis <= rank of input");
}

auto dimSize = inTy.getDimSize(axis);
auto k = getConstantK(adaptor.getK());
// If both k and dim are known statically, we can check that k <= dim
if (k && dimSize != ShapedType::kDynamic) {
if ((uint64_t)dimSize <= *k) {
return emitOptionalError(location, "expected k <= dimension size");
}
}

SmallVector<int64_t> resultShape{inTy.getShape()};
resultShape[axis] = k ? *k : ShapedType::kDynamic;

inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, inTy.getElementType()));
inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, IntegerType::get(context, 64)));
return success();
}

bool TopK::isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r) {
if (l.size() != r.size() || l.size() != 2)
return false;

auto sameElementType =
getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]) &&
getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]);
return sameElementType && succeeded(verifyCompatibleShapes(l, r));
}
42 changes: 42 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,45 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) {
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
return
}

// -----

// CHECK-LABEL: topk
func.func @topk(%arg0: tensor<10x8xf32>) {
%k = arith.constant 7 : i64
// CHECK: %[[C7:.*]] = arith.constant 7 : i64
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
return
}

// -----

// CHECK-LABEL: topk_arg
func.func @topk_arg(%arg0: tensor<10x8xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
return
}

// -----

// Make sure that the topk verification does not fail if the result type is
// static even though it cannot be statically infered due to the dynamic k
// CHECK-LABEL: topk_arg_type_inference
func.func @topk_arg_type_inference(%arg0: tensor<10x8xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x3xf32>, tensor<10x3xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x3xf32>, tensor<10x3xi64>
return
}

// -----

// CHECK-LABEL: topk_arg_dyn_in
func.func @topk_arg_dyn_in(%arg0: tensor<?x?xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<?x?xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<?x?xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
return
}
40 changes: 40 additions & 0 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,43 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected non-function type}}
xten_nn.kernel "myKernel" () ->
}

// -----

func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<1xf32>', 'tensor<1xi64>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<7x10xf32>', 'tensor<7x10xf32>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32>
return
}

// -----

func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis <= rank of input}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_large_k(%arg0: tensor<10x10xf32>) {
%k = arith.constant 100 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected k <= dimension size}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

0 comments on commit 9722914

Please sign in to comment.