Skip to content

Commit

Permalink
Consider the data layout of the input tensor for the MultiThreshold
Browse files Browse the repository at this point in the history
  • Loading branch information
iksnagreb committed Sep 13, 2024
1 parent 2b20386 commit 4a69267
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/finn/transformation/qonnx/qonnx_activation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def replace_quant_node(self):
graph.value_info.append(thresh_tensor)
model.set_initializer(thresh_tensor.name, thresholds)

data_layout = model.get_tensor_layout(n.input[0])

# Insert MultiThreshold node
outp_trans_node = helper.make_node(
"MultiThreshold",
Expand All @@ -154,6 +156,11 @@ def replace_quant_node(self):
mt_node = graph.node[running_node_index - 1]
mt_inst = getCustomOp(mt_node)

# Inherit the data layout from the input tensor if available
if data_layout is not None:
# Convert list to string representation of the data layout
mt_inst.set_nodeattr("data_layout", "".join(data_layout))

# Set scale and bias
# If these values are scalar then they can be set as attributes
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
Expand Down Expand Up @@ -395,9 +402,9 @@ def _calculate_thresholds(self):
else:
thresholds[c][t] = step / selu_scale

# First try to consider the tensor layout of the output for determining
# First try to consider the tensor layout of the input for determining
# the number of output channels
layout = self._model.get_tensor_layout(self._q_node.output[0])
layout = self._model.get_tensor_layout(self._q_node.input[0])
# If there is a layout annotation, use this to determine the index of
# the channel dimension
if layout is not None and "C" in layout:
Expand All @@ -410,7 +417,7 @@ def _calculate_thresholds(self):
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.output[0]}:"
f"No layout annotations for {self._q_node.input[0]}:"
f" Assuming channel dimension at index {cdim}"
)

Expand Down Expand Up @@ -556,9 +563,9 @@ def _calculate_thresholds(self):
for t in range(num_thresholds):
thresholds[c][t] = min_threshold[c] + step[c] * t

# First try to consider the tensor layout of the output for
# First try to consider the tensor layout of the input for
# determining the number of output channels
layout = self._model.get_tensor_layout(self._q_node.output[0])
layout = self._model.get_tensor_layout(self._q_node.input[0])
# If there is a layout annotation, use this to determine the index
# of the channel dimension
if layout is not None and "C" in layout:
Expand All @@ -571,7 +578,7 @@ def _calculate_thresholds(self):
cdim = 1
# Issue a warning to the user, so they are aware of this
warnings.warn(
f"No layout annotations for {self._q_node.output[0]}:"
f"No layout annotations for {self._q_node.input[0]}:"
f" Assuming channel dimension at index {cdim}"
)

Expand Down

0 comments on commit 4a69267

Please sign in to comment.