Skip to content
Jos van de Wolfshaar edited this page Feb 27, 2018 · 2 revisions

LIP 10 - Gathering columns with TensorFlow core op

LIP 10
Title Gathering columns with TensorFlow core op
Author J. van de Wolfshaar
Status Draft
Type Standard
Discussion Issue #40
PR
Created Feb 23, 2018

Introduction

A common operation in the TensorFlow graph of an SPN is to gather the columns of some input tensor. Currently, we have our own custom implementation of gather_columns.

Since TensorFlow 1.3, tf.gather has an axis parameter. It would be interesting to see how this operator compares to our custom implementation of gather_cols.

Technical Background

Gathering of columns might happen in case of selecting certain parts of an input tensor e.g. when the indices of an Input node are specified, or when we compute an input mixture. In a SumsLayer the custom gather_cols_3d operation actually one of the most expensive operations in the graph. Hence, any optimization that can be obtained by using the TensorFlow core implementation is worth considering.

Proposal

The impact on the code will really be minimal. We can deprecate the implementation of gather_cols in libspn/utils/math.py and replace the default implementation to use tf.gather with axis=1 instead.

Performance comparison

As can be seen below, the performance of the tf.gather and our custom Op are comparable.

#-----------------------
2d_1index
#-----------------------
CPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
GPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
         custom int32:    49       15.35           52.41          48.23       True
         custom int64:    49       28.88           57.75          48.16       True
      gather_tf int32:    59       21.22           51.25          45.96       True
      gather_tf int64:    59       20.46           40.95          47.66       True
       slice_2d int32:    59       28.79           52.02          50.32       True
       slice_2d int64:    59       23.79           50.46          46.72       True

#-----------------------
2d_passthrough_500indices
#-----------------------
CPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
GPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
         custom int32:    49       16.09           90.91          55.57       True
         custom int64:    49       14.49           63.92          56.58       True
      gather_tf int32:    59       19.25           52.75          58.66       True
      gather_tf int64:    59       22.05           68.10          60.07       True
           noop int32:    29       10.10           54.85          58.25       True
           noop int64:    29       12.61           65.24          54.89       True

#-----------------------
2d_opt_500indices
#-----------------------
CPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
GPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
         custom int32:    49       14.29           49.42          56.34       True
         custom int64:    49       15.30           62.71          59.09       True
      gather_tf int32:    59       19.03           61.05          56.34       True
      gather_tf int64:    59       20.28           63.21          59.65       True

#-----------------------
2d_worst_500indices
#-----------------------
CPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
GPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
         custom int32:    49       13.96           58.28          57.07       True
         custom int64:    49       16.14           66.50          55.19       True
      gather_tf int32:    59       20.83           50.41          56.20       True
      gather_tf int64:    59       18.51           61.05          53.34       True

-----------------------
2d_random_100indices
-----------------------
CPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
GPU          op    dt:  size  setup_time  first_run_time  rest_run_time    correct
         custom int32:    49       13.82           52.36          54.00       True
         custom int64:    49       13.94           41.48          50.46       True
      gather_tf int32:    59       18.52           52.58          52.45       True
      gather_tf int64:    59       21.18           54.50          49.28       True

Below we have listed a performance comparison for our custom gather Op vs. the TF core Op when training on a 2-class MNIST problem with multi-nodes.

We can see that the performance difference is only marginal and more often in favor of the custom Op. In addition, the custom Op implementation yields a smaller graph size. Hence, we should still favor the custom implementation.

#-----------------------
InferenceType: MPE
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name   | on_gpu   | multi_nodes   |   spn_size |   tf_size |   memory_used | input_dist   |   setup_time |   weights_init_time |   first_run_time |   rest_run_time |   test_accuracy | config                   |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01  | True     | True          |       1400 |     53250 |     500596736 | MIXTURE      |      25.6499 |             2.73446 |          3248.72 |         1175.65 |        0.463357 | custom_gather_cols=False |
| mnist_01  | True     | True          |       1400 |     51274 |     500788992 | MIXTURE      |      29.0768 |             2.72073 |          3297.86 |         1110.14 |        0.463357 | custom_gather_cols=True  |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+

-----------------------
InferenceType: MARGINAL
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name   | on_gpu   | multi_nodes   |   spn_size |   tf_size |   memory_used | input_dist   |   setup_time |   weights_init_time |   first_run_time |   rest_run_time |   test_accuracy | config                   |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01  | True     | True          |       1400 |     50874 |     508240384 | MIXTURE      |      25.8084 |             2.73312 |          3066.35 |         1019.17 |        0.463357 | custom_gather_cols=False |
| mnist_01  | True     | True          |       1400 |     48898 |     508240384 | MIXTURE      |      26.7667 |             2.68207 |          2805.65 |         1086.93 |        0.463357 | custom_gather_cols=True  |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+

-----------------------
InferenceType: MPE-LOG
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name   | on_gpu   | multi_nodes   |   spn_size |   tf_size |   memory_used | input_dist   |   setup_time |   weights_init_time |   first_run_time |   rest_run_time |   test_accuracy | config                   |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01  | True     | True          |       1400 |     54448 |     508240384 | MIXTURE      |      32.4243 |             2.82695 |          3559.33 |         1167.01 |        0.998109 | custom_gather_cols=False |
| mnist_01  | True     | True          |       1400 |     52472 |     510517760 | MIXTURE      |      30.1937 |             2.79218 |          3076.52 |         1108.25 |        0.997636 | custom_gather_cols=True  |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+

-----------------------
InferenceType: MARGINAL-LOG
-----------------------
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+
| op_name   | on_gpu   | multi_nodes   |   spn_size |   tf_size |   memory_used | input_dist   |   setup_time |   weights_init_time |   first_run_time |   rest_run_time |   test_accuracy | config                   |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------|
| mnist_01  | True     | True          |       1400 |     61612 |     510517760 | MIXTURE      |      47.097  |             3.01621 |          3886.95 |         1321.13 |        0.99669  | custom_gather_cols=False |
| mnist_01  | True     | True          |       1400 |     59636 |     510517760 | MIXTURE      |      33.4144 |             3.03568 |          4013.45 |         1311.92 |        0.996217 | custom_gather_cols=True  |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+--------------------------+

Decision

Clone this wiki locally