-
Notifications
You must be signed in to change notification settings - Fork 3
LIP0009
LIP | 9 |
---|---|
Title | Optimizing mixture input distributions |
Author | J. van de Wolfshaar |
Status | Draft |
Type | Standard |
Discussion | Issue #39 |
PR | |
Created | Feb 22, 2018 |
Consider the situation in which we have IVs
nodes as inputs to our SPN. If we have mixture nodes
that merely perform a weighted sum over the IVs
we multiply (or add in log-space)
the weights corresponding to a single value of the IV.
This resulting value is actually as if we are gathering a
single row/column from our weight matrix for that particular sample. Hence, we propose to change
the behavior of a Sum
, a ParSum
or a SumsLayer
for IVs
inputs to directly use
the IVs
' placeholder _feed
for gathering the correct rows or columns.
An IVs
node might be connected as input to a Sum
, ParSum
or SumsLayer
node.
If we pass any non-negative integer to
indicate the evidence, we first compute the dense representation of the IVs
.
Then, the dense representation will be applied component-wise to the weight matrix to get the weighted value of a sum node. Finally, the result will be reduced to obtain the value of each sum node being modeled.
In terms of TF operations, the current implementation involves:
- Computing the one-hot representation of the
IVs
value - Assigning rows of
1
s wherever the input had no evidence (-1
) - Applying the
IVs
dense representation component-wise to the weights. This is addition in log space and multiplication in non-log space. - Reducing to get the result of a single sum being modeled. This might a
reduce_logsum
(consisting of multiple Ops itself) in log space and otherwise areduce_sum
in non-log space.
This seems to be an expensive way of actually just selecting any weight per sum, or 1
in the
case of no evidence.
For a Sum
node the weights are represented by a tensor of shape [num_weights]
, this means that
we could perform a tf.gather
operation on this tensor to obtain the values in case of an IVs
input. Before we perform gathering, we might even choose to add another element to the tensor which
contains 1
(i.e. the sum of the normalized weights), which corresponds to the value that needs to
be gathered whenever the IVs
' input is -1
.
For a higher-level Sum
node e.g. a ParSum
or a SumsLayer
, the weight tensor will
have shape [num_sums, num_weights]
, where each row corresponds to a certain sum being modeled by
the node. In this case, we would add a column of ones and gather that in case of -1
, similar to
what was done before.
Inside the _compute_value
and _compute_log_value
methods of our Sum
nodes, we should
distinguish between different kinds of input nodes. Whenever an input is an IVs
node, we no longer
compute its value directly, we merely gather the elements indicated by the 'sparse' representation
of the IVs
input feed.
Note that this method is completely independent of the size of the indicator variable, as there
is no reduction or component-wise operation at all. In the overview below we compare the current
application of IVs
value cwise_reduce
with gather_col
and gather_row
. We have included
both ways of gathering to investigate whether one would have a large advantage over the other.
Since this is not the case, we will just pick the option has the least
impact on the current implementation.
The results below also show that the method is independent of whether it is applied in log-space or non-log space.
#-----------------------#
Input mixture
#-----------------------#
CPU op dt: size setup_time first_run_time rest_run_time correct
GPU op dt: size setup_time first_run_time rest_run_time correct
cwise_reduce int32: 144 73.90 67.27 37.66 True
gather_col int32: 189 95.12 26.63 1.09 True
gather_row int32: 99 51.17 17.01 0.98 True
#-----------------------#
Input mixture Log
#-----------------------#
CPU op dt: size setup_time first_run_time rest_run_time correct
GPU op dt: size setup_time first_run_time rest_run_time correct
cwise_reduce int32: 265 139.13 124.29 78.53 True
gather_col int32: 199 133.17 29.89 1.09 True
gather_row int32: 109 54.68 22.14 0.97 True