forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add TruncatedDistribution and TruncatedNormal
- Loading branch information
Showing
3 changed files
with
88 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
from torch.distributions import constraints | ||
from torch.distributions.distribution import Distribution | ||
from torch.distributions.utils import broadcast_all | ||
|
||
|
||
class TruncatedDistribution(Distribution): | ||
r""" | ||
Extension of the Distribution class, which applies truncation to base distribution. | ||
""" | ||
def __init__(self, base_distribution, lower_bound=-float('inf'), upper_bound=float('inf'), *args, **kwargs): | ||
super(TruncatedDistribution, self).__init__(*args, **kwargs) | ||
self.base_dist = base_distribution | ||
self.params = self.base_dist.params | ||
self.lower_bound, self.upper_bound, _ = broadcast_all(lower_bound, upper_bound, | ||
getattr(self.base_dist, | ||
list(self.self.base_dist.params.keys)[0])) | ||
self.params['lower_bound': constraints.dependent, 'upper_bound': constraints.dependent] | ||
|
||
@constraints.dependent_property | ||
def support(self): | ||
# Note: The proper way to implement this is intersection([lower_bound, upper_bound], base_dist.support) | ||
# This requires intersection method to be implemented for constraints. | ||
return constraints.interval(self.lower_bound, self.upper_bound) | ||
|
||
@property | ||
def batch_shape(self): | ||
return self.base_dist.batch_shape | ||
|
||
@property | ||
def event_shape(self): | ||
return self.base_dist.event_shape | ||
|
||
def rsample(self, sample_shape=torch.Size()): | ||
""" | ||
Generates a sample_shape shaped sample or sample_shape shaped batch of | ||
samples if the distribution parameters are batched via inverse cdf sampling. Note that this | ||
is a generic sampler which is not the most efficient or accurate around tails of base distribution. | ||
""" | ||
shape = shape = self._extended_shape(sample_shape) | ||
u = getattr(self.base_dist, list(self.base_dist.params.keys())[0]).new(shape).uniform() | ||
return (self.icdf(self.base_dist.cdf(self.lower_bound) + | ||
u * (self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound)))) | ||
|
||
def log_prob(self, value): | ||
""" | ||
Returns the log of the probability density/mass function evaluated at `value`. | ||
Returns -inf in value is out of bounds | ||
""" | ||
log_prob = self.base_dist.log_prob(value) | ||
log_prob[(value < self.lower_bound) | (value > self.upper_bound)] = -float('inf') | ||
log_prob = log_prob - (self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound)).log() | ||
|
||
def cdf(self, value): | ||
""" | ||
Cumulative distribution function for the truncated distribution | ||
""" | ||
return ((self.base_dist.cdf(value) - self.base_dist.cdf(self.lower_bound)) / | ||
(self.base_dist.cdf(self.upper_bound) - self.base_dist.cdf(self.lower_bound))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from torch.distributions.normal import Normal | ||
from torch.distributions.truncated_distribution import TruncatedDistribution | ||
|
||
|
||
class TruncatedNormal(TruncatedDistribution): | ||
r""" | ||
Creates a log-normal distribution parameterized by | ||
`mean` and `std` where:: | ||
X ~ Normal(loc, scale) | ||
Y = exp(X) ~ LogNormal(loc, scale) | ||
Example:: | ||
>>> m = LogNormal(torch.Tensor([0.0]), torch.Tensor([1.0])) | ||
>>> m.sample() # log-normal distributed with mean=0 and stddev=1 | ||
0.1046 | ||
[torch.FloatTensor of size 1] | ||
Args: | ||
loc (float or Tensor or Variable): mean of log of distribution | ||
scale (float or Tensor or Variable): standard deviation of log ofthe distribution | ||
""" | ||
def __init__(self, loc, scale, lower_bound, upper_bound): | ||
super(TruncatedNormal, self).__init__(Normal(loc, scale), lower_bound, upper_bound) |