Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce graph rewrite for mixture sub-graphs defined via IfElse Op #169

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
node_rewriter,
pre_greedy_node_rewriter,
)
from aesara.ifelse import ifelse
from aesara.ifelse import IfElse, ifelse
from aesara.scalar.basic import Switch
from aesara.tensor.basic import Join, MakeVector
from aesara.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -309,6 +309,31 @@ def mixture_replace(fgraph, node):
def switch_mixture_replace(fgraph, node):
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

old_mixture_rv = node.default_output()

# Add an extra dimension to the indices so that the `MixtureRV` we
# construct represents a valid
# `at.stack(node.inputs[1:])[f(node.inputs[0])]`, for some function `f`,
# that's equivalent to `at.switch(*node.inputs)`.
out_shape = at.broadcast_shape(
*(tuple(v.shape) for v in node.inputs[1:]), arrays_are_shapes=True
)
switch_indices = (node.inputs[0],) + tuple(at.arange(s) for s in out_shape)

# Construct the proxy/intermediate mixture representation
switch_stack = at.stack(node.inputs[::-1])[switch_indices]
switch_stack.name = old_mixture_rv.name

return mixture_replace.transform(fgraph, switch_stack.owner)


@node_rewriter((IfElse,))
def ifelse_mixture_replace(fgraph, node):
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

Expand All @@ -332,14 +357,25 @@ def switch_mixture_replace(fgraph, node):
new_comp_rv = new_node.outputs[out_idx]
mixture_rvs.append(new_comp_rv)

"""
Unlike mixtures generated via at.stack, there is only one condition, i.e. index
for switch/ifelse-defined mixture sub-graphs. However, this condition can be
non-scalar for Switch Ops.
"""
mix_op = MixtureRV(
2,
old_mixture_rv.type.dtype,
old_mixture_rv.type.shape,
)
new_node = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
)

if node.inputs[0].ndim == 0:
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
# created using at.stack and Subtensor indexing
new_node = mix_op.make_node(
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
)
else:
new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs))

new_mixture_rv = new_node.default_output()

Expand Down Expand Up @@ -420,7 +456,7 @@ def logprob_MixtureRV(
logprob_rewrites_db.register(
"mixture_replace",
EquilibriumGraphRewriter(
[mixture_replace, switch_mixture_replace],
[mixture_replace, switch_mixture_replace, ifelse_mixture_replace],
max_use_ratio=aesara.config.optdb__max_use_ratio,
),
"basic",
Expand Down
135 changes: 111 additions & 24 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import scipy.stats.distributions as sp
from aesara.graph.basic import Variable, equal_computations
from aesara.ifelse import ifelse
from aesara.tensor.random.basic import CategoricalRV
from aesara.tensor.shape import shape_tuple
from aesara.tensor.subtensor import as_index_constant
Expand Down Expand Up @@ -232,25 +233,6 @@ def test_hetero_mixture_binomial(p_val, size):
(),
0,
),
(
(
np.array(0, dtype=aesara.config.floatX),
np.array(1, dtype=aesara.config.floatX),
),
(
np.array(0.5, dtype=aesara.config.floatX),
np.array(0.5, dtype=aesara.config.floatX),
),
(
np.array(100, dtype=aesara.config.floatX),
np.array(1, dtype=aesara.config.floatX),
),
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
(),
(),
(),
0,
),
(
(
np.array(0, dtype=aesara.config.floatX),
Expand Down Expand Up @@ -682,17 +664,122 @@ def test_mixture_with_DiracDelta():
assert M_rv in logp_res


def test_switch_mixture():
@pytest.mark.parametrize(
"op, X_args, Y_args, p_val, comp_size, idx_size",
[
[op] + list(test_args)
for op in [at.switch, ifelse]
for test_args in [
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(),
(),
),
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(),
(6,),
),
(
(
np.array([10, 20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array([-10, -20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array([0.9, 0.1], dtype=aesara.config.floatX),
(2,),
(2,),
),
(
(
np.array([10, 20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array([-10, -20], dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array([0.9, 0.1], dtype=aesara.config.floatX),
None,
None,
),
(
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(2, 3),
(2, 3),
),
(
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(2, 3),
(),
),
(
(
np.array(10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
(
np.array(-10, dtype=aesara.config.floatX),
np.array(0.1, dtype=aesara.config.floatX),
),
np.array(0.5, dtype=aesara.config.floatX),
(3,),
(3,),
),
]
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
],
)
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
"""
The argument size is both the input to srng.normal and the expected
size of the mixture RV Z1_rv
"""
srng = at.random.RandomStream(29833)

X_rv = srng.normal(-10.0, 0.1, name="X")
Y_rv = srng.normal(10.0, 0.1, name="Y")
X_rv = srng.normal(*X_args, size=comp_size, name="X")
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")

I_rv = srng.bernoulli(0.5, name="I")
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
i_vv = I_rv.clone()
i_vv.name = "i"

Z1_rv = at.switch(I_rv, X_rv, Y_rv)
Z1_rv = op(I_rv, X_rv, Y_rv)
larryshamalama marked this conversation as resolved.
Show resolved Hide resolved
z_vv = Z1_rv.clone()
z_vv.name = "z1"

Expand Down