diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 9a7e9d0723..33751cb4d8 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -518,7 +518,9 @@ def apply(self, model): class MoveLinearPastEltwiseAdd(Transformation): - """Move linear operations (mul, add) past elementwise add operations where possible. + """ + DEPRECATED, use MoveAddPastJoinAdd() and MoveMulPastJoinAdd() + Move linear operations (mul, add) past elementwise add operations where possible. Specifically,matches and transforms the following patterns: (x*C) + (y*C) -> (x + y) * C (x+A) + (y+B) -> (x + y) + (A + B) @@ -918,6 +920,121 @@ def __init__(self): super().__init__(["Transpose"]) +def permute_shape(shape, perm): + new_shape = np.zeros(len(shape)) + for i, p in enumerate(perm): + new_shape[i] = shape[p] + return [int(el) for el in new_shape] + + +class MoveScalarLinearPastSplit(Transformation): + """ + Move scalar Mul and Add nodes past channel split operation. + """ + + def __init__(self): + super().__init__() + self.ops_to_move = ["Mul", "Add"] + self.fork_ops = ["Split"] + + def apply(self, model): + graph = model.graph + graph_modified = False + node_ind = 0 + for n in graph.node: + node_ind += 1 + # if n.op_type in self.fork_ops and model.is_fork_node(n): + if n.op_type in self.fork_ops: + producer = model.find_producer(n.input[0]) + if producer is not None and producer.op_type in self.ops_to_move: + linear_param = model.get_initializer(producer.input[1]) + # Check if single input + if len(producer.input) != 2 or linear_param is None: + continue + # Check if scalar + if np.prod(linear_param.shape) != 1: + continue + split_outputs = n.output + for split_output_idx, old_split_output in enumerate(split_outputs): + new_mul_node = deepcopy(producer) + new_split_output = model.make_new_valueinfo_name() + model.set_tensor_datatype( + new_split_output, model.get_tensor_datatype(producer.input[0]) + ) + + model.set_tensor_shape( + new_split_output, model.get_tensor_shape(old_split_output) + ) + + n.output[split_output_idx] = new_split_output + new_mul_node.input[0] = new_split_output + new_mul_node.output[0] = old_split_output + + graph.node.insert(node_ind, new_mul_node) + node_ind += 1 + + # remove the mul node + n.input[0] = producer.input[0] + graph.node.remove(producer) + graph_modified = True + + if graph_modified: + model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False) + + return (model, graph_modified) + + +class MoveTransposePastSplit(Transformation): + def __init__(self): + super().__init__() + self.ops_to_move = ["Transpose"] + self.fork_ops = ["Split"] + + def apply(self, model): + graph = model.graph + graph_modified = False + node_ind = 0 + for n in graph.node: + node_ind += 1 + # if n.op_type in self.fork_ops and model.is_fork_node(n): + if n.op_type in self.fork_ops: + producer = model.find_producer(n.input[0]) + if producer is not None and producer.op_type in self.ops_to_move: + initial_perm = get_by_name(producer.attribute, "perm").ints + reverse_perm = np.argsort(initial_perm) + split_outputs = n.output + for split_output_idx, old_split_output in enumerate(split_outputs): + new_trans_node = deepcopy(producer) + new_split_output = model.make_new_valueinfo_name() + old_split_output_shape = model.get_tensor_shape(old_split_output) + model.set_tensor_datatype( + new_split_output, model.get_tensor_datatype(producer.input[0]) + ) + + model.set_tensor_shape( + new_split_output, permute_shape(old_split_output_shape, reverse_perm) + ) + + n.output[split_output_idx] = new_split_output + new_trans_node.input[0] = new_split_output + new_trans_node.output[0] = old_split_output + + graph.node.insert(node_ind, new_trans_node) + node_ind += 1 + + # remove the transpose node and change the split axis + old_split_axis = get_by_name(n.attribute, "axis").i + get_by_name(n.attribute, "axis").i = initial_perm[old_split_axis] + n.input[0] = producer.input[0] + graph.node.remove(producer) + graph_modified = True + + if graph_modified: + model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False) + + return (model, graph_modified) + + class MoveMaxPoolPastMultiThreshold(Transformation): """Move MaxPool nodes past MultiThreshold nodes on linear segments of the graph.""" @@ -1188,13 +1305,8 @@ def apply(self, model): class MoveIdenticalOpPastJoinOp(Transformation): """ - Move identical operations on different branches past the common join node. - This transformation assumes that the identical operations only change the - data layout. For linear operations, see the transformation MoveLinearPastEltwiseAdd. - Specifically, this transformation matches and transforms the following patterns: - f(x) + f(y) -> f(x + y) - where f(.) is currently only supporting 'Transpose', and an 'Add' node is - the join node. + Move multiple identical operations on different branches past the common join node. + It assumes the shape to be preserved by the join op in the default move_node() method """ def __init__(self, identical_op_list, join_node_list): @@ -1202,52 +1314,77 @@ def __init__(self, identical_op_list, join_node_list): self.ops_to_move = identical_op_list self.join_node_op = join_node_list - def move_node(self, model, n, prod0, prod1): - # Found! move one of the identical_ops to output, remove the other one - identical_op0_in0 = prod0.input[0] - identical_op1_in0 = prod1.input[0] - add_in0 = n.input[0] - add_out = n.output[0] + def move_node(self, model, n, producers): + """ + Should be overwritten for some operations + + Returns: + bool: whether moving the node was successful + """ + identical_ops_inputs = [p.input[0] for p in producers] + # join_in0 = n.input[0] + join_out = n.output[0] - # Rewire - n.input[0] = identical_op0_in0 - n.input[1] = identical_op1_in0 + # Rewire join op inputs + for i in range(len(n.input)): + n.input[i] = identical_ops_inputs[i] # Output tensor of the join node must have the same shape as # its input tensor (original shape is preserved) - new_shape = model.get_tensor_shape(identical_op0_in0) + new_join_output = model.make_new_valueinfo_name() + new_shape = model.get_tensor_shape(identical_ops_inputs[0]) + new_layout = model.get_tensor_layout(identical_ops_inputs[0]) # Set new tensor shape - model.set_tensor_shape(tensor_name=add_in0, tensor_shape=new_shape) - - n.output[0] = add_in0 - prod0.input[0] = add_in0 - prod0.output[0] = add_out - - model.graph.node.remove(prod1) + model.set_tensor_shape(new_join_output, new_shape) + if new_layout: + model.set_tensor_layout(new_join_output, new_layout) + + # Rewire join op outputs (reuse the first join input tensor) + n.output[0] = new_join_output + producers[0].input[0] = new_join_output + producers[0].output[0] = join_out + + for prod in producers[1:]: + model.graph.node.remove(prod) + + return True + + def are_producers_identical(self, model, producers): + """ + Checks only op_types + Should be overwritten for additional checks + """ + op_types = [prod.op_type for prod in producers] + for op in op_types: + if op != op_types[0]: + return False + return True def apply(self, model): graph = model.graph graph_modified = False for n in graph.node: if n.op_type in self.join_node_op and model.is_join_node(n): - in0 = n.input[0] - in1 = n.input[1] - if in0 is None or in1 is None: + inputs = n.input + if None in inputs: continue - prod0 = model.find_producer(in0) - prod1 = model.find_producer(in1) - # Checks if the join node is preceded by - # two different, but identical operations - if prod0 == prod1: + producers = [model.find_producer(inp) for inp in inputs] + if producers[0].op_type not in self.ops_to_move: + continue + identical_ops = self.are_producers_identical(model, producers) + if not identical_ops: + warnings.warn("Producers not identical, skipping") continue - identical_op = prod0.op_type == prod1.op_type - - if identical_op and prod0.op_type in self.ops_to_move: - self.move_node(model, n, prod0, prod1) - graph_modified = True + # check for producers that are fork nodes (need to fork them before our transform) + for prod in producers: + if model.is_fork_node(prod) and not model.is_join_node(prod): + model = model.transform(MoveOpPastFork(self.ops_to_move)) + # topology modified, "ask" ModelWrapper to apply this transform again + return (model, True) + graph_modified = self.move_node(model, n, producers) if graph_modified: model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False) @@ -1258,3 +1395,189 @@ def apply(self, model): class MoveTransposePastJoinAdd(MoveIdenticalOpPastJoinOp): def __init__(self): super().__init__(["Transpose"], ["Add"]) + + def are_producers_identical(self, model, producers): + if not super().are_producers_identical(model, producers): + return False + first_perm = get_by_name(producers[0].attribute, "perm").ints + for producer in producers: + if first_perm != get_by_name(producer.attribute, "perm").ints: + False + return True + + +class MoveMulPastJoinAdd(MoveIdenticalOpPastJoinOp): + def __init__(self): + super().__init__(["Mul"], ["Add"]) + + def are_producers_identical(self, model, producers): + if not super().are_producers_identical(model, producers): + return False + first_mul = model.get_initializer(producers[0].input[1]) + if first_mul is None: + return False + for producer in producers: + if first_mul != model.get_initializer(producer.input[1]): + return False + return True + + +class MoveAddPastJoinAdd(MoveIdenticalOpPastJoinOp): + def __init__(self): + super().__init__(["Add"], ["Add"]) + + def are_producers_identical(self, model, producers): + if not super().are_producers_identical(model, producers): + return False + for producer in producers: + if model.get_initializer(producer.input[1]) is None: + return False + return True + + def move_node(self, model, n, producers): + """ + We use the base move_node method to move the first producer + past the join node (and delete the rest) + """ + add_inits = [model.get_initializer(producer.input[1]) for producer in producers] + new_init = np.sum(add_inits) + model.set_initializer(producers[0].input[1], new_init) + super().move_node(model, n, producers) + + return True + + +class MoveTransposePastJoinConcat(MoveIdenticalOpPastJoinOp): + def __init__(self): + super().__init__(["Transpose"], ["Concat"]) + + def are_producers_identical(self, model, producers): + if not super().are_producers_identical(model, producers): + return False + first_perm = get_by_name(producers[0].attribute, "perm").ints + for producer in producers: + if first_perm != get_by_name(producer.attribute, "perm").ints: + False + return True + + def move_node(self, model, n, producers): + trans_inputs = [prod.input[0] for prod in producers] + # concat_in0 = n.input[0] + concat_out = n.output[0] + # Rewire concat inputs + for i in range(len(n.input)): + n.input[i] = trans_inputs[i] + + new_concat_out = model.make_new_valueinfo_name() # reuse tensor + # reverse the permutation of the concat output + transpose_perm = get_by_name(producers[0].attribute, "perm").ints + reverse_perm = np.argsort(transpose_perm) + new_concat_out_shape = permute_shape(model.get_tensor_shape(concat_out), reverse_perm) + new_concat_out_layout = model.get_tensor_layout(trans_inputs[0]) + # Set tensor layout and shape of the new concatenation output + model.set_tensor_shape(new_concat_out, new_concat_out_shape) + if new_concat_out_layout: + model.set_tensor_layout(new_concat_out, new_concat_out_layout) + # Change concatenation axis + old_concat_axis = get_by_name(n.attribute, "axis").i + get_by_name(n.attribute, "axis").i = transpose_perm[old_concat_axis] + + # Rewire concat output + n.output[0] = new_concat_out + producers[0].input[0] = new_concat_out + producers[0].output[0] = concat_out + + for prod in producers[1:]: + model.graph.node.remove(prod) + + return True + + +class MoveAffinePastJoinConcat(MoveIdenticalOpPastJoinOp): + """ + Applies to scalar linear or channelwise affine ops with the same parameter value + """ + + def __init__(self, linear_ops=["Mul", "Add"]): + super().__init__(linear_ops, ["Concat"]) + + def are_producers_identical_scalar_ops(self, model, producers): + first_param = model.get_initializer(producers[0].input[1]) + for producer in producers: + producer_param = model.get_initializer(producer.input[1]) + if (first_param != producer_param).any() or np.prod(producer_param.shape) != 1: + return False + + return True + + def are_producers_channelwise_ops(self, channel_dim, model, producers): + for producer in producers: + producer_input = producer.input[0] + num_channels = model.get_tensor_shape(producer_input)[channel_dim] + producer_param = model.get_initializer(producer.input[1]) + if ( + len(producer_param.shape) < channel_dim + or producer_param.shape[channel_dim] != num_channels + ): + return False + + return True + + def move_node(self, model, n, producers): + # check if single input + for producer in producers: + producer_init = model.get_initializer(producer.input[1]) + if len(producer.input) != 2 or producer_init is None: + warnings.warn("Producer found that is not single-input, skipping") + return False + + # decide if producers are identical scalar ops or channelwise ops + channelwise_op = False + identical_scalar_op = self.are_producers_identical_scalar_ops(model, producers) + if not identical_scalar_op: + channel_dim = get_by_name(n.attribute, "axis").i + channelwise_op = self.are_producers_channelwise_ops(channel_dim, model, producers) + if not channelwise_op: + warnings.warn( + "Producers are neither identical scalar ops nor channelwise ops, skipping" + ) + return False + + # Rewire concat inputs + producers_inputs = [prod.input[0] for prod in producers] + concat_out = n.output[0] + for i in range(len(n.input)): + n.input[i] = producers_inputs[i] + # Set tensor layout and shape of the new concatenation output + new_concat_out = model.make_new_valueinfo_name() + new_concat_out_layout = model.get_tensor_layout(producers_inputs[0]) + model.set_tensor_shape(new_concat_out, model.get_tensor_shape(concat_out)) + if new_concat_out_layout: + model.set_tensor_layout(new_concat_out, new_concat_out_layout) + model.set_tensor_datatype(new_concat_out, model.get_tensor_datatype(producers_inputs[0])) + + if channelwise_op: + # concatenate op params of producers into one mul tensor + producers_params = [model.get_initializer(prod.input[1]) for prod in producers] + new_mul_tensor = np.concatenate(producers_params, axis=channel_dim) + model.set_initializer(producers[0].input[1], new_mul_tensor) + + # Rewire concat output + n.output[0] = new_concat_out + producers[0].input[0] = new_concat_out + producers[0].output[0] = concat_out + + for prod in producers[1:]: + model.graph.node.remove(prod) + + return True + + +class MoveMulPastJoinConcat(MoveAffinePastJoinConcat): + def __init__(self): + super().__init__(["Mul"]) + + +class MoveAddPastJoinConcat(MoveAffinePastJoinConcat): + def __init__(self): + super().__init__(["Add"]) diff --git a/tests/transformation/streamline/test_move_identical_op_past_join_add.py b/tests/transformation/streamline/test_move_identical_op_past_join_add.py new file mode 100644 index 0000000000..7226d31589 --- /dev/null +++ b/tests/transformation/streamline/test_move_identical_op_past_join_add.py @@ -0,0 +1,150 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +import numpy as np +from onnx import TensorProto +from onnx import helper as oh +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model + +import finn.core.onnx_exec as oxe +from finn.transformation.streamline.reorder import ( + MoveAddPastJoinAdd, + MoveMulPastJoinAdd, + MoveTransposePastJoinAdd, +) + + +def create_add_model(identical_op): + perm = None + if "Transpose" in identical_op: + perm = identical_op.split("_")[1] + identical_op = identical_op.split("_")[0] + perm = [int(char) for char in perm] + if perm == [0, 2, 3, 1]: + in_shape = [1, 64, 10, 9] + out_shape = [1, 10, 9, 64] + elif perm == [0, 3, 1, 2]: + in_shape = [1, 10, 9, 64] + out_shape = [1, 64, 10, 9] + else: + in_shape = [1, 64, 10, 9] + out_shape = in_shape + op_value = 1.5 + + op1_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op1_out"]) + + op2_node = oh.make_node(identical_op, inputs=["in2"], outputs=["op2_out"]) + + if identical_op == "Transpose": + new_attr = oh.make_attribute("perm", perm) + op1_node.attribute.append(new_attr) + op2_node.attribute.append(new_attr) + elif identical_op == "Mul" or identical_op == "Add": + op1_init = oh.make_tensor_value_info("op1_param", TensorProto.FLOAT, [1]) + op2_init = oh.make_tensor_value_info("op2_param", TensorProto.FLOAT, [1]) + op1_node.input.append(op1_init.name) + op2_node.input.append(op2_init.name) + + add_node = oh.make_node("Add", inputs=["op1_out", "op2_out"], outputs=["out_join1"]) + + in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, in_shape) + in2 = oh.make_tensor_value_info("in2", TensorProto.FLOAT, in_shape) + op1_out = oh.make_tensor_value_info("op1_out", TensorProto.FLOAT, out_shape) + op2_out = oh.make_tensor_value_info("op2_out", TensorProto.FLOAT, out_shape) + out_join1 = oh.make_tensor_value_info("out_join1", TensorProto.FLOAT, out_shape) + + graph = oh.make_graph( + nodes=[op1_node, op2_node, add_node], + name="test_graph", + inputs=[in1, in2], + outputs=[out_join1], + value_info=[ + op1_out, + op2_out, + ], + ) + + onnx_model = qonnx_make_model(graph, producer_name="test_model") + model = ModelWrapper(onnx_model) + if identical_op == "Mul" or identical_op == "Add": + model.set_initializer("op1_param", np.array(op_value).astype(np.float32)) + model.set_initializer("op2_param", np.array(op_value).astype(np.float32)) + + return model + + +transform_dict = { + "Transpose_0231": MoveTransposePastJoinAdd(), + "Transpose_0312": MoveTransposePastJoinAdd(), + "Mul": MoveMulPastJoinAdd(), + "Add": MoveAddPastJoinAdd(), +} + + +@pytest.mark.streamline +# Permutation of transpose node +@pytest.mark.parametrize("identical_op", ["Transpose_0231", "Transpose_0312", "Mul", "Add"]) +def test_move_identical_op_past_join_op(identical_op): + model = create_add_model(identical_op) + # build_dir = os.environ["FINN_BUILD_DIR"] + # model.save(join(build_dir, "add_pytest_model_{}.onnx".format(identical_op))) + + # Create input data + input0_tensor_name = model.graph.input[0].name + input1_tensor_name = model.graph.input[1].name + + # Note: it is assumed that both tensors have the same shape and data type + input_shape = model.get_tensor_shape(input0_tensor_name) + input_dtype = model.get_tensor_datatype(input0_tensor_name) + input_val = gen_finn_dt_tensor(input_dtype, input_shape) + input_dict = {} + input_dict[input0_tensor_name] = input_val + input_dict[input1_tensor_name] = input_val + + model_transformed = model.transform(transform_dict[identical_op]) + # model_transformed.save(join(build_dir, "add_pytest_model_{}_trans.onnx".format(identical_op))) + + assert oxe.compare_execution(model, model_transformed, input_dict) + + # Check if order changed + node0_optype_model = model.find_consumers(model.graph.input[0].name)[0].op_type + node1_optype_model = model.find_consumers(model.graph.input[1].name)[0].op_type + node0_optype_model_transformed = model_transformed.find_consumers( + model_transformed.graph.input[0].name + )[0].op_type + node1_optype_model_transformed = model_transformed.find_consumers( + model_transformed.graph.input[1].name + )[0].op_type + last_node_optype_model_transformed = model_transformed.find_producer( + model_transformed.graph.output[0].name + ).op_type + assert node0_optype_model == last_node_optype_model_transformed + assert node1_optype_model == last_node_optype_model_transformed + assert node0_optype_model_transformed == node1_optype_model_transformed == "Add" diff --git a/tests/transformation/streamline/test_move_identical_op_past_join_concat.py b/tests/transformation/streamline/test_move_identical_op_past_join_concat.py new file mode 100644 index 0000000000..2dcf90d10a --- /dev/null +++ b/tests/transformation/streamline/test_move_identical_op_past_join_concat.py @@ -0,0 +1,183 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +import numpy as np +import os +from onnx import TensorProto +from onnx import helper as oh +from os.path import join +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model + +import finn.core.onnx_exec as oxe +from finn.transformation.streamline.reorder import ( + MoveAddPastJoinConcat, + MoveMulPastJoinConcat, + MoveTransposePastJoinConcat, +) + + +def create_concat_model(identical_op): + perm = None + channelwise = False + if "Transpose" in identical_op: + perm = identical_op.split("_")[1] + identical_op = identical_op.split("_")[0] + perm = [int(char) for char in perm] + if "channelwise" in identical_op: + channelwise = True + identical_op = identical_op.split("_")[0] + if perm == [0, 2, 3, 1]: + in_shape1 = [1, 64, 10, 9] + in_shape2 = [1, 32, 10, 9] + out_shape1 = [1, 10, 9, 64] + out_shape2 = [1, 10, 9, 32] + out_join_shape = [1, 10, 9, 96] + concat_axis = 3 + elif perm == [0, 3, 1, 2]: + in_shape1 = [1, 10, 9, 64] + in_shape2 = [1, 10, 9, 32] + out_shape1 = [1, 64, 10, 9] + out_shape2 = [1, 32, 10, 9] + out_join_shape = [1, 96, 10, 9] + concat_axis = 1 + else: + in_shape1 = [1, 64, 10, 9] + in_shape2 = [1, 32, 10, 9] + out_shape1 = in_shape1 + out_shape2 = in_shape2 + out_join_shape = [1, 96, 10, 9] + concat_axis = 1 + if channelwise: + op1_param_shape = [1, 64, 1, 1] + op2_param_shape = [1, 32, 1, 1] + op1_param = np.ones((1, 64, 1, 1)) * 2 + op2_param = np.ones((1, 32, 1, 1)) * 3 + else: + op1_param_shape = [1] + op2_param_shape = [1] + op1_param = 1.5 + op2_param = 1.5 + + op1_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op1_out"]) + + op2_node = oh.make_node(identical_op, inputs=["in2"], outputs=["op2_out"]) + + if identical_op == "Transpose": + new_attr = oh.make_attribute("perm", perm) + op1_node.attribute.append(new_attr) + op2_node.attribute.append(new_attr) + elif identical_op == "Mul" or identical_op == "Add": + op1_init = oh.make_tensor_value_info("op1_param", TensorProto.FLOAT, op1_param_shape) + op2_init = oh.make_tensor_value_info("op2_param", TensorProto.FLOAT, op2_param_shape) + op1_node.input.append(op1_init.name) + op2_node.input.append(op2_init.name) + + concat_node = oh.make_node( + "Concat", inputs=["op1_out", "op2_out"], outputs=["out_join1"], axis=concat_axis + ) + + in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, in_shape1) + in2 = oh.make_tensor_value_info("in2", TensorProto.FLOAT, in_shape2) + op1_out = oh.make_tensor_value_info("op1_out", TensorProto.FLOAT, out_shape1) + op2_out = oh.make_tensor_value_info("op2_out", TensorProto.FLOAT, out_shape2) + out_join1 = oh.make_tensor_value_info("out_join1", TensorProto.FLOAT, out_join_shape) + + graph = oh.make_graph( + nodes=[op1_node, op2_node, concat_node], + name="test_graph", + inputs=[in1, in2], + outputs=[out_join1], + value_info=[ + op1_out, + op2_out, + ], + ) + + onnx_model = qonnx_make_model(graph, producer_name="test_model") + model = ModelWrapper(onnx_model) + if identical_op == "Mul" or identical_op == "Add": + model.set_initializer("op1_param", np.array(op1_param).astype(np.float32)) + model.set_initializer("op2_param", np.array(op2_param).astype(np.float32)) + + return model + + +transform_dict = { + "Transpose_0231": MoveTransposePastJoinConcat(), + "Transpose_0312": MoveTransposePastJoinConcat(), + "Mul": MoveMulPastJoinConcat(), + "Mul_channelwise": MoveMulPastJoinConcat(), + "Add": MoveAddPastJoinConcat(), + "Add_channelwise": MoveAddPastJoinConcat(), +} + + +@pytest.mark.streamline +# Permutation of transpose node +@pytest.mark.parametrize( + "identical_op", + ["Transpose_0231", "Transpose_0312", "Mul", "Add", "Mul_channelwise", "Add_channelwise"], +) +def test_move_identical_op_past_join_concat(identical_op): + model = create_concat_model(identical_op) + build_dir = os.environ["FINN_BUILD_DIR"] + model.save(join(build_dir, "concat_pytest_model_{}.onnx".format(identical_op))) + + # Create input data + input0_tensor_name = model.graph.input[0].name + input1_tensor_name = model.graph.input[1].name + + # Note: it is assumed that both tensors have the same shape and data type + input_dict = {} + input_dict[input0_tensor_name] = gen_finn_dt_tensor( + model.get_tensor_datatype(input0_tensor_name), model.get_tensor_shape(input0_tensor_name) + ) + input_dict[input1_tensor_name] = gen_finn_dt_tensor( + model.get_tensor_datatype(input1_tensor_name), model.get_tensor_shape(input1_tensor_name) + ) + + model_transformed = model.transform(transform_dict[identical_op]) + model_transformed.save( + join(build_dir, "concat_pytest_model_{}_trans.onnx".format(identical_op)) + ) + + assert oxe.compare_execution(model, model_transformed, input_dict) + + # Check if order changed + node0_input0_model = model.find_consumers(model.graph.input[0].name)[0].op_type + node1_input1_model = model.find_consumers(model.graph.input[1].name)[0].op_type + node0_input0_model_transformed = model_transformed.find_consumers( + model_transformed.graph.input[0].name + )[0].op_type + node1_input1_model_transformed = model_transformed.find_consumers( + model_transformed.graph.input[1].name + )[0].op_type + assert node0_input0_model != node0_input0_model_transformed + assert node1_input1_model != node1_input1_model_transformed diff --git a/tests/transformation/streamline/test_move_identical_op_past_join_op.py b/tests/transformation/streamline/test_move_identical_op_past_join_op.py deleted file mode 100644 index dd83681fc2..0000000000 --- a/tests/transformation/streamline/test_move_identical_op_past_join_op.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2020, Xilinx -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of FINN nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import pytest - -from onnx import TensorProto -from onnx import helper as oh -from qonnx.core.modelwrapper import ModelWrapper -from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model - -import finn.core.onnx_exec as oxe -from finn.transformation.streamline.reorder import MoveTransposePastJoinAdd - - -def create_model(perm): - if perm == [0, 3, 1, 2]: - in_shape = [1, 128, 1, 256] - out_shape = [1, 256, 128, 1] - if perm == [0, 2, 3, 1]: - in_shape = [1, 256, 128, 1] - out_shape = [1, 128, 1, 256] - - Transpose1_node = oh.make_node( - "Transpose", inputs=["in_transpose1"], outputs=["out_transpose1"], perm=perm - ) - - Transpose2_node = oh.make_node( - "Transpose", inputs=["in_transpose2"], outputs=["out_transpose2"], perm=perm - ) - - Join1_node = oh.make_node( - "Add", inputs=["out_transpose1", "out_transpose2"], outputs=["out_join1"] - ) - - in_transpose1 = oh.make_tensor_value_info("in_transpose1", TensorProto.FLOAT, in_shape) - in_transpose2 = oh.make_tensor_value_info("in_transpose2", TensorProto.FLOAT, in_shape) - out_transpose1 = oh.make_tensor_value_info("out_transpose1", TensorProto.FLOAT, out_shape) - out_transpose2 = oh.make_tensor_value_info("out_transpose2", TensorProto.FLOAT, out_shape) - out_join1 = oh.make_tensor_value_info("out_join1", TensorProto.FLOAT, out_shape) - - graph = oh.make_graph( - nodes=[Transpose1_node, Transpose2_node, Join1_node], - name="test_graph", - inputs=[in_transpose1, in_transpose2], - outputs=[out_join1], - value_info=[ - out_transpose1, - out_transpose2, - ], - ) - - onnx_model = qonnx_make_model(graph, producer_name="test_model") - model = ModelWrapper(onnx_model) - - return model - - -@pytest.mark.streamline -# Permutation of transpose node -@pytest.mark.parametrize("perm", [[0, 3, 1, 2], [0, 2, 3, 1]]) -def test_move_identical_op_past_join_op(perm): - model = create_model(perm) - - # Create input data - input0_tensor_name = model.graph.input[0].name - input1_tensor_name = model.graph.input[1].name - - # Note: it is assumed that both tensors have the same shape and data type - input_shape = model.get_tensor_shape(input0_tensor_name) - input_dtype = model.get_tensor_datatype(input0_tensor_name) - input_val = gen_finn_dt_tensor(input_dtype, input_shape) - input_dict = {} - input_dict[input0_tensor_name] = input_val - input_dict[input1_tensor_name] = input_val - - model_transformed = model.transform(MoveTransposePastJoinAdd()) - - assert oxe.compare_execution(model, model_transformed, input_dict) - - # Check if order changed - node0_input0_model = model.find_consumers(model.graph.input[0].name)[0].op_type - node1_input1_model = model.find_consumers(model.graph.input[1].name)[0].op_type - node0_input0_model_transformed = model_transformed.find_consumers( - model_transformed.graph.input[0].name - )[0].op_type - node1_input1_model_transformed = model_transformed.find_consumers( - model_transformed.graph.input[1].name - )[0].op_type - assert node0_input0_model != node0_input0_model_transformed - assert node1_input1_model != node1_input1_model_transformed diff --git a/tests/transformation/streamline/test_move_identical_op_past_split.py b/tests/transformation/streamline/test_move_identical_op_past_split.py new file mode 100644 index 0000000000..a104f179be --- /dev/null +++ b/tests/transformation/streamline/test_move_identical_op_past_split.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020, Xilinx +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of FINN nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytest + +import numpy as np +from onnx import TensorProto +from onnx import helper as oh +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.general import GiveUniqueNodeNames +from qonnx.util.basic import gen_finn_dt_tensor + +import finn.core.onnx_exec as oxe +from finn.transformation.streamline.reorder import ( + MoveScalarLinearPastSplit, + MoveTransposePastSplit, +) + + +def create_split_model(identical_op): + perm = None + if "Transpose" in identical_op: + perm = identical_op.split("_")[1] + identical_op = identical_op.split("_")[0] + perm = [int(char) for char in perm] + if perm == [0, 2, 3, 1]: + in_shape = [1, 96, 10, 9] + out_shape = [1, 10, 9, 96] + out1_split_shape = [1, 10, 9, 32] + out2_split_shape = [1, 10, 9, 64] + split_axis = 3 + elif perm == [0, 3, 1, 2]: + in_shape = [1, 10, 9, 96] + out_shape = [1, 96, 10, 9] + out1_split_shape = [1, 32, 10, 9] + out2_split_shape = [1, 64, 10, 9] + split_axis = 1 + else: + in_shape = [1, 96, 10, 9] + out_shape = in_shape + out1_split_shape = [1, 32, 10, 9] + out2_split_shape = [1, 64, 10, 9] + split_axis = 1 + op_value = 1.5 + split = [32, 64] + + op_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op_out"]) + + if identical_op == "Transpose": + new_attr = oh.make_attribute("perm", perm) + op_node.attribute.append(new_attr) + elif identical_op == "Mul" or identical_op == "Add": + op_init = oh.make_tensor_value_info("op_param", TensorProto.FLOAT, [1]) + op_node.input.append(op_init.name) + + in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, in_shape) + op_out = oh.make_tensor_value_info("op_out", TensorProto.FLOAT, out_shape) + out1_split = oh.make_tensor_value_info("out1_split", TensorProto.FLOAT, out1_split_shape) + out2_split = oh.make_tensor_value_info("out2_split", TensorProto.FLOAT, out2_split_shape) + split_init = oh.make_tensor_value_info("split", TensorProto.INT64, [2]) + + split_node = oh.make_node( + "Split", [op_out.name, split_init.name], [out1_split.name, out2_split.name], axis=split_axis + ) + + graph = oh.make_graph( + nodes=[op_node, split_node], + name="test_graph", + inputs=[in1], + outputs=[out1_split, out2_split], + value_info=[op_out], + ) + + model = oh.make_model(graph) + model = ModelWrapper(model) + model.set_initializer(split_init.name, np.array(split, dtype=np.int64)) + if identical_op == "Mul" or identical_op == "Add": + model.set_initializer(op_init.name, np.array(op_value).astype(np.float32)) + model = model.transform(GiveUniqueNodeNames()) + + return model + + +transform_dict = { + "Transpose_0231": MoveTransposePastSplit(), + "Transpose_0312": MoveTransposePastSplit(), + "Mul": MoveScalarLinearPastSplit(), + "Add": MoveScalarLinearPastSplit(), +} + + +@pytest.mark.streamline +# Permutation of transpose node +@pytest.mark.parametrize("identical_op", ["Transpose_0231", "Transpose_0312", "Mul", "Add"]) +def test_move_identical_op_past_join_concat(identical_op): + model = create_split_model(identical_op) + # build_dir = os.environ["FINN_BUILD_DIR"] + # model.save(join(build_dir, "split_pytest_model_{}.onnx".format(identical_op))) + + # Create input data + input0_tensor_name = model.graph.input[0].name + + # Note: it is assumed that both tensors have the same shape and data type + input_dict = {} + input_dict[input0_tensor_name] = gen_finn_dt_tensor( + model.get_tensor_datatype(input0_tensor_name), model.get_tensor_shape(input0_tensor_name) + ) + + model_transformed = model.transform(transform_dict[identical_op]) + # model_transformed.save( + # join(build_dir, "split_pytest_model_{}_trans.onnx".format(identical_op)) + # ) + + assert oxe.compare_execution(model, model_transformed, input_dict) + + # Check if order changed + node0_input0_model = model.find_consumers(model.graph.input[0].name)[0].op_type + node0_input0_model_transformed = model_transformed.find_consumers( + model_transformed.graph.input[0].name + )[0].op_type + assert node0_input0_model != node0_input0_model_transformed