-
Notifications
You must be signed in to change notification settings - Fork 3
LIP0010
LIP | 10 |
---|---|
Title | Gathering columns with TensorFlow core op |
Author | J. van de Wolfshaar |
Status | Draft |
Type | Standard |
Discussion | Issue #40 |
PR | |
Created | Feb 23, 2018 |
A common operation in the TensorFlow graph of an SPN is to gather the columns of
some input tensor. Currently, we have our own custom implementation
of gather_columns
.
Since TensorFlow 1.3,
tf.gather
has an axis
parameter. It would be interesting to see how this operator compares to
our custom implementation of gather_cols
.
Gathering of columns
might happen in case of selecting certain parts of an input tensor e.g.
when the indices
of an Input
node are specified, or when we compute an
input mixture. In a SumsLayer
the custom gather_cols_3d
operation
actually one of the most expensive operations in the graph. Hence, any optimization that can be
obtained by using the TensorFlow core implementation is worth considering.
The impact on the code will really be minimal. We can deprecate the implementation of
gather_cols
in libspn/utils/math.py
and replace the default implementation to use tf.gather
with axis=1
instead.
As can be seen below, the performance of the tf.gather
and our custom Op are comparable.
#-----------------------
2d_1index
#-----------------------
CPU op dt: size setup_time first_run_time rest_run_time correct
GPU op dt: size setup_time first_run_time rest_run_time correct
custom int32: 49 15.35 52.41 48.23 True
custom int64: 49 28.88 57.75 48.16 True
gather_tf int32: 59 21.22 51.25 45.96 True
gather_tf int64: 59 20.46 40.95 47.66 True
slice_2d int32: 59 28.79 52.02 50.32 True
slice_2d int64: 59 23.79 50.46 46.72 True
#-----------------------
2d_passthrough_500indices
#-----------------------
CPU op dt: size setup_time first_run_time rest_run_time correct
GPU op dt: size setup_time first_run_time rest_run_time correct
custom int32: 49 16.09 90.91 55.57 True
custom int64: 49 14.49 63.92 56.58 True
gather_tf int32: 59 19.25 52.75 58.66 True
gather_tf int64: 59 22.05 68.10 60.07 True
noop int32: 29 10.10 54.85 58.25 True
noop int64: 29 12.61 65.24 54.89 True
#-----------------------
2d_opt_500indices
#-----------------------
CPU op dt: size setup_time first_run_time rest_run_time correct
GPU op dt: size setup_time first_run_time rest_run_time correct
custom int32: 49 14.29 49.42 56.34 True
custom int64: 49 15.30 62.71 59.09 True
gather_tf int32: 59 19.03 61.05 56.34 True
gather_tf int64: 59 20.28 63.21 59.65 True
#-----------------------
2d_worst_500indices
#-----------------------
CPU op dt: size setup_time first_run_time rest_run_time correct
GPU op dt: size setup_time first_run_time rest_run_time correct
custom int32: 49 13.96 58.28 57.07 True
custom int64: 49 16.14 66.50 55.19 True
gather_tf int32: 59 20.83 50.41 56.20 True
gather_tf int64: 59 18.51 61.05 53.34 True
-----------------------
2d_random_100indices
-----------------------
CPU op dt: size setup_time first_run_time rest_run_time correct
GPU op dt: size setup_time first_run_time rest_run_time correct
custom int32: 49 13.82 52.36 54.00 True
custom int64: 49 13.94 41.48 50.46 True
gather_tf int32: 59 18.52 52.58 52.45 True
gather_tf int64: 59 21.18 54.50 49.28 True
Below we have listed a performance comparison for our custom gather Op vs. the TF core Op when training on a 2-class MNIST problem with multi-nodes.
We can see that the performance difference is only marginal and more often in favor of the custom Op. In addition, the custom Op implementation yields a smaller graph size. Hence, we should still favor the custom implementation.
#-----------------------
InferenceType: MPE
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01 | True | True | 1400 | 53250 | 500596736 | MIXTURE | 25.6499 | 2.73446 | 3248.72 | 1175.65 | 0.463357 | custom_gather_cols=False |
| mnist_01 | True | True | 1400 | 51274 | 500788992 | MIXTURE | 29.0768 | 2.72073 | 3297.86 | 1110.14 | 0.463357 | custom_gather_cols=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
-----------------------
InferenceType: MARGINAL
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01 | True | True | 1400 | 50874 | 508240384 | MIXTURE | 25.8084 | 2.73312 | 3066.35 | 1019.17 | 0.463357 | custom_gather_cols=False |
| mnist_01 | True | True | 1400 | 48898 | 508240384 | MIXTURE | 26.7667 | 2.68207 | 2805.65 | 1086.93 | 0.463357 | custom_gather_cols=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
-----------------------
InferenceType: MPE-LOG
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01 | True | True | 1400 | 54448 | 508240384 | MIXTURE | 32.4243 | 2.82695 | 3559.33 | 1167.01 | 0.998109 | custom_gather_cols=False |
| mnist_01 | True | True | 1400 | 52472 | 510517760 | MIXTURE | 30.1937 | 2.79218 | 3076.52 | 1108.25 | 0.997636 | custom_gather_cols=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
-----------------------
InferenceType: MARGINAL-LOG
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memory_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01 | True | True | 1400 | 61612 | 510517760 | MIXTURE | 47.097 | 3.01621 | 3886.95 | 1321.13 | 0.99669 | custom_gather_cols=False |
| mnist_01 | True | True | 1400 | 59636 | 510517760 | MIXTURE | 33.4144 | 3.03568 | 4013.45 | 1311.92 | 0.996217 | custom_gather_cols=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+