Skip to content

Commit

Permalink
add TruncatedDistribution and TruncatedNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
alicanb committed Feb 2, 2018
1 parent 71fd505 commit c5ef53e
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from .poisson import Poisson
from .studentT import StudentT
from .transformed_distribution import TransformedDistribution
from .truncated_distribution import TruncatedDistribution
from .truncated_normal import TruncatedNormal
from .uniform import Uniform

__all__ = [
Expand All @@ -79,8 +81,10 @@
'Pareto',
'StudentT',
'Poisson',
'TruncatedNormal',
'Uniform',
'TransformedDistribution',
'TruncatedDistribution',
'biject_to',
'kl_divergence',
'register_kl',
Expand Down
59 changes: 59 additions & 0 deletions torch/distributions/truncated_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all


class TruncatedDistribution(Distribution):
r"""
Extension of the Distribution class, which applies truncation to base distribution.
"""
def __init__(self, base_distribution, lower_bound=-float('inf'), upper_bound=float('inf'), *args, **kwargs):
super(TruncatedDistribution, self).__init__(*args, **kwargs)
self.base_dist = base_distribution
self.params = self.base_dist.params
self.lower_bound, self.upper_bound, _ = broadcast_all(lower_bound, upper_bound,
getattr(self.base_dist,
list(self.self.base_dist.params.keys)[0]))
self.params['lower_bound': constraints.dependent, 'upper_bound': constraints.dependent]

@constraints.dependent_property
def support(self):
# Note: The proper way to implement this is intersection([lower_bound, upper_bound], base_dist.support)
# This requires intersection method to be implemented for constraints.
return constraints.interval(self.lower_bound, self.upper_bound)

@property
def batch_shape(self):
return self.base_dist.batch_shape

@property
def event_shape(self):
return self.base_dist.event_shape

def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched via inverse cdf sampling. Note that this
is a generic sampler which is not the most efficient or accurate around tails of base distribution.
"""
shape = shape = self._extended_shape(sample_shape)
u = getattr(self.base_dist, list(self.base_dist.params.keys())[0]).new(shape).uniform()
return (self.icdf(self.base_dist.cdf(self.lower_bound) +
u * (self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound))))

def log_prob(self, value):
"""
Returns the log of the probability density/mass function evaluated at `value`.
Returns -inf in value is out of bounds
"""
log_prob = self.base_dist.log_prob(value)
log_prob[(value < self.lower_bound) | (value > self.upper_bound)] = -float('inf')
log_prob = log_prob - (self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound)).log()

def cdf(self, value):
"""
Cumulative distribution function for the truncated distribution
"""
return ((self.base_dist.cdf(value) - self.base_dist.cdf(self.lower_bound)) /
(self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound)))
25 changes: 25 additions & 0 deletions torch/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torch.distributions.normal import Normal
from torch.distributions.truncated_distribution import TruncatedDistribution


class TruncatedNormal(TruncatedDistribution):
r"""
Creates a log-normal distribution parameterized by
`mean` and `std` where::
X ~ Normal(loc, scale)
Y = exp(X) ~ LogNormal(loc, scale)
Example::
>>> m = LogNormal(torch.Tensor([0.0]), torch.Tensor([1.0]))
>>> m.sample() # log-normal distributed with mean=0 and stddev=1
0.1046
[torch.FloatTensor of size 1]
Args:
loc (float or Tensor or Variable): mean of log of distribution
scale (float or Tensor or Variable): standard deviation of log ofthe distribution
"""
def __init__(self, loc, scale, lower_bound, upper_bound):
super(TruncatedNormal, self).__init__(Normal(loc, scale), lower_bound, upper_bound)

0 comments on commit c5ef53e

Please sign in to comment.