From 103025b404a03d948edeff2d0a1f355231f9babf Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Feb 2023 20:25:06 +0100 Subject: [PATCH 1/4] Implement utility to convert Model to and from FunctionGraph --- docs/api_reference.rst | 3 + pymc_experimental/tests/utils/__init__.py | 0 .../tests/utils/test_model_fgraph.py | 273 +++++++++++++++ pymc_experimental/utils/__init__.py | 8 +- pymc_experimental/utils/model_fgraph.py | 310 ++++++++++++++++++ pymc_experimental/utils/pytensorf.py | 33 ++ 6 files changed, 626 insertions(+), 1 deletion(-) create mode 100644 pymc_experimental/tests/utils/__init__.py create mode 100644 pymc_experimental/tests/utils/test_model_fgraph.py create mode 100644 pymc_experimental/utils/model_fgraph.py create mode 100644 pymc_experimental/utils/pytensorf.py diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 2d85c3ab..0dcd149a 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -49,5 +49,8 @@ Utils .. autosummary:: :toctree: generated/ + clone_model spline.bspline_interpolation prior.prior_from_idata + model_fgraph.fgraph_from_model + model_fgraph.model_from_fgraph diff --git a/pymc_experimental/tests/utils/__init__.py b/pymc_experimental/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pymc_experimental/tests/utils/test_model_fgraph.py b/pymc_experimental/tests/utils/test_model_fgraph.py new file mode 100644 index 00000000..b1e6fefe --- /dev/null +++ b/pymc_experimental/tests/utils/test_model_fgraph.py @@ -0,0 +1,273 @@ +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest +from pytensor.graph import Constant, FunctionGraph, node_rewriter +from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor.exceptions import NotScalarConstantError + +from pymc_experimental.utils.model_fgraph import ( + ModelFreeRV, + ModelVar, + fgraph_from_model, + model_deterministic, + model_free_rv, + model_from_fgraph, +) + + +def test_basic(): + """Test we can convert from a PyMC Model to a FunctionGraph and back""" + with pm.Model(coords={"test_dim": range(3)}) as m_old: + x = pm.Normal("x") + y = pm.Deterministic("y", x + 1) + w = pm.HalfNormal("w", pm.math.exp(y)) + z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",)) + pm.Potential("pot", x * 2) + + m_fgraph = fgraph_from_model(m_old) + assert isinstance(m_fgraph, FunctionGraph) + + m_new = model_from_fgraph(m_fgraph) + assert isinstance(m_new, pm.Model) + + assert m_new.coords == {"test_dim": tuple(range(3))} + assert m_new._dim_lengths["test_dim"].eval() == 3 + assert m_new.named_vars_to_dims == {"z": ["test_dim"]} + + named_vars = {"x", "y", "w", "z", "pot"} + assert set(m_new.named_vars) == named_vars + for named_var in named_vars: + assert m_new[named_var] is not m_old[named_var] + for value_new, value_old in zip(m_new.rvs_to_values.values(), m_old.rvs_to_values.values()): + # Constants are not cloned + if not isinstance(value_new, Constant): + assert value_new is not value_old + assert m_new["x"] in m_new.free_RVs + assert m_new["w"] in m_new.free_RVs + assert m_new["y"] in m_new.deterministics + assert m_new["z"] in m_new.observed_RVs + assert m_new["pot"] in m_new.potentials + assert m_new.rvs_to_transforms[m_new["x"]] is None + assert m_new.rvs_to_transforms[m_new["w"]] is pm.distributions.transforms.log + assert m_new.rvs_to_transforms[m_new["z"]] is None + + # Test random + new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1) + old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1) + np.testing.assert_array_equal(new_y_draw, old_y_draw) + np.testing.assert_array_equal(new_z_draw, old_z_draw) + + # Test logp + ip = m_new.initial_point() + np.testing.assert_equal( + m_new.compile_logp()(ip), + m_old.compile_logp()(ip), + ) + + +def test_data(): + """Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly. + + Everything should be preserved across new and old models, except for shared RNGs + """ + with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old: + x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",)) + y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",)) + b0 = pm.ConstantData("b0", 0.0) + b1 = pm.Normal("b1") + mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",)) + obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",)) + + m_new = model_from_fgraph(fgraph_from_model(m_old)) + + # ConstantData is preserved + assert m_new["b0"].data == m_old["b0"].data + + # Shared non-rng shared variables are preserved + assert m_new["x"].container is x.container + assert m_new["y"].container is y.container + assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"] + + # Shared rng shared variables are not preserved + m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container + + with m_old: + pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)}) + + assert m_new.dim_lengths["test_dim"].eval() == 2 + np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0]) + + +def test_deterministics(): + """Test handling of deterministics. + + We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome + However we want them in the middle of Model.basic_RVs, so they display nicely in graphviz + + There is one edge case that has to be considered, when a Deterministic is just a copy of a RV. + In that case we don't bother to reintroduce it in between other Model.basic_RVs + """ + with pm.Model() as m: + x = pm.Normal("x") + mu = pm.Deterministic("mu", pm.math.abs(x)) + sigma = pm.math.exp(x) + pm.Deterministic("sigma", sigma) + y = pm.Normal("y", mu, sigma) + # Special case where the Deterministic + # is a direct view on another model variable + y_ = pm.Deterministic("y_", y) + # Just for kicks, make it a double one! + y__ = pm.Deterministic("y__", y_) + z = pm.Normal("z", y__) + + # Deterministic mu is in the graph of x to y but not sigma + assert m["y"].owner.inputs[3] is m["mu"] + assert m["y"].owner.inputs[4] is not m["sigma"] + + fg = fgraph_from_model(m) + + # Check that no Deterministics are in graph of x to y and y to z + x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs + # [Det(mu), Det(sigma)] + mu = det_mu.owner.inputs[0] + sigma = det_sigma.owner.inputs[0] + # [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))] + assert y.owner.inputs[0].owner.inputs[3] is mu + assert y.owner.inputs[0].owner.inputs[4] is sigma + # [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))] + assert z.owner.inputs[0].owner.inputs[3] is y + # [Det(y), Det(y)], not [Det(y), Det(Det(y))] + assert det_y_.owner.inputs[0] is y + assert det_y__.owner.inputs[0] is y + assert det_y_ is not det_y__ + + # Both mu and sigma deterministics are now in the graph of x to y + m = model_from_fgraph(fg) + assert m["y"].owner.inputs[3] is m["mu"] + assert m["y"].owner.inputs[4] is m["sigma"] + # But not y_* in y to z, since there was no real Op in between + assert m["z"].owner.inputs[3] is m["y"] + assert m["y_"].owner.inputs[0] is m["y"] + assert m["y__"].owner.inputs[0] is m["y"] + + +def test_context_error(): + """Test that model_from_fgraph fails when called inside a Model context. + + We can't allow it, because the new Model that's returned would be a child of whatever Model context is active. + """ + with pm.Model() as m: + x = pm.Normal("x") + + fg = fgraph_from_model(m) + + with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"): + model_from_fgraph(fg) + + +def test_sub_model_error(): + """Test Error is raised when trying to convert a sub-model to fgraph.""" + with pm.Model() as m: + x = pm.Beta("x", 1, 1) + with pm.Model() as sub_m: + y = pm.Normal("y", x) + + nodes = [v for v in fgraph_from_model(m).toposort() if not isinstance(v.op, ModelVar)] + assert len(nodes) == 2 + assert isinstance(nodes[0].op, pm.Beta) + assert isinstance(nodes[1].op, pm.Normal) + + with pytest.raises(ValueError, match="Nested sub-models cannot be converted"): + fgraph_from_model(sub_m) + + +@pytest.fixture() +def non_centered_rewrite(): + @node_rewriter(tracks=[ModelFreeRV]) + def non_centered_param(fgraph: FunctionGraph, node): + """Rewrite that replaces centered normal by non-centered parametrization.""" + + rv, value, *dims = node.inputs + if not isinstance(rv.owner.op, pm.Normal): + return + rng, size, dtype, loc, scale = rv.owner.inputs + + # Only apply rewrite if size information is explicit + if size.ndim == 0: + return None + + try: + is_unit = ( + pt.get_underlying_scalar_constant_value(loc) == 0 + and pt.get_underlying_scalar_constant_value(scale) == 1 + ) + except NotScalarConstantError: + is_unit = False + + # Nothing to do here + if is_unit: + return + + raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng) + raw_norm.name = f"{rv.name}_raw_" + raw_norm_value = raw_norm.clone() + fgraph.add_input(raw_norm_value) + raw_norm = model_free_rv(raw_norm, raw_norm_value, node.op.transform, *dims) + + new_norm = loc + raw_norm * scale + new_norm.name = rv.name + new_norm_det = model_deterministic(new_norm, *dims) + fgraph.add_output(new_norm_det) + + return [new_norm] + + return in2out(non_centered_param) + + +def test_fgraph_rewrite(non_centered_rewrite): + """Test we can apply a simple rewrite to a PyMC Model.""" + + with pm.Model(coords={"subject": range(10)}) as m_old: + group_mean = pm.Normal("group_mean") + group_std = pm.HalfNormal("group_std") + subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",)) + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",)) + + fg = fgraph_from_model(m_old) + non_centered_rewrite.apply(fg) + + m_new = model_from_fgraph(fg) + assert m_new.named_vars_to_dims == { + "subject_mean": ["subject"], + "subject_mean_raw_": ["subject"], + "obs": ["subject"], + } + assert set(m_new.named_vars) == { + "group_mean", + "group_std", + "subject_mean_raw_", + "subject_mean", + "obs", + } + assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"} + assert {rv.name for rv in m_new.observed_RVs} == {"obs"} + assert {rv.name for rv in m_new.deterministics} == {"subject_mean"} + + with pm.Model() as m_ref: + group_mean = pm.Normal("group_mean") + group_std = pm.HalfNormal("group_std") + subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,)) + subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std) + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10)) + + np.testing.assert_array_equal( + pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1), + pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1), + ) + + ip = m_new.initial_point() + np.testing.assert_equal( + m_new.compile_logp()(ip), + m_ref.compile_logp()(ip), + ) diff --git a/pymc_experimental/utils/__init__.py b/pymc_experimental/utils/__init__.py index db751aa2..705d2107 100644 --- a/pymc_experimental/utils/__init__.py +++ b/pymc_experimental/utils/__init__.py @@ -15,5 +15,11 @@ from pymc_experimental.utils import prior, spline from pymc_experimental.utils.linear_cg import linear_cg +from pymc_experimental.utils.model_fgraph import clone_model -# from pymc_experimental.utils.pivoted_cholesky import pivoted_cholesky +__all__ = ( + "clone_model", + "linear_cg", + "prior", + "spline", +) diff --git a/pymc_experimental/utils/model_fgraph.py b/pymc_experimental/utils/model_fgraph.py new file mode 100644 index 00000000..edb773f2 --- /dev/null +++ b/pymc_experimental/utils/model_fgraph.py @@ -0,0 +1,310 @@ +from typing import Optional + +import pytensor +from pymc.logprob.transforms import RVTransform +from pymc.model import Model +from pymc.pytensorf import find_rng_nodes +from pytensor import Variable +from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter +from pytensor.graph.rewriting.basic import out2in +from pytensor.scalar import Identity +from pytensor.tensor.elemwise import Elemwise + +from pymc_experimental.utils.pytensorf import StringType + + +class ModelVar(Op): + """A dummy Op that describes the purpose of a Model variable and contains + meta-information as additional inputs (value and dims). + """ + + def make_node(self, rv, *dims): + assert isinstance(rv, Variable) + dims = self._parse_dims(rv, *dims) + return Apply(self, [rv, *dims], [rv.type(name=rv.name)]) + + def _parse_dims(self, rv, *dims): + if dims: + dims = [pytensor.as_symbolic(dim) for dim in dims] + assert all(isinstance(dim.type, StringType) for dim in dims) + assert len(dims) == rv.type.ndim + return dims + + def infer_shape(self, fgraph, node, inputs_shape): + return [inputs_shape[0]] + + def do_constant_folding(self, fgraph, node): + return False + + def perform(self, *args, **kwargs): + raise RuntimeError("ModelVars should never be in a final graph!") + + +class ModelValuedVar(ModelVar): + + __props__ = ("transform",) + + def __init__(self, transform: Optional[RVTransform] = None): + if transform is not None and not isinstance(transform, RVTransform): + raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}") + self.transform = transform + super().__init__() + + def make_node(self, rv, value, *dims): + assert isinstance(rv, Variable) + dims = self._parse_dims(rv, *dims) + if value is not None: + assert isinstance(value, Variable) + assert rv.type.in_same_class(value.type) + return Apply(self, [rv, value, *dims], [rv.type(name=rv.name)]) + + +class ModelFreeRV(ModelValuedVar): + pass + + +class ModelObservedRV(ModelValuedVar): + pass + + +class ModelPotential(ModelVar): + pass + + +class ModelDeterministic(ModelVar): + pass + + +class ModelNamed(ModelVar): + pass + + +def model_free_rv(rv, value, transform, *dims): + return ModelFreeRV(transform=transform)(rv, value, *dims) + + +model_observed_rv = ModelObservedRV() +model_potential = ModelPotential() +model_deterministic = ModelDeterministic() +model_named = ModelNamed() + + +def toposort_replace(fgraph: FunctionGraph, replacements) -> None: + """Replace multiple variables in topological order.""" + toposort = fgraph.toposort() + sorted_replacements = sorted( + replacements, key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1 + ) + fgraph.replace_all(tuple(sorted_replacements), import_missing=True) + + +@node_rewriter([Elemwise]) +def local_remove_identity(fgraph, node): + if isinstance(node.op.scalar_op, Identity): + return [node.inputs[0]] + + +remove_identity_rewrite = out2in(local_remove_identity) + + +def fgraph_from_model(model: Model) -> FunctionGraph: + """Convert Model to FunctionGraph. + + Create a FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops. + + PyTensor rewrites can be used to transform the FunctionGraph. + + It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`. + + See: model_from_fgraph + """ + + if any(v is not None for v in model.rvs_to_initial_values.values()): + raise NotImplementedError("Cannot convert models with non-default initial_values") + + if model.parent is not None: + raise ValueError( + "Nested sub-models cannot be converted to fgraph. Convert the parent model instead" + ) + + # Collect PyTensor variables + rvs_to_values = model.rvs_to_values + rvs = list(rvs_to_values.keys()) + free_rvs = model.free_RVs + observed_rvs = model.observed_RVs + potentials = model.potentials + # We copy Deterministics (Identity Op) so that they don't show in between "main" variables + # We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator + old_deterministics = model.deterministics + deterministics = [det.copy(det.name) for det in old_deterministics] + # Other variables that are in model.named_vars but are not any of the categories above + # E.g., MutableData, ConstantData, _dim_lengths + accounted_for = free_rvs + observed_rvs + potentials + old_deterministics + other_named_vars = [ + var.copy(var.name) for var in model.named_vars.values() if var not in accounted_for + ] + value_vars = [val for val in rvs_to_values.values() if val not in other_named_vars] + + model_vars = rvs + potentials + deterministics + other_named_vars + value_vars + + memo = {} + + # Replace RNG nodes so that seeding does not interfere with old model + for rng in find_rng_nodes(model_vars): + new_rng = rng.clone() + new_rng.set_value(rng.get_value(borrow=False)) + memo[rng] = new_rng + + fgraph = FunctionGraph( + outputs=model_vars, + clone=True, + memo=memo, + copy_orphans=True, + copy_inputs=True, + ) + # Copy model meta-info to fgraph + fgraph._coords = model._coords.copy() + fgraph._dim_lengths = model._dim_lengths.copy() + + rvs_to_transforms = model.rvs_to_transforms + named_vars_to_dims = model.named_vars_to_dims + + # Introduce dummy `ModelVar` Ops + free_rvs_to_transforms = {memo[k]: tr for k, tr in rvs_to_transforms.items()} + free_rvs_to_values = {memo[k]: memo[v] for k, v in rvs_to_values.items() if k in free_rvs} + observed_rvs_to_values = { + memo[k]: memo[v] for k, v in rvs_to_values.items() if k in observed_rvs + } + potentials = [memo[k] for k in potentials] + deterministics = [memo[k] for k in deterministics] + other_named_vars = [memo[k] for k in other_named_vars] + + vars = fgraph.outputs + new_vars = [] + for var in vars: + dims = named_vars_to_dims.get(var.name, ()) + if var in free_rvs_to_values: + new_var = model_free_rv( + var, free_rvs_to_values[var], free_rvs_to_transforms[var], *dims + ) + elif var in observed_rvs_to_values: + new_var = model_observed_rv(var, observed_rvs_to_values[var], *dims) + elif var in potentials: + new_var = model_potential(var, *dims) + elif var in deterministics: + new_var = model_deterministic(var, *dims) + elif var in other_named_vars: + new_var = model_named(var, *dims) + else: + # Value variables + new_var = var + new_vars.append(new_var) + + toposort_replace(fgraph, tuple(zip(vars, new_vars))) + + # Remove value variable as outputs, now that they are graph inputs + for _ in range(len(value_vars)): + fgraph.remove_output(-1) + + # Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph + remove_identity_rewrite.apply(fgraph) + + return fgraph + + +def model_from_fgraph(fgraph: FunctionGraph) -> Model: + """Convert FunctionGraph to PyMC model. + + This requires nodes to be properly tagged with `ModelVar` dummy Ops. + + See: fgraph_from_model + """ + model = Model() + if model.parent is not None: + raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context") + model._coords = getattr(fgraph, "_coords", {}) + model._dim_lengths = getattr(fgraph, "_dim_lengths", {}) + + # Replace dummy `ModelVar` Ops by the underlying variables, + # Except for Deterministics which could reintroduce the old graphs + fgraph = fgraph.clone() + model_dummy_vars = [ + model_node.outputs[0] + for model_node in fgraph.toposort() + if isinstance(model_node.op, ModelVar) + ] + model_dummy_vars_to_vars = { + dummy_var: dummy_var.owner.inputs[0] + for dummy_var in model_dummy_vars + # Don't include Deterministics! + if not isinstance(dummy_var.owner.op, ModelDeterministic) + } + toposort_replace(fgraph, tuple(model_dummy_vars_to_vars.items())) + + # Populate new PyMC model mappings + non_det_model_vars = set(model_dummy_vars_to_vars.values()) + for model_var in model_dummy_vars: + if isinstance(model_var.owner.op, ModelFreeRV): + var, value, *dims = model_var.owner.inputs + transform = model_var.owner.op.transform + model.free_RVs.append(var) + # PyMC does not allow setting transform when we pass a value_var. Why? + model.create_value_var(var, transform=None, value_var=value) + model.rvs_to_transforms[var] = transform + model.set_initval(var, initval=None) + elif isinstance(model_var.owner.op, ModelObservedRV): + var, value, *dims = model_var.owner.inputs + model.observed_RVs.append(var) + model.create_value_var(var, transform=None, value_var=value) + elif isinstance(model_var.owner.op, ModelPotential): + var, *dims = model_var.owner.inputs + model.potentials.append(var) + elif isinstance(model_var.owner.op, ModelDeterministic): + var, *dims = model_var.owner.inputs + # Register the original var (not the copy) as the Deterministic + # So it shows in the expected place in graphviz. + # unless it's another model var, in which case we need a copy! + if var in non_det_model_vars: + var = var.copy() + model.deterministics.append(var) + elif isinstance(model_var.owner.op, ModelNamed): + var, *dims = model_var.owner.inputs + else: + raise TypeError(f"Unexpected ModelVar type {type(model_var)}") + + var.name = model_var.name + dims = [dim.data for dim in dims] if dims else None + model.add_named_variable(var, dims=dims) + + return model + + +def clone_model(model: Model) -> Model: + """Clone a PyMC model. + + Recreates a PyMC model with clones of the original variables. + Shared variables will point to the same container but be otherwise different objects. + Constants are not cloned. + + + Examples + -------- + + .. code-block:: python + + import pymc as pm + from pymc_experimental.utils import clone_model + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + + with clone_model(m) as clone_m: + # Access cloned variables by name + clone_x = clone_m["x"] + + # z will be part of clone_m but not m + z = pm.Deterministic("z", clone_x + 1) + + """ + return model_from_fgraph(fgraph_from_model(model)) diff --git a/pymc_experimental/utils/pytensorf.py b/pymc_experimental/utils/pytensorf.py new file mode 100644 index 00000000..76358c27 --- /dev/null +++ b/pymc_experimental/utils/pytensorf.py @@ -0,0 +1,33 @@ +import pytensor +from pytensor.graph import Constant, Type + + +class StringType(Type[str]): + def clone(self, **kwargs): + return type(self)() + + def filter(self, x, strict=False, allow_downcast=None): + if isinstance(x, str): + return x + else: + raise TypeError("Expected a string!") + + def __str__(self): + return "string" + + @staticmethod + def may_share_memory(a, b): + return isinstance(a, str) and a is b + + +stringtype = StringType() + + +class StringConstant(Constant): + pass + + +@pytensor._as_symbolic.register(str) +def as_symbolic_string(x, **kwargs): + + return StringConstant(stringtype, x) From 96fbc30bb8a077eb73fd493c61f192819fe90dba Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 14 Apr 2023 07:18:07 +0200 Subject: [PATCH 2/4] Implement uncensor and forecast_timeseries model transformation --- pymc_experimental/model_transform.py | 175 ++++++++++++++++++ .../tests/test_model_transform.py | 90 +++++++++ 2 files changed, 265 insertions(+) create mode 100644 pymc_experimental/model_transform.py create mode 100644 pymc_experimental/tests/test_model_transform.py diff --git a/pymc_experimental/model_transform.py b/pymc_experimental/model_transform.py new file mode 100644 index 00000000..49301326 --- /dev/null +++ b/pymc_experimental/model_transform.py @@ -0,0 +1,175 @@ +from pymc import DiracDelta +from pymc.distributions.censored import CensoredRV +from pymc.distributions.timeseries import AR, AutoRegressiveRV +from pymc.model import Model +from pytensor import shared +from pytensor.graph import FunctionGraph, node_rewriter +from pytensor.graph.basic import get_var_by_name +from pytensor.graph.rewriting.basic import in2out + +from pymc_experimental.utils.model_fgraph import ( + ModelObservedRV, + ModelValuedVar, + fgraph_from_model, + model_free_rv, + model_from_fgraph, + model_named, +) + +__all__ = ( + "forecast_timeseries", + "uncensor", +) + + +@node_rewriter(tracks=[ModelValuedVar]) +def uncensor_node_rewrite(fgraph, node): + """Rewrite that replaces censored variables by uncensored ones""" + + ( + censored_rv, + value, + *dims, + ) = node.inputs + if not isinstance(censored_rv.owner.op, CensoredRV): + return + + model_rv = node.outputs[0] + base_rv = censored_rv.owner.inputs[0] + uncensored_rv = node.op.make_node(base_rv, value, *dims).default_output() + uncensored_rv.name = f"{model_rv.name}_uncensored" + return [uncensored_rv] + + +uncensor_rewrite = in2out(uncensor_node_rewrite) + + +def uncensor(model: Model) -> Model: + """Replace censored variables in the model by uncensored equivalent. + + Replaced variables have the same name as original ones with an additional "_uncensored" suffix. + + .. code-block:: python + + import arviz as az + import pymc as pm + from pymc_experimental.model_transform import uncensor + + with pm.Model() as model: + x = pm.Normal("x") + dist_raw = pm.Normal.dist(x) + y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[-1, 0.5, 1, 1, 1]) + idata = pm.sample() + + with uncensor(model): + idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_uncensored"]) + + az.summary(idata_pp) + """ + fg = fgraph_from_model(model) + + (_, nodes_changed, *_) = uncensor_rewrite.apply(fg) + if not nodes_changed: + raise RuntimeError("No censored variables were replaced by uncensored counterparts") + + return model_from_fgraph(fg) + + +@node_rewriter(tracks=[ModelValuedVar]) +def forecast_timeseries_node_rewrite(fgraph: FunctionGraph, node): + """Rewrite that replaces timeseries variables by new ones starting at the last timepoint(s).""" + + ( + timeseries_rv, + value, + *dims, + ) = node.inputs + if not isinstance(timeseries_rv.owner.op, AutoRegressiveRV): + return + + forecast_steps = get_var_by_name(fgraph.inputs, "forecast_steps_") + if len(forecast_steps) != 1: + return False + + forecast_steps = forecast_steps[0] + + op = timeseries_rv.owner.op + model_rv = node.outputs[0] + + # We cannot reference the variable we are planning to replace + # Or it will introduce circularities in the graph + # FIXME: This special logic shouldn't be needed for ObservedRVs + # but PyMC does not allow one to not resample observed. + # We hack around by conditioning on the value variable directly, + # even though that should not be part of the generative graph... + if isinstance(node.op, ModelObservedRV): + init_dist = DiracDelta.dist(value[..., -op.ar_order :]) + else: + cloned_model_rv = model_rv.owner.clone().default_output() + fgraph.add_output(cloned_model_rv, import_missing=True) + init_dist = DiracDelta.dist(cloned_model_rv[..., -op.ar_order :]) + + if isinstance(timeseries_rv.owner.op, AutoRegressiveRV): + rhos, sigma, *_ = timeseries_rv.owner.inputs + new_timeseries_rv = AR.rv_op( + rhos=rhos, + sigma=sigma, + init_dist=init_dist, + steps=forecast_steps, + ar_order=op.ar_order, + constant_term=op.constant_term, + ) + + new_name = f"{model_rv.name}_forecast" + new_value = new_timeseries_rv.type(name=new_name) + new_timeseries_rv = model_free_rv(new_timeseries_rv, new_value, transform=None) + new_timeseries_rv.name = new_name + + # Import new variables into fgraph (value and RNG) + fgraph.import_var(new_timeseries_rv, import_missing=True) + + return [new_timeseries_rv] + + +forecast_timeseries_rewrite = in2out(forecast_timeseries_node_rewrite, ignore_newtrees=True) + + +def forecast_timeseries(model: Model, forecast_steps: int) -> Model: + """Replace timeseries variables in the model by forecast that start at the last value. + + Replaced variables have the same name as original ones with an additional "_forecast" suffix. + + The function will fail if any variables with fixed static shape depend on the timeseries being replaced, + and forecast_steps differs from the original timeseries steps. + + .. code-block:: python + + import pymc as pm + from pymc_experimental.model_transform import forecast_timeseries + + with pm.Model() as model: + rho = pm.Normal("rho") + sigma = pm.HalfNormal("sigma") + init_dist = pm.Normal.dist() + y = pm.AR("y", init_dist=init_dist, rho=rho, sigma=sigma, observed=np.zeros(100,)) + idata = pm.sample() + + forecast_model = forecast_timeseries(mode, forecast_steps=20) + with forecast_model: + idata_pp = pm.sample_posterior_predictive(idata, var_names=["y_forecast"]) + + az.summary(idata_pp) + """ + + fg = fgraph_from_model(model) + + forecast_steps_sh = shared(forecast_steps, name="forecast_steps_") + forecast_steps_sh = model_named(forecast_steps_sh) + fg.add_output(forecast_steps_sh, import_missing=True) + + (_, nodes_changed, *_) = forecast_timeseries_rewrite.apply(fg) + if not nodes_changed: + raise RuntimeError("No timeseries were replaced by forecast counterparts") + + res = model_from_fgraph(fg) + return res diff --git a/pymc_experimental/tests/test_model_transform.py b/pymc_experimental/tests/test_model_transform.py new file mode 100644 index 00000000..25df4a9f --- /dev/null +++ b/pymc_experimental/tests/test_model_transform.py @@ -0,0 +1,90 @@ +import arviz as az +import numpy as np +import pymc as pm +import pytest + +from pymc_experimental.model_transform import forecast_timeseries, uncensor + + +@pytest.mark.parametrize( + "transform, kwargs", + [ + (uncensor, dict()), + (forecast_timeseries, dict(forecast_steps=20)), + ], +) +def test_transform_error(transform, kwargs): + """Test informative error is raised when the transform is not applicable to a model.""" + with pm.Model() as model: + x = pm.Normal("x") + y = pm.Normal("y", x, observed=[0, 5, 10]) + + with pytest.raises(RuntimeError, match="No .* were replaced by .* counterparts"): + transform(model, **kwargs) + + +def test_uncensor(): + with pm.Model() as model: + x = pm.Normal("x") + dist_raw = pm.Normal.dist(x) + y = pm.Censored("y", dist=dist_raw, lower=-1, upper=1, observed=[0, 5, 10]) + det = pm.Deterministic("det", y * 2) + + idata = az.from_dict({"x": np.zeros((2, 500))}) + + with uncensor(model): + pp = pm.sample_posterior_predictive( + idata, + var_names=["y_uncensored", "det"], + random_seed=18, + ).posterior_predictive + + assert np.any(np.abs(pp["y_uncensored"]) > 1) + np.testing.assert_allclose(pp["y_uncensored"] * 2, pp["det"]) + + +@pytest.mark.parametrize("observed", (True, False)) +@pytest.mark.parametrize("ar_order", (1, 2)) +def test_forecast_timeseries_ar(observed, ar_order): + data_steps = 3 + data = np.hstack((np.zeros(ar_order), (np.arange(data_steps) + 1) * 100.0)) + with pm.Model() as model: + rho = pm.Normal("rho", shape=(ar_order,)) + sigma = pm.HalfNormal("sigma") + init_dist = pm.Normal.dist(0, 1e-3) + y = pm.AR( + "y", + init_dist=init_dist, + rho=rho, + sigma=sigma, + observed=data if observed else None, + steps=data_steps, + ) + det = pm.Deterministic("det", y * 2) + + draws = (2, 50) + # These rhos mean that all steps will be data[-1] for ar_order > 1 + idata_dict = { + "rho": np.full((*draws, ar_order), (0.1,) + (0,) * (ar_order - 1)), + "sigma": np.full(draws, 1e-5), + } + if observed: + idata = az.from_dict(idata_dict, observed_data={"y": data}) + else: + idata_dict["y"] = np.full((*draws, len(data)), data) + idata = az.from_dict(idata_dict) + + forecast_steps = 5 + with forecast_timeseries(model, forecast_steps=forecast_steps): + pp = pm.sample_posterior_predictive( + idata, + var_names=["y_forecast", "det"], + random_seed=50, + ).posterior_predictive + + expected = data[-1] / np.logspace(0, forecast_steps, forecast_steps + 1) + expected = np.hstack((data[-ar_order:-1], expected)) + np.testing.assert_allclose( + pp["y_forecast"].values, np.full((*draws, forecast_steps + ar_order), expected), rtol=0.01 + ) + np.testing.assert_allclose(pp["y_forecast"] * 2, pp["det"]) From e65d0abbc28d2ac6397fbfb4ebc84b0e6e22d53d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 21 Apr 2023 18:47:08 +0200 Subject: [PATCH 3/4] WIP NB --- pymc_experimental/gp/GPs from scratch.ipynb | 55 +++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 pymc_experimental/gp/GPs from scratch.ipynb diff --git a/pymc_experimental/gp/GPs from scratch.ipynb b/pymc_experimental/gp/GPs from scratch.ipynb new file mode 100644 index 00000000..c056ba36 --- /dev/null +++ b/pymc_experimental/gp/GPs from scratch.ipynb @@ -0,0 +1,55 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from pytensor.graph.op import Op" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "hide_input": false, + "kernelspec": { + "display_name": "pymc_experimental", + "language": "python", + "name": "pymc_experimental" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 2084e167486a42b313abcec77958d3a24e001ae3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 21 Apr 2023 20:41:30 +0200 Subject: [PATCH 4/4] WIP NB --- pymc_experimental/gp/GPs from scratch.ipynb | 484 +++++++++++++++++++- 1 file changed, 482 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/gp/GPs from scratch.ipynb b/pymc_experimental/gp/GPs from scratch.ipynb index c056ba36..3cb9e84c 100644 --- a/pymc_experimental/gp/GPs from scratch.ipynb +++ b/pymc_experimental/gp/GPs from scratch.ipynb @@ -2,11 +2,491 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "from pytensor.graph.op import Op" + "import numpy as np\n", + "\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import pymc as pm\n", + "\n", + "from pytensor.tensor.slinalg import cholesky\n", + "from pytensor.graph import FunctionGraph\n", + "from pytensor.graph.rewriting.basic import node_rewriter, in2out\n", + "from pytensor.tensor.rewriting.basic import register_canonicalize\n", + "from pytensor.graph.op import Op, Apply\n", + "from pymc.gp.util import stabilize\n", + "from pymc.logprob.abstract import _logprob, _get_measurable_outputs, MeasurableVariable" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "class Cov(Op):\n", + "\n", + " __props__ = (\"fn\",)\n", + "\n", + " def __init__(self, fn):\n", + " self.fn = fn\n", + "\n", + " def make_node(self, ls):\n", + " ls = pt.as_tensor(ls)\n", + " out = pt.matrix(shape=(None, None))\n", + "\n", + " return Apply(self, [ls], [out])\n", + "\n", + " def __call__(self, ls=1.0):\n", + " return super().__call__(ls)\n", + "\n", + " def perform(self, node, inputs, output_storage):\n", + " raise NotImplementedError(\"You should convert Cov into a TensorVariable expression!\")\n", + "\n", + " def do_constant_folding(self, fgraph, node):\n", + " return False\n", + "\n", + "\n", + "class GP(Op):\n", + "\n", + " __props__ = (\"approx\",)\n", + "\n", + " def __init__(self, approx):\n", + " self.approx = approx\n", + "\n", + " def make_node(self, mean, cov):\n", + " mean = pt.as_tensor(mean)\n", + " cov = pt.as_tensor(cov)\n", + "\n", + " if not (cov.owner and isinstance(cov.owner.op, Cov)):\n", + " raise ValueError(\"Second argument should be a Cov output.\")\n", + "\n", + " out = pt.vector(shape=(None,))\n", + "\n", + " return Apply(self, [mean, cov], [out])\n", + "\n", + " def perform(self, node, inputs, output_storage):\n", + " raise NotImplementedError(\"You cannot evaluate a GP, not enough RAM in the Universe.\")\n", + "\n", + " def do_constant_folding(self, fgraph, node):\n", + " return False\n", + "\n", + "\n", + "class PriorFromGP(Op):\n", + " \"\"\"This Op will be replaced by the right MvNormal.\"\"\"\n", + "\n", + " def make_node(self, gp, x, rng):\n", + " gp = pt.as_tensor(gp)\n", + " if not (gp.owner and isinstance(gp.owner.op, GP)):\n", + " raise ValueError(\"First argument should be a GP output.\")\n", + "\n", + " # TODO: Assert RNG has the right type\n", + " x = pt.as_tensor(x)\n", + " out = x.type()\n", + "\n", + " return Apply(self, [gp, x, rng], [out])\n", + "\n", + " def __call__(self, gp, x, rng=None):\n", + " if rng is None:\n", + " rng = pytensor.shared(np.random.default_rng())\n", + " return super().__call__(gp, x, rng)\n", + "\n", + " def perform(self, node, inputs, output_storage):\n", + " raise NotImplementedError(\"You should convert PriorFromGP into a MvNormal!\")\n", + "\n", + " def do_constant_folding(self, fgraph, node):\n", + " return False\n", + "\n", + "\n", + "cov_op = Cov(fn=pm.gp.cov.ExpQuad)\n", + "gp_op = GP(\"vanilla\")\n", + "# SymbolicRandomVariable.register(type(gp_op))\n", + "prior_from_gp = PriorFromGP()\n", + "\n", + "MeasurableVariable.register(type(prior_from_gp))\n", + "\n", + "\n", + "@_get_measurable_outputs.register(type(prior_from_gp))\n", + "def gp_measurable_outputs(op, node):\n", + " return node.outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PriorFromGP [id A] \n", + " |GP{approx='vanilla'} [id B] \n", + " | |mean [id C] \n", + " | |Cov{fn=} [id D] \n", + " | |ls [id E] \n", + " |x [id F] \n", + " |RandomGeneratorSharedVariable() [id G] \n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mean = pt.vector(\"mean\")\n", + "x = pt.vector(\"x\", shape=(50,))\n", + "ls = pt.scalar(\"ls\")\n", + "\n", + "cov = cov_op(ls)\n", + "gp = gp_op(mean, cov)\n", + "f = prior_from_gp(gp, x)\n", + "pytensor.dprint(f, print_type=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[PriorFromGP.0]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pymc.logprob.abstract import get_measurable_outputs\n", + "\n", + "get_measurable_outputs(f.owner.op, f.owner)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "# You can only run once\n", + "@register_canonicalize\n", + "@node_rewriter(tracks=[PriorFromGP])\n", + "def prior_from_gp_to_mvnormal(fgraph: FunctionGraph, node: Apply):\n", + " out = node.outputs[0]\n", + " gp, X, rng = node.inputs\n", + " # TODO: Check GP is still a GP Op\n", + " mean, cov = gp.owner.inputs\n", + "\n", + " if gp.owner.op.approx != \"vanilla\":\n", + " return False\n", + "\n", + " # Materialize cov\n", + " ls = cov.owner.inputs[0]\n", + " cov = cov.owner.op.fn(input_dim=1, ls=ls).full(X[:, None])\n", + "\n", + " size = pt.shape(X)[0]\n", + " fgraph.add_input(rng)\n", + "\n", + " # TODO: Give names\n", + " L = cholesky(stabilize(cov))\n", + " # L.name = \"L\"\n", + " v = pm.Normal.dist(0, 1, size=size, rng=rng)\n", + " f = mean + pt.dot(L, v)\n", + "\n", + " return [f]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FromFunctionNodeRewriter(, [], ())" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prior_from_gp_to_mvnormal" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "fg = FunctionGraph(outputs=[f], clone=False)\n", + "[out] = prior_from_gp_to_mvnormal.transform(fg, fg.outputs[0].owner)\n", + "# pytensor.dprint(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "@_logprob.register(PriorFromGP)\n", + "def prior_gp_logprob(op, values, gp, X, rng, **kwargs):\n", + " [value] = values\n", + "\n", + " # TODO: Check GP is still a GP Op\n", + " mean, cov = gp.owner.inputs\n", + "\n", + " if gp.owner.op.approx != \"vanilla\":\n", + " raise NotImplementedError()\n", + "\n", + " # Materialize cov\n", + " ls = cov.owner.inputs[0]\n", + " cov = cov.owner.op.fn(input_dim=1, ls=ls).full(X[:, None])\n", + "\n", + " f = pm.MvNormal.dist(mu=mean, cov=stabilize(cov))\n", + " return pm.logp(f, value)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[ls ~ Gamma(4, f()), f]" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with pm.Model() as m:\n", + " mean = pt.zeros(())\n", + " x = pm.ConstantData(\"x\", np.linspace(0, 10, 20))\n", + " ls = pm.Gamma(\"ls\", alpha=4, beta=1)\n", + "\n", + " cov = cov_op(ls)\n", + " gp = gp_op(mean, cov)\n", + " f = prior_from_gp(gp, x)\n", + "\n", + " m.register_rv(f, name=\"f\", initval=np.zeros(20))\n", + "\n", + "m.basic_RVs" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'ls_log__': array(1.38629436),\n", + " 'f': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0.])}" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ip = m.initial_point()\n", + "ip" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(78.72276943)" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m.compile_logp()(m.initial_point())" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [f, ls]\n" + ] + } + ], + "source": [ + "with m:\n", + " idata = pm.sample_prior_predictive()" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "# idata.prior[\"f\"].mean((\"chain\", \"draw\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Check{posdef}.0" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pm.logp(f, np.ones(20))" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "__logp" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m.logp()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Auto-assigning NUTS sampler...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [ls, f]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 2.76% [221/8000 01:11<41:50 Sampling 4 chains, 0 divergences]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with m:\n", + " idata = pm.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# pytensor.dprint(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# out.eval({mean: np.ones(50), x: np.linspace(0, 10, 50), ls: 1.0})" ] }, {