Skip to content

Commit

Permalink
Merge branch 'aesara-devs:main' into ifelse-mixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored Dec 17, 2022
2 parents 8c9c0f3 + 0959489 commit b6aa902
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 59 deletions.
9 changes: 5 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ containing Aesara ``RandomVariable``\s:
from aeppl import joint_logprob, pprint
srng = at.random.RandomStream()
# A simple scale mixture model
S_rv = at.random.invgamma(0.5, 0.5)
Y_rv = at.random.normal(0.0, at.sqrt(S_rv))
S_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0.0, at.sqrt(S_rv))
# Compute the joint log-probability
logprob, (y, s) = joint_logprob(Y_rv, S_rv)
Expand Down Expand Up @@ -94,8 +95,8 @@ Joint log-probabilities can be computed for some terms that are *derived* from
.. code-block:: python
# Create a switching model from a Bernoulli distributed index
Z_rv = at.random.normal([-100, 100], 1.0, name="Z")
I_rv = at.random.bernoulli(0.5, name="I")
Z_rv = srng.normal([-100, 100], 1.0, name="Z")
I_rv = srng.bernoulli(0.5, name="I")
M_rv = Z_rv[I_rv]
M_rv.name = "M"
Expand Down
5 changes: 3 additions & 2 deletions aeppl/joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def conditional_logprob(
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
sigma2_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0, at.sqrt(sigma2_rv))
Expand Down Expand Up @@ -267,7 +268,7 @@ def joint_logprob(
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
sigma2_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0, at.sqrt(sigma2_rv))
Expand Down
2 changes: 1 addition & 1 deletion aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def categorical_logprob(op, values, *inputs, **kwargs):
)
)
# FIXME: `take_along_axis` drops a broadcastable dimension
# when `value.broadcastable == p.broadcastable == (True, True, False)`.
# when `value.type.shape == p.type.shape == (1, 1, None)`.
else:
res = at.log(p[value])

Expand Down
16 changes: 8 additions & 8 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,17 @@ def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
class MixtureRV(Op):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
__props__ = ("indices_end_idx", "out_dtype", "out_shape")

def __init__(self, indices_end_idx, out_dtype, out_broadcastable):
def __init__(self, indices_end_idx, out_dtype, out_shape):
super().__init__()
self.indices_end_idx = indices_end_idx
self.out_dtype = out_dtype
self.out_broadcastable = out_broadcastable
self.out_shape = out_shape

def make_node(self, *inputs):
return Apply(
self, list(inputs), [TensorType(self.out_dtype, self.out_broadcastable)()]
self, list(inputs), [TensorType(self.out_dtype, shape=self.out_shape)()]
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -284,8 +284,8 @@ def mixture_replace(fgraph, node):
# Replace this sub-graph with a `MixtureRV`
mix_op = MixtureRV(
1 + len(mixing_indices),
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
old_mixture_rv.type.dtype,
old_mixture_rv.type.shape,
)
new_node = mix_op.make_node(*([join_axis] + mixing_indices + mixture_rvs))

Expand Down Expand Up @@ -364,8 +364,8 @@ def ifelse_mixture_replace(fgraph, node):
"""
mix_op = MixtureRV(
2,
old_mixture_rv.dtype,
old_mixture_rv.broadcastable,
old_mixture_rv.type.dtype,
old_mixture_rv.type.shape,
)

if node.inputs[0].ndim == 0:
Expand Down
3 changes: 2 additions & 1 deletion aeppl/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ class PreamblePPrinter(PPrinter):
-------
>>> import aesara.tensor as at
>>> from aeppl.printing import pprint
>>> X_rv = at.random.normal(at.scalar('\\mu'), at.scalar('\\sigma'), name='X')
>>> srng = at.random.RandomStream()
>>> X_rv = srng.normal(at.scalar('\\mu'), at.scalar('\\sigma'), name='X')
>>> print(pprint(X_rv))
\\mu in R
\\sigma in R
Expand Down
64 changes: 32 additions & 32 deletions docs/source/api/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The :py:func:`aeppl.logprob.logprob` function can be called on any random variab
import aesara.tensor as at
from aeppl.logprob import _logprob
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar("mu")
sigma = at.scalar("sigma")
Expand All @@ -29,7 +29,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
p = at.scalar("p")
x_rv = snrg.bernoulli(p)
Expand All @@ -43,7 +43,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
a = at.scalar("a")
b = at.scalar("b")
Expand All @@ -59,7 +59,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
n = at.iscalar("n")
a = at.scalar("a")
Expand All @@ -76,7 +76,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
n = at.iscalar("n")
p = at.scalar("p")
Expand All @@ -92,7 +92,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
loc = at.scalar("loc")
scale = at.scalar("scale")
Expand All @@ -107,7 +107,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
p = at.vector("p")
x_rv = snrg.categorical(p)
Expand All @@ -121,7 +121,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
df = at.scalar("df")
x_rv = snrg.chisquare(df)
Expand All @@ -148,7 +148,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
alpha = at.vector("alpha")
x_rv = snrg.dirichlet(alpha)
Expand All @@ -167,7 +167,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
beta = at.scalar("beta")
x_rv = snrg.exponential(beta)
Expand All @@ -181,7 +181,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
alpha = at.scalar('alpha')
beta = at.scalar('beta')
Expand All @@ -196,7 +196,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
p = at.scalar("p")
x_rv = snrg.geometric(p)
Expand All @@ -210,7 +210,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar('mu')
beta = at.scalar('beta')
Expand All @@ -225,7 +225,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
x0 = at.scalar('x0')
gamma = at.scalar('gamma')
Expand All @@ -240,7 +240,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar('mu')
sigma = at.scalar('sigma')
Expand All @@ -255,7 +255,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
ngood = at.scalar("ngood")
nbad = at.scalar("nbad")
Expand All @@ -271,7 +271,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
alpha = at.scalar('alpha')
beta = at.scalar('beta')
Expand All @@ -286,7 +286,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar("mu")
lmbda = at.scalar("lambda")
Expand All @@ -301,7 +301,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar("mu")
s = at.scalar("s")
Expand All @@ -316,7 +316,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar("mu")
sigma = at.scalar("sigma")
Expand All @@ -331,7 +331,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
n = at.iscalar("n")
p = at.vector("p")
Expand All @@ -346,7 +346,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.vector('mu')
Sigma = at.matrix('sigma')
Expand All @@ -362,7 +362,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
n = at.iscalar("n")
p = at.scalar("p")
Expand All @@ -377,7 +377,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar('mu')
sigma = at.scalar('sigma')
Expand All @@ -392,7 +392,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
b = at.scalar("b")
scale = at.scalar("scale")
Expand All @@ -407,7 +407,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
lmbda = at.scalar("lambda")
x_rv = snrg.poisson(lmbda)
Expand All @@ -421,7 +421,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
df = at.scalar('df')
loc = at.scalar('loc')
Expand All @@ -437,7 +437,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
left = at.scalar('left')
mode = at.scalar('mode')
Expand All @@ -453,7 +453,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
low = at.scalar('low')
high = at.scalar('high')
Expand All @@ -468,7 +468,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar('mu')
kappa = at.scalar('kappa')
Expand All @@ -483,7 +483,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
mu = at.scalar('mu')
lmbda = at.scalar('lambda')
Expand All @@ -499,7 +499,7 @@ Documentation for the Aesara implementation can be found here: :external:py:clas
import aesara.tensor as at
srng = at.random.RandomStream(0)
srng = at.random.RandomStream()
k = at.scalar('k')
x_rv = srng.weibull(k)
Loading

0 comments on commit b6aa902

Please sign in to comment.