Skip to content

Commit

Permalink
Merge pull request #96 from Xilinx/jrickert.filter_external_consts
Browse files Browse the repository at this point in the history
Skip external consts when walking operands
  • Loading branch information
ttjost authored Oct 4, 2024
2 parents 5793d27 + f953f8e commit b58d735
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
8 changes: 5 additions & 3 deletions lib/Transform/XTenMinimizeLiveTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ FailureOr<SmallVector<Value>> getFmOperands(Operation *op) {
return {getSubgraphIFMs(op)};

if (isTemplatedGraph(op))
return {op->getOperands()};
return {getSubgraphIFMs(op)};

// Otherwise, this is a PseudoOp and IFM is the first operand.
if (!(isAnyPseudoOp(op) || isInterfaceOp(op))) {
Expand All @@ -225,7 +225,8 @@ size_t getSize(Value val) {

if (auto complexType = elementType.dyn_cast<ComplexType>()) {
elementType = complexType.getElementType();
return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) / 8;
return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) /
8;
}
llvm_unreachable("Does not know how to compute size");
}
Expand Down Expand Up @@ -299,7 +300,8 @@ class XTenMinimizeLiveTensorsPass
} else {
fmResults = SmallVector<Value>(currFn.getBody().front().getArguments());
}
std::optional<Value> const sharesResultMemory = sharesMemoryWithResult(defOp);
std::optional<Value> const sharesResultMemory =
sharesMemoryWithResult(defOp);
OpInfo info = {.op = defOp,
.operands = *fmOperands,
.results = fmResults,
Expand Down
15 changes: 15 additions & 0 deletions test/Transform/XTenMinimizeLiveTensors/other_subgraphs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,19 @@ func.func @support_for_inteface_op(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x6
xten_nn.output %6 : tensor<1x64x56x56xf32>
} -> tensor<1x64x56x56xf32>
return %3 : tensor<1x64x56x56xf32>
}

// -----

// CHECK-LABEL: func.func @tg_with_constant_ops
// CHECK: LayerName = "TGConst"{{.*}} Reason = "TemplatedGraph"
func.func @tg_with_constant_ops(%arg0: tensor<1x1x64x8xbf16>) -> tensor<1x1x64x8xbf16> {
%0 = xten_nn.load_external_const {file = "constants.h5", key = "Test/Constant_2_0"} -> tensor<8xbf16>
%1 = xten_nn.load_external_const {file = "constants.h5", key = "Test/Constant_1_0"} -> tensor<8xbf16>
%2 = xten_nn.subgraph (%arg1 = %arg0: tensor<1x1x64x8xbf16>, %arg2 = %1: tensor<8xbf16>, %arg3 = %0: tensor<8xbf16>) attributes {IfmOperands = [0 : index], LayerName = "TGConst", Reason = "TemplatedGraph"}
{
%6 = tensor.empty() : tensor<1x1x64x8xbf16>
xten_nn.output %6 : tensor<1x1x64x8xbf16>
} -> tensor<1x1x64x8xbf16>
return %2 : tensor<1x1x64x8xbf16>
}

0 comments on commit b58d735

Please sign in to comment.