Skip to content

Commit

Permalink
topk: allow negative axis (#77)
Browse files Browse the repository at this point in the history
topk: allow for negative axis
  • Loading branch information
josel-amd authored Aug 27, 2024
1 parent d7ba85f commit b29a99f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
20 changes: 14 additions & 6 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,7 @@ ParseResult SubgraphOp::parse(OpAsmParser &p, OperationState &result) {
return parseEnclaveOp(p, result);
}

void SubgraphOp::print(OpAsmPrinter &p) {
printEnclaveOp(p, *this);
}
void SubgraphOp::print(OpAsmPrinter &p) { printEnclaveOp(p, *this); }

LogicalResult SubgraphOp::verify() {
Block *optBody = this->getOptionalEnclaveBody();
Expand Down Expand Up @@ -552,11 +550,21 @@ LogicalResult TopK::inferReturnTypeComponents(

auto inTy = cast<RankedTensorType>(adaptor.getInput().getType());

auto axis = adaptor.getAxis();
if (axis >= (uint64_t)inTy.getRank()) {
return emitOptionalError(location, "expected axis <= rank of input");
auto axis = (int64_t)adaptor.getAxis();
// onnx spec: axis: [-r, r-1]
if (axis < -inTy.getRank() || axis >= inTy.getRank()) {
return emitOptionalError(location,
"expected axis to be within [-rank,rank) (where "
"rank is the rank of the input)");
}

// normalize axis: [0, r)
if (axis < 0) {
axis += inTy.getRank();
}

assert((axis >= 0 && axis < inTy.getRank()) && "axis has invalid value");

auto dimSize = inTy.getDimSize(axis);
auto k = getConstantK(adaptor.getK());
// If both k and dim are known statically, we can check that k <= dim
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ func.func @topk_arg_dyn_in(%arg0: tensor<?x?xf32>, %k: i64) {
// CHECK: xten_nn.topk(%arg0 : tensor<?x?xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
return
}


// -----

// CHECK-LABEL: topk_neg_axis
func.func @topk_neg_axis(%arg0: tensor<10x8xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = -1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = -1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
return
}
12 changes: 11 additions & 1 deletion test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) {
func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis <= rank of input}}
// expected-error@+1 {{expected axis to be within [-rank,rank) (where rank is the rank of the input}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}
Expand All @@ -107,3 +107,13 @@ func.func @topk_large_k(%arg0: tensor<10x10xf32>) {
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_negative_axis(%arg0: tensor<10x10xf32>) {
%k = arith.constant 100 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis to be within [-rank,rank) (where rank is the rank of the input)}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = -3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

0 comments on commit b29a99f

Please sign in to comment.