# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import singledispatch
import numpy as np
import pytensor.tensor as pt
from pytensor import config, scan
from pytensor.graph import Op
from pytensor.graph.basic import Apply
from pytensor.graph.replace import graph_replace
from pytensor.raise_op import CheckAndRaise
from pytensor.scan import until
from pytensor.tensor import TensorConstant, TensorVariable
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
Distribution,
SymbolicRandomVariable,
_support_point,
support_point,
)
from pymc.distributions.shape_utils import (
_change_dist_size,
change_dist_size,
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.logprob.basic import icdf, logccdf, logcdf, logp
from pymc.math import logdiffexp
from pymc.pytensorf import collect_default_updates
from pymc.util import check_dist_not_registered
class TruncatedRV(SymbolicRandomVariable):
"""An `Op` constructed from a PyTensor graph that represents a truncated univariate random variable."""
default_output: int = 0
base_rv_op: Op
max_n_steps: int
def __init__(
self,
*args,
base_rv_op: Op,
max_n_steps: int,
**kwargs,
):
self.base_rv_op = base_rv_op
self.max_n_steps = max_n_steps
self._print_name = (
f"Truncated{self.base_rv_op._print_name[0]}",
f"\\operatorname{{{self.base_rv_op._print_name[1]}}}",
)
super().__init__(*args, **kwargs)
@classmethod
def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
# We don't accept rng because we don't have control over it when using a specialized Op
# and there may be a need for multiple RNGs in dist.
# Try to use specialized Op
try:
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
except NotImplementedError:
pass
lower = pt.as_tensor_variable(lower) if lower is not None else pt.constant(-np.inf)
upper = pt.as_tensor_variable(upper) if upper is not None else pt.constant(np.inf)
if size is not None:
size = pt.as_tensor(size, dtype="int64", ndim=1)
if rv_size_is_none(size):
size = pt.broadcast_shape(dist, lower, upper)
dist = change_dist_size(dist, new_size=size)
rv_inputs = [
inp if not isinstance(inp.type, RandomType) else pt.random.shared_rng(seed=None)
for inp in dist.owner.inputs
]
graph_inputs = [*rv_inputs, lower, upper]
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs_ = [
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
]
*rv_inputs_, lower_, upper_ = graph_inputs_
rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()
# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
try:
logcdf_lower_, logcdf_upper_ = TruncatedRV._create_logcdf_exprs(
rv_, rv_, lower_, upper_
)
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
uniform_next_rng_, uniform_ = pt.random.uniform(
pt.exp(logcdf_lower_),
pt.exp(logcdf_upper_),
rng=uniform_rng_,
size=rv_.shape,
return_next_rng=True,
)
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs_,
outputs=[truncated_rv_, uniform_next_rng_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
except NotImplementedError:
pass
# Fallback to rejection sampling
# truncated_rv = zeros(rv.shape)
# reject_draws = ones(rv.shape, dtype=bool)
# while any(reject_draws):
# truncated_rv[reject_draws] = draw(rv)[reject_draws]
# reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
# We need to split the rv_inputs on whether they are rng or not, as rng must be updated inside the scan
# TODO: This will be simplified by https://github.com/pymc-devs/pytensor/pull/1968
is_rng_arg = tuple(isinstance(arg.type, RandomType) for arg in rv_inputs_)
len_rng = sum(is_rng_arg)
def loop_fn(truncated_rv, reject_draws, *truncated_args):
rngs = truncated_args[:len_rng]
lower, upper, *other_args = truncated_args[len_rng:]
step_rv_inputs = []
rngs_iter = iter(rngs)
other_args_iter = iter(other_args)
for is_rng in is_rng_arg:
if is_rng:
step_rv_inputs.append(next(rngs_iter))
else:
step_rv_inputs.append(next(other_args_iter))
new_truncated_rv = dist.owner.op.make_node(*step_rv_inputs).default_output()
# Avoid scalar boolean indexing
if truncated_rv.type.ndim == 0:
truncated_rv = new_truncated_rv
else:
truncated_rv = pt.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))
rng_updates = collect_default_updates(new_truncated_rv, must_be_shared=False)
next_rngs = [rng_updates[rng] for rng in rngs]
return (
(truncated_rv, reject_draws, *next_rngs),
until(~pt.any(reject_draws)),
)
truncated_rv_, reject_draws_, *final_rngs = scan(
loop_fn,
outputs_info=[
pt.zeros_like(rv_),
pt.ones_like(rv_, dtype=bool),
*(arg for is_rng, arg in zip(is_rng_arg, rv_inputs_) if is_rng),
],
non_sequences=[
lower_,
upper_,
*(arg for is_rng, arg in zip(is_rng_arg, rv_inputs_) if not is_rng),
],
n_steps=max_n_steps,
strict=True,
return_updates=False,
)
truncated_rv_ = truncated_rv_[-1]
convergence_ = ~pt.any(reject_draws_[-1])
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv_, convergence_
)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs_,
outputs=[truncated_rv_, *final_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
@staticmethod
def _create_logcdf_exprs(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
upper: TensorVariable,
) -> tuple[TensorVariable, TensorVariable]:
"""Create lower and upper logcdf expressions for base_rv.
Uses `value` as a template for broadcasting.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
upper_value = pt.full_like(value, upper, dtype=config.floatX)
lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False)
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf
@staticmethod
def _create_lower_logccdf_expr(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
) -> TensorVariable:
"""Create logccdf expression at lower bound for base_rv.
Uses `value` as a template for broadcasting. This is numerically more
stable than computing log(1 - exp(logcdf)) for distributions that have
a registered logccdf method.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
return logccdf(base_rv, lower_value, warn_rvs=False)
def update(self, node: Apply):
"""Return the update mapping for the internal RNGs.
TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs.
"""
rngs = [inp for inp in node.inputs if isinstance(inp.type, RandomType)]
next_rngs = [out for out in node.outputs if isinstance(out.type, RandomType)]
return dict(zip(rngs, next_rngs))
@singledispatch
def _truncated(op: Op, lower, upper, size, *params):
"""Return the truncated equivalent of another `RandomVariable`."""
raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented")
class TruncationCheck(CheckAndRaise):
"""Implements a check in truncated graphs.
Raises `TruncationError` if the check is not True.
"""
def __init__(self, msg=""):
super().__init__(TruncationError, msg)
def __str__(self):
"""Return a string representation of the object."""
return f"TruncationCheck{{{self.msg}}}"
[docs]
class Truncated(Distribution):
r"""
Truncated distribution.
The pdf of a Truncated distribution is
.. math::
\begin{cases}
0 & \text{for } x < lower, \\
\frac{\text{PDF}(x, dist)}{\text{CDF}(upper, dist) - \text{CDF}(lower, dist)}
& \text{for } lower <= x <= upper, \\
0 & \text{for } x > upper,
\end{cases}
Parameters
----------
dist: unnamed distribution
Univariate distribution created via the `.dist()` API, which will be truncated.
This distribution must be a pure RandomVariable and have a logcdf method
implemented for MCMC sampling.
.. warning:: dist will be cloned, rendering it independent of the one passed as input.
lower: tensor_like of float or None
Lower (left) truncation point. If `None` the distribution will not be left truncated.
upper: tensor_like of float or None
Upper (right) truncation point. If `None`, the distribution will not be right truncated.
max_n_steps: int, defaults 10_000
Maximum number of resamples that are attempted when performing rejection sampling.
A `TruncationError` is raised if convergence is not reached after that many steps.
Returns
-------
truncated_distribution: TensorVariable
Graph representing a truncated `RandomVariable`. A specialized `Op` may be used
if the `Op` of the dist has a dispatched `_truncated` function. Otherwise, a
`SymbolicRandomVariable` graph representing the truncation process, via inverse
CDF sampling (if the underlying dist has a logcdf method), or rejection sampling
is returned.
Examples
--------
.. code-block:: python
with pm.Model():
normal_dist = pm.Normal.dist(mu=0.0, sigma=1.0)
truncated_normal = pm.Truncated("truncated_normal", normal_dist, lower=-1, upper=1)
"""
rv_type = TruncatedRV
rv_op = rv_type.rv_op
[docs]
@classmethod
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
if not (
isinstance(dist, TensorVariable)
and dist.owner is not None
and isinstance(dist.owner.op, RandomVariable | SymbolicRandomVariable)
):
raise ValueError(
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
if (
isinstance(dist.owner.op, SymbolicRandomVariable)
and "[size]" not in dist.owner.op.extended_signature
):
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
# distribution factories like `Censored` and `Mixture` which would have a very complex signature if they
# encapsulated the random components instead of taking them as inputs like they do now.
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}")
if dist.owner.op.ndim_supp > 0:
raise NotImplementedError("Truncation not implemented for multivariate distributions")
check_dist_not_registered(dist)
if lower is None and upper is None:
raise ValueError("lower and upper cannot both be None")
return super().dist([dist, lower, upper, max_n_steps], **kwargs)
@_change_dist_size.register(TruncatedRV)
def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand):
*rv_inputs, lower, upper = truncated_rv.owner.inputs
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
if expand:
new_size = to_tuple(new_size) + tuple(truncated_rv.shape)
return Truncated.rv_op(
untruncated_rv,
lower=lower,
upper=upper,
size=new_size,
max_n_steps=op.max_n_steps,
)
@_support_point.register(TruncatedRV)
def truncated_support_point(op: TruncatedRV, truncated_rv, *inputs):
*rv_inputs, lower, upper = inputs
# recreate untruncated rv and respective support_point
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
untruncated_support_point = support_point(untruncated_rv)
fallback_support_point = pt.switch(
pt.and_(pt.bitwise_not(pt.isinf(lower)), pt.bitwise_not(pt.isinf(upper))),
(upper - lower) / 2, # lower and upper are finite
pt.switch(
pt.isinf(upper),
lower + 1, # only lower is finite
upper - 1, # only upper is finite
),
)
return pt.switch(
pt.and_(pt.ge(untruncated_support_point, lower), pt.le(untruncated_support_point, upper)),
untruncated_support_point, # untruncated support_point is between lower and upper
fallback_support_point,
)
@_default_transform.register(TruncatedRV)
def truncated_default_transform(op, truncated_rv):
# Don't transform discrete truncated distributions
if truncated_rv.type.dtype.startswith("int"):
return None
# Lower and Upper are the arguments -2 and -1
return bounded_cont_transform(op, truncated_rv, bound_args_indices=(-2, -1))
@_logprob.register(TruncatedRV)
def truncated_logprob(op, values, *inputs, **kwargs):
(value,) = values
*rv_inputs, lower, upper = inputs
base_rv_op = op.base_rv_op
base_rv = base_rv_op.make_node(*rv_inputs).default_output()
base_logp = logp(base_rv, value)
lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper)
if base_rv_op.name:
base_logp.name = f"{base_rv_op}_logprob"
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
lognorm = 0
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
elif is_upper_bounded:
lognorm = upper_logcdf
truncated_logp = base_logp - lognorm
if is_lower_bounded:
truncated_logp = pt.switch(value < lower, -np.inf, truncated_logp)
if is_upper_bounded:
truncated_logp = pt.switch(value <= upper, truncated_logp, -np.inf)
if is_lower_bounded and is_upper_bounded:
truncated_logp = check_parameters(
truncated_logp,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)
return truncated_logp
@_logcdf.register(TruncatedRV)
def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
*rv_inputs, lower, upper = inputs
base_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
base_logcdf = logcdf(base_rv, value)
lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper)
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
lognorm = 0
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
elif is_upper_bounded:
lognorm = upper_logcdf
logcdf_numerator = logdiffexp(base_logcdf, lower_logcdf) if is_lower_bounded else base_logcdf
logcdf_trunc = logcdf_numerator - lognorm
if is_lower_bounded:
logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_trunc)
if is_upper_bounded:
logcdf_trunc = pt.switch(value <= upper, logcdf_trunc, 0.0)
if is_lower_bounded and is_upper_bounded:
logcdf_trunc = check_parameters(
logcdf_trunc,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)
return logcdf_trunc
@_truncated.register(NormalRV)
def _truncated_normal(op, lower, upper, size, rng, old_size, mu, sigma):
return TruncatedNormal.dist(
mu=mu,
sigma=sigma,
lower=lower,
upper=upper,
rng=None, # Do not reuse rng to avoid weird dependencies
size=size,
dtype=op.dtype,
)