from __future__ import annotations
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any, Literal, Self
import jax
import jax.numpy as jnp
import liesel.goose as gs
import liesel.model as lsl
import numpy as np
import pandas as pd
import plotnine as p9
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from liesel_ptm.gam.roles import Roles as GamRoles
from ..bspline import LogIncKnots, OnionSpline, PTMSpline
from ..cckernel import freeze_parental_submodel
from ..dist import GaussianPseudoTransformationDist, LocScaleTransformationDist
from ..iwls_proposals import (
CholInfo,
GaussianLocCholInfo,
GaussianScaleCholInfo,
ObservedCholInfoOrIdentity,
PTMCholInfoFixed,
)
from ..logprob import FlatLogProb
from ..predictor import (
LocPredictor,
ScalePredictor,
SimplePTMPredictor,
setup_loc_scale,
)
from ..swap_dists import SwapSpec, TemporarilySwapDists, spec_bounded
from ..util.summary import cache_results, summarise_by_samples
from ..var import PTMCoef, ScaleWeibull
Array = Any
KeyArray = Any
InferenceTypes = Any
logger = logging.getLogger(__name__)
HYPERPARAMETER_ROLES = [
"hyperparam",
GamRoles.variance_smooth,
GamRoles.scale_smooth,
]
class PTMDist(lsl.Dist):
"""
Distribution wrapper that builds a location-scale transformation distribution
using a spline-based transformation.
Parameters
----------
knots
Spline knot sequence.
loc
Location variable for the distribution.
scale
Scale variable for the distribution.
shape
Shape variable for the transformation.
centered
If True, use centered parameterization.
scaled
If True, use scaled parameterization.
trafo_lambda
Controls transition sharpness for the transformation spline.
bspline
Which spline variant to use: 'ptm', 'onion', or 'identity'.
trafo_target_slope
How to handle tail slopes for the transformation.
**kwargs
Forwarded to the parent distribution constructor.
Attributes
----------
partial_dist_class
Partial distribution class used to construct per-observation distributions.
"""
def __init__(
self,
knots: Array,
loc: lsl.Var,
scale: lsl.Var,
shape: lsl.Var,
centered: bool = False,
scaled: bool = False,
trafo_lambda: float = 0.1,
bspline: Literal["ptm", "onion", "identity"] = "ptm",
trafo_target_slope: Literal["identity", "continue_linearly"] = "identity",
**kwargs,
) -> None:
match bspline:
case "ptm":
continue_linearly = trafo_target_slope == "continue_linearly"
bspline_inst: PTMSpline | OnionSpline = PTMSpline(
knots=knots, eps=trafo_lambda, continue_linearly=continue_linearly
)
partial_dist_class = partial(
LocScaleTransformationDist,
bspline=bspline_inst,
centered=centered,
scaled=scaled,
)
case "onion":
bspline_inst = OnionSpline(knots)
partial_dist_class = partial(
LocScaleTransformationDist,
bspline=bspline_inst,
centered=centered,
scaled=scaled,
)
case "identity":
partial_dist_class = partial(
GaussianPseudoTransformationDist,
centered=centered,
scaled=scaled,
)
self.partial_dist_class = partial_dist_class
partial_dist_class = partial(self.partial_dist_class, batched=False)
super().__init__(partial_dist_class, loc=loc, scale=scale, coef=shape, **kwargs)
[docs]
class LocScalePTM:
"""
A Penalized Transformation Model for Location and Scale.
Parameters
----------
response
Array of response values.
knots
Array of equidistant knots. Should correspond to the chosen bspline variant. \
See :class:`.PTMKnots` and :class:`.OnionKnots`.
intercepts
How to handle intercepts in the location and scale model parts.
The options are:
- "compute": Intercepts are assumed constant and re-computed any time a value \
in the location or scale model part changes.
- "pseudo_sample": Intercepts are assumed constant and re-computed once in \
every \
MCMC iteration given the current values of the location and scale model \
parts.
- "sample": Intercepts are treated as ordinary parameters and sampled according\
to their inference specification in the arguments \
``loc_intercept_inference`` \
and ``scale_intercept_inference``. This can lead to identification issues.
- "constant": Intercepts are kept constant.
- "sample_mh": Intercepts are sampled with bespoke Metropolis-Hastings \
proposals. Experimental, undocumented.
loc_intercept_inference, scale_intercept_inference
:class:`liesel.goose.MCMCSpec` objects that define MCMC inference for \
intercepts, if ``intercepts="sample"``.
centered, scaled
Whether the transformation distribution should be centered and scaled to \
negate any side-effects the transformation might have on the location and \
scale of the response distribution. Can be used with ``intercepts="sample"``.
See also :class:`.TransformationDist`.
trafo_lambda
Parameter controlling the sharpness of transition to tail extrapolation. \
Is used to compute ``transition_width = eps * (knots[3] - knots[-4])``, \
where ``transition_width`` indicates the width of the transition interval.
Relevant only for ``bspline="ptm"``.
trafo_target_slope
If "continue_linearly", there is no transition to the identity function. \
Instead, \
the spline will continue linearly in the tails with the slope fixed \
to the slopes at the boundaries of the core interval for left and right \
extrapolation, respectively. Relevant only for ``bspline="ptm"``.
bspline
Which B-spline formulation to use. The option ``"onion"`` is experimental.
to_float32
Whether to convert appropriate values in the model to 32-bit floats.
Attributes
----------
intercepts
How to handle intercepts in the location and scale model parts.
knots
Array of knots.
response
Response variable, an instance of :class:`liesel.model.Var`.
graph
The model graph, an instance of :class:`liesel.model.Model`. Only available
after :meth:`.build` has been called.
to_float32
Whether to convert appropriate values in the model to 32-bit floats.
interface
An instance of :class:`liesel.goose.LieselInterface` representing
:attr:`.graph`. Only available after :meth:`.build` has been called.
is_initialized
Boolean, indicating whether the model as been initialized with posterior
modes.
Examples
--------
A basic unconditional model::
import liesel_ptm as ptm
import jax
y = jax.random.normal(jax.random.key(0), (50,))
model = ptm.LocScalePTM.new_ptm(y, a=-4.0, b=4.0, nparam=20)
results = model.run_mcmc(seed=1, warmup=300, posterior=500)
samples = results.get_posterior_samples()
model.plot(samples)
dist = model.init_dist(samples) # initialize a distribution object
A basic linear location-scale model::
import jax
import liesel_ptm as ptm
from liesel_ptm import lin, term
y = jax.random.normal(jax.random.key(0), (50,))
x = jax.random.uniform(jax.random.key(1), (50,))
model = ptm.LocScalePTM.new_ptm(y, a=-4.0, b=4.0, nparam=20)
# location and scale predictors can be filled by adding terms.
xlin = lin(x, xname="x")
model.loc += term.f(xlin, fname="s")
# when adding terms to the scale model part, they are applied additively
# to the log-level automatically
model.scale += term.f(xlin, fname="g")
results = model.run_mcmc(seed=1, warmup=300, posterior=500)
samples = results.get_posterior_samples()
model.plot(samples)
A basic model with one P-spline::
import jax
import liesel_ptm as ptm
from liesel_ptm import term, ps
y = jax.random.normal(jax.random.key(0), (50,))
x = jax.random.uniform(jax.random.key(1), (50,))
model = ptm.LocScalePTM.new_ptm(y, a=-4.0, b=4.0, nparam=20)
xps = ps(x, nbases=20, xname="x")
model.loc += term.f_ig(xps, fname="s")
results = model.run_mcmc(seed=1, warmup=300, posterior=500)
samples = results.get_posterior_samples()
model.plot(samples)
"""
sample_intercepts_under_constant_trafo: bool = True
"""
If True (default), intercepts are sampled under the assumption of an identity
transformation. Only takes effect when the init argument ``intercepts=""sample"``
is used.
"""
def __init__(
self,
response: Array | pd.Series,
knots: Array,
intercepts: Literal[
"compute", "pseudo_sample", "sample", "constant", "sample_mh"
] = "pseudo_sample",
loc_intercept_inference: InferenceTypes = None,
scale_intercept_inference: InferenceTypes = None,
centered: bool = False,
scaled: bool = False,
trafo_lambda: float = 0.1,
trafo_target_slope: Literal["identity", "continue_linearly"] = "identity",
bspline: Literal["ptm", "onion", "identity"] = "ptm",
to_float32: bool = True,
) -> None:
response_name: str = "response"
if isinstance(response, pd.Series):
response = jnp.asarray(response.to_numpy())
if response.var() <= 0.0:
response_val = None
else:
response_val = lsl.Value(response, _name="_response_value_helper")
loc, scale = setup_loc_scale(
loc_intercept=intercepts,
scale_intercept=intercepts,
response_name=response_name,
loc_intercept_inference=loc_intercept_inference,
scale_intercept_inference=scale_intercept_inference,
response_value=response_val,
loc_name="$\\mu$",
scale_name="$\\sigma$",
)
self.intercepts = intercepts
self._loc = loc
self._scale = scale
self._trafo = SimplePTMPredictor.new_sum(name="trafo")
self.knots = knots
dist = PTMDist(
knots=knots,
loc=loc,
scale=scale,
shape=self.trafo,
centered=centered,
scaled=scaled,
trafo_lambda=trafo_lambda,
bspline=bspline,
trafo_target_slope=trafo_target_slope,
)
self._response_value_helper = response_val
self.response = lsl.Var.new_obs(response, dist, name=response_name)
self.graph: lsl.Model | None = None
self.to_float32 = to_float32
self.interface: gs.LieselInterface | None = None
self.is_initialized: bool = False
self._hyperparameter_initial_values: dict[str, Array] = {}
[docs]
@classmethod
def new_ptm(
cls,
response: Array | pd.Series,
nparam: int = 20,
a: float = -4.0,
b: float = 4.0,
tau2_scale: float = 0.5,
trafo_lambda: float = 0.1,
trafo_target_slope: Literal["identity", "continue_linearly"] = "identity",
to_float32: bool = True,
):
"""
Shortcut for convenient model setup.
"""
kernel_kwargs = {"da_target_accept": 0.9, "mm_diag": False, "max_treedepth": 10}
knots = LogIncKnots(a, b, nparam=nparam)
model = cls(
response=response,
knots=knots.knots,
trafo_lambda=trafo_lambda,
trafo_target_slope=trafo_target_slope,
bspline="ptm",
to_float32=to_float32,
)
trafo_scale = ScaleWeibull(
value=1.0,
scale=tau2_scale,
name="$\\tau_\\delta$",
bijector=tfb.Exp(),
inference=gs.MCMCSpec(
gs.NUTSKernel, kernel_group="trafo", kernel_kwargs=kernel_kwargs
),
)
trafo_scale.variance_param.name = "$\\tau^2_\\delta$"
trafo_scale.variance_param.value_node[0].name = "$\\log(\\tau^2_\\delta)$"
trafo0 = PTMCoef.new_rw1_sumzero(
knots=knots.knots, scale=trafo_scale, name="$\\delta$", noncentered=False
)
trafo0.latent_coef.name = "$\\delta_z$"
model.trafo += trafo0
return model
[docs]
@classmethod
def new_gaussian(
cls,
response: Array | pd.Series,
loc_intercept_inference: InferenceTypes = gs.MCMCSpec(gs.IWLSKernel),
scale_intercept_inference: InferenceTypes = gs.MCMCSpec(gs.IWLSKernel),
to_float32: bool = True,
) -> LocScalePTM:
"""
Shortcut for initializing a Gaussian model.
Parameters
----------
response
Array of response values.
loc_intercept_inference, scale_intercept_inference
:class:`liesel.goose.MCMCSpec` objects that define MCMC inference for
intercepts.
to_float32
Whether to convert appropriate values in the model to 32-bit floats.
"""
return cls(
response=response,
intercepts="sample",
loc_intercept_inference=loc_intercept_inference,
scale_intercept_inference=scale_intercept_inference,
knots=jnp.linspace(-3.0, 3.0, 10),
to_float32=to_float32,
centered=False,
scaled=False,
bspline="identity",
)
@property
def loc(self) -> LocPredictor:
"""Location predictor"""
return self._loc
@loc.setter
def loc(self, value):
if value is not self._loc:
raise ValueError("Cannot overwrite .loc attribute.")
self._loc = value
@property
def scale(self) -> ScalePredictor:
"""Location predictor"""
return self._scale
@scale.setter
def scale(self, value):
if value is not self._scale:
raise ValueError("Cannot overwrite .scale attribute.")
self._scale = value
@property
def trafo(self):
"""Predictor for the transformation function."""
return self._trafo
@trafo.setter
def trafo(self, value):
if value is not self._trafo:
raise ValueError("Cannot overwrite .trafo attribute.")
self._trafo = value
def _set_intercept_inference(self) -> Self:
li = self.loc.loc_intercept
si = self.scale.log_scale_intercept
if li.inference is not None:
li.inference = freeze_parental_submodel(
li.inference, of=self.trafo, exclude_roles=["hyperparam"]
)
if si.inference is not None:
si.inference = freeze_parental_submodel(
si.inference, of=self.trafo, exclude_roles=["hyperparam"]
)
return self
[docs]
def build(self) -> Self:
"""Build the model graph."""
if self.graph is not None:
raise ValueError("Graph was already built.")
if self.intercepts == "sample" and self.sample_intercepts_under_constant_trafo:
self._set_intercept_inference()
vars_: list[lsl.Var | lsl.Node] = [self.response]
if self._response_value_helper is not None:
vars_.append(self._response_value_helper)
self.graph = lsl.Model(vars_, to_float32=self.to_float32)
self.interface = gs.LieselInterface(self.graph)
return self
[docs]
def optim(
self,
exclude_roles: list[Literal["hyperparam", "transformation_coef"] | str]
| None = None,
exclude_params: Sequence[str] | None = None,
swap_pairs: Sequence[SwapSpec] | None = None,
stopper: gs.Stopper | None = None,
progress_bar: bool = False,
test_for_positive_definiteness: bool = False,
update_parameters_inplace: bool = False,
**kwargs,
) -> gs.OptimResult:
"""
Optimize model parameters using the selected optimizer.
Parameters
----------
exclude_roles
Roles to exclude from optimization.
exclude_params
Specific parameter names to exclude.
swap_pairs
Swap specifications used to temporarily replace distribution parts during
optimization.
stopper
Optional stopper controlling optimization iterations.
progress_bar
Whether to show a progress bar.
test_for_positive_definiteness
If True, test Fisher information matrices for positive definiteness.
update_parameters_inplace
If True, apply optimized parameters to the model.
Otherwise restore the previous state after optimization.
**kwargs
Forwarded to the underlying optimizer.
Returns
-------
OptimResult
Result of the optimization run.
"""
if self.graph is None:
raise ValueError("Model must be built with .build() first.")
state_before = self.graph.state
exclude_roles = exclude_roles if exclude_roles is not None else []
exclude_params = exclude_params if exclude_params is not None else []
logger.debug(f"Exlcuding roles: {exclude_roles}")
logger.debug(f"Exlcuding params: {exclude_params}")
if swap_pairs is None:
swap_pairs = []
hyp = [
p
for p in self.graph.parameters.values()
if p.role in HYPERPARAMETER_ROLES
]
hyp = [p for p in hyp if p.name not in exclude_params]
for param in hyp:
logger.debug(f"Setting up temporary bounding for {param.name}.")
spec = spec_bounded(
param,
lower_bound=0.05**2,
upper_bound=50000.0,
)
swap_pairs.append(spec)
for param in hyp:
if param not in self._hyperparameter_initial_values:
logger.debug(
f"Saving initial value {param.value} for parameter{param.name}."
)
self._hyperparameter_initial_values[param.name] = param.value
tmp = partial(
TemporarilySwapDists, pairs=swap_pairs, to_float32=self.to_float32
)
with tmp(self.graph) as model_:
params = [
var_.name
for var_ in model_.parameters.values()
if var_.role not in exclude_roles and var_.name not in exclude_params
]
logger.debug(f"Optimizing params: {params}")
i = 0
is_positive_definite = False
maxi = 5
while not is_positive_definite and i < maxi:
logger.debug(f"Optimization step: {i} started")
result = gs.optim_flat(
model_train=model_,
params=params,
model_validation=model_,
stopper=stopper,
progress_bar=progress_bar,
**kwargs,
)
logger.debug(f"Optimization step: {i} done")
model_.state = result.model_state
if not test_for_positive_definiteness:
break
positive_definite_tests = []
for param_name in result.position:
logprob = FlatLogProb(
model=gs.LieselInterface(model_),
model_state=result.model_state,
param_names=[param_name],
)
pos = model_.extract_position(
[param_name], model_state=result.model_state
)
flat_position, _ = jax.flatten_util.ravel_pytree(pos)
if len(flat_position) < 2: # skip scalar parameters
continue
finfo = -logprob.hessian(
flat_position, model_state=result.model_state
)
augmentation = 1e-5 * jnp.eye(jnp.shape(flat_position)[-1])
evals = jnp.linalg.eigvalsh(finfo + augmentation)
eps = jnp.finfo(finfo.dtype).eps
tol = 100 * eps * jnp.linalg.norm(finfo, ord=jnp.inf)
is_pd = jnp.all(evals > tol)
positive_definite_tests.append(is_pd)
logger.debug(f"Tested parameter: {param_name}")
logger.debug(f"Positive definite Fisher info: {is_pd}")
is_positive_definite = all(positive_definite_tests)
logger.debug(
"Number of positive definitive "
f"Fisher Infos: {sum(positive_definite_tests)}"
)
logger.debug(
f"All Fisher Infos positive definite: {is_positive_definite}"
)
i += 1
logger.debug("")
if not update_parameters_inplace:
self.graph.state = state_before
return result
[docs]
def initialize(
self,
exclude_roles: list[Literal["hyperparam", "transformation_coef"] | str]
| None = None,
stopper: gs.Stopper | None = None,
test_for_positive_definiteness: bool = False,
**kwargs,
) -> tuple[gs.OptimResult, gs.OptimResult]:
"""
Two-stage initialization that fits loc-scale, then transformation.
Updates the model state with the resulting estimated parameters.
Parameters
----------
exclude_roles
Roles to exclude from optimization.
stopper
Optional stopper controlling optimization iterations.
test_for_positive_definiteness
If True, check Fisher information matrices for positive definiteness.
**kwargs
Forwarded to the underlying optimization routine.
Returns
-------
tuple
Optimization results for (loc-scale, transformation).
"""
if self.graph is None:
raise ValueError("Model must be built with .build() first.")
# step 1: fit only location-scale part
trafo_submodel = self.graph.parental_submodel(self.trafo)
trafo_params = list(trafo_submodel.parameters)
trafo_params_transformed = [n + "_transformed" for n in trafo_params]
result1 = self.optim(
exclude_roles=exclude_roles,
exclude_params=trafo_params + trafo_params_transformed,
stopper=stopper,
test_for_positive_definiteness=test_for_positive_definiteness,
update_parameters_inplace=True,
**kwargs,
)
# step 2: fit only transformation part
locscale_params = [
name for name in list(self.graph.parameters) if name not in trafo_params
]
locscale_params_transformed = [n + "_transformed" for n in locscale_params]
result2 = self.optim(
exclude_roles=exclude_roles,
exclude_params=locscale_params + locscale_params_transformed,
stopper=stopper,
test_for_positive_definiteness=test_for_positive_definiteness,
update_parameters_inplace=True,
**kwargs,
)
self.is_initialized = True
# return both results
return result1, result2
[docs]
def init_dist(
self,
samples: dict[str, Array],
loc: Array | None = None,
scale: Array | None = None,
newdata: dict[str, Array] | None = None,
) -> LocScaleTransformationDist:
"""
Construct a batched distribution from posterior samples.
Parameters
----------
samples
Posterior samples dict used to build the distribution.
loc, scale
Optional explicit loc/scale arrays; if provided, `newdata` is ignored.
newdata
Optional newdata for prediction when loc/scale are not provided.
Returns
-------
LocScaleTransformationDist
A batched transformation distribution for prediction.
"""
if not self.graph:
raise ValueError("Model must be built with .build() first.")
if self.is_gaussian:
# assuming no samples are present, just use current values.
trafo_samples = self.trafo.value
else:
trafo_samples = self.trafo.predict(samples)
if loc is not None and newdata is not None:
raise ValueError("If loc is not None, newdata is not used.")
if scale is not None and newdata is not None:
raise ValueError("If scale is not None, newdata is not used.")
if loc is None or scale is None:
locscale = self.graph.predict(
samples, predict=[self.loc.name, self.scale.name], newdata=newdata
)
loc_ = locscale[self.loc.name] if loc is None else loc
scale_ = locscale[self.scale.name] if scale is None else scale
else:
loc_ = loc
scale_ = scale
if trafo_samples.ndim > 0:
# protection for the Gaussian case, when trafo_samples is 0.0 (scalar)
trafo_samples = jnp.expand_dims(trafo_samples, -2)
loc_ = jnp.asarray(loc_)
scale_ = jnp.asarray(scale_)
ndim = max(loc_.ndim, scale_.ndim, (trafo_samples.ndim - 1))
if loc_.ndim < ndim:
loc_ = jnp.expand_dims(loc_, -1)
if scale_.ndim < ndim:
scale_ = jnp.expand_dims(scale_, -1)
return self.response.dist_node.partial_dist_class( # type: ignore
loc=loc_, scale=scale_, coef=trafo_samples, batched=True
)
[docs]
def summarise_dist(
self,
samples: dict[str, Array],
loc: Array | None = None,
scale: Array | None = None,
grid: Array | None = None,
newdata: dict[str, Array] | None = None,
) -> dict[str, Array]:
"""
Return summary arrays (z, prob, log_prob, cdf) for a grid of values.
Parameters
----------
samples
Posterior samples dict used to build the distribution.
loc, scale
Optional loc/scale arrays overriding predictions.
grid
Points to evaluate; if None, uses observed response values.
newdata
Optional newdata for prediction when loc/scale are None.
Returns
-------
dict
Keys: 'z', 'prob', 'log_prob', 'cdf' with arrays over samples.
"""
grid_ = grid if grid is not None else self.response.value
dist = self.init_dist(samples, loc=loc, scale=scale, newdata=newdata)
z_samples, _ = dist.transformation_and_logdet(grid_)
log_prob_samples = dist.log_prob(grid_)
prob_samples = dist.prob(grid_)
cdf_samples = dist.cdf(grid_)
return {
"z": z_samples,
"prob": prob_samples,
"log_prob": log_prob_samples,
"cdf": cdf_samples,
}
[docs]
def summarise_trafo_by_samples(
self,
key: KeyArray | int,
grid: Array,
samples: dict[str, Array],
n: int = 100,
) -> pd.DataFrame:
"""
Summarise transformation samples on a grid.
Parameters
----------
key
PRNG key or integer seed used to subsample trajectories.
grid
Points at which the transformation is evaluated.
samples
Posterior samples dictionary used to build the distribution.
n
Number of sampled trajectories to return.
Returns
-------
DataFrame
DataFrame with sampled trajectories and plotting metadata.
"""
key = jax.random.PRNGKey(key) if isinstance(key, int) else key
dist = self.init_dist(samples, loc=0.0, scale=1.0)
z_samples, _ = dist.transformation_and_logdet(grid)
pdf_samples = jnp.exp(dist.log_prob(grid))
cdf_samples = dist.cdf(grid)
z_df = summarise_by_samples(key, z_samples, "z", n=n)
cdf_df = summarise_by_samples(key, cdf_samples, "cdf", n=n)
pdf_df = summarise_by_samples(key, pdf_samples, "pdf", n=n)
df = pd.concat([z_df.z, cdf_df.cdf, pdf_df.pdf], axis=1)
df["index"] = z_df.index
df["obs"] = z_df.obs
df["chain"] = z_df.chain
df["sample"] = z_df["sample"]
df["r"] = np.tile(np.squeeze(grid), n)
return df
@property
def is_gaussian(self) -> bool:
"""Whether the model is Gaussian."""
return len(list(self.trafo.terms.values())) == 0
[docs]
def setup_default_mcmc_kernels(
self,
strategy: Literal[
"iwls_fixed", "iwls_fixed-nuts", "nuts", "iwls-nuts", "iwls-iwls_fixed"
] = "iwls_fixed",
use_fallback_finfos: bool = True,
locscale_kernel_kwargs: dict[str, Any] | None = None,
trafo_kernel_kwargs_nuts: dict[str, Any] | None = None,
trafo_kernel_kwargs_iwls: dict[str, Any] | None = None,
override_existing_inference_on_locscale: bool = False,
override_existing_inference_on_trafo: bool = False,
jitter_dist: tfd.Distribution | None = None,
) -> LocScalePTM:
"""
Configure default MCMC kernels for model parts.
Parameters
----------
strategy
Kernel selection strategy for loc/scale/trafo parts. The strategies are:
- "iwls_fixed": Metropolis-Hastings with proposals generated according to
the iteratively re-weighted least squares kernel. The Fisher information
matrices used here are fixed to the observed Fisher information
matrices at initial estimates of the posterior modes.
- "iwls_fixed-nuts": Uses "iwls_fixed" in the location and scale model parts
and a No-U-turn sampler (NUTS) for the parameters of the transformation
function.
- "nuts": Uses NUTS for location, scale, and transformation. Does not
scale well with increasing sample size.
- "iwls-nuts": Like "iwls_fixed-nuts", but uses expected Fisher information
matrices derived under a Gaussian assumption for the response as an
approximation to generate proposals.
None of the strategies set up MCMC kernels for hyperparameters like
smoothing parameters in the location and scale model parts, these should
be specified manually.
If the random walk variance of the transformation function is transformed
to the real line with a bijector, the default scheme will always set up a
NUTS kernel for this parameter.
use_fallback_finfos
If True, fall back to Gaussian Fisher information matrices in
location and scale model parts when needed.
locscale_kernel_kwargs, trafo_kernel_kwargs
Optional keyword arguments forwarded to kernel constructors.
override_existing_inference_on_locscale, override_existing_inference_on_trafo
If True, existing inference specifications are overridden.
Returns
-------
The model with inference specs set up.
"""
jitter_dist = (
tfd.Normal(loc=0.0, scale=1.0) if jitter_dist is None else jitter_dist
)
if strategy not in [
"iwls_fixed",
"iwls_fixed-nuts",
"nuts",
"iwls-nuts",
"iwls-iwls_fixed",
]:
raise ValueError(f"Unknown strategy {strategy=}.")
if "iwls_fixed" in strategy and not self.is_initialized:
raise RuntimeError(
f"Cannot setup MCMC scheme '{strategy}' if the model is not "
"initialized. Please run .initialize() first."
)
locscale_kernel_kwargs = (
locscale_kernel_kwargs if locscale_kernel_kwargs is not None else {}
)
trafo_kernel_kwargs_nuts = (
trafo_kernel_kwargs_nuts if trafo_kernel_kwargs_nuts is not None else {}
)
trafo_kernel_kwargs_iwls = (
trafo_kernel_kwargs_iwls if trafo_kernel_kwargs_iwls is not None else {}
)
iwls_locscale_kernel_kwargs = (
locscale_kernel_kwargs
if locscale_kernel_kwargs
else {
"initial_step_size": 1.0,
"da_target_accept": 0.5,
}
)
for term in self.loc.terms.values():
if term is self.loc.loc_intercept and self.loc.loc_intercept is not None:
continue
if (
not override_existing_inference_on_locscale
and term.coef.inference is not None
):
logger.debug(
f"Did not set up kernel for {term.coef.name}, "
f"because {term.coef.inference=} is not None."
)
continue
if strategy == "nuts":
term.coef.inference = gs.MCMCSpec(
gs.NUTSKernel,
kernel_kwargs=locscale_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up NUTS kernel for {term.coef.name}")
continue
# iwls cases
if self.interface is None:
raise ValueError(f"{self.interface=} must not be None.")
if self.is_gaussian or strategy in ["iwls-nuts", "iwls-iwls_fixed"]:
# this fails if term is a linear term or for another reason has
# no .scale attribute.
try:
cinfo_class = GaussianLocCholInfo
cinfo = cinfo_class.from_smooth(
term, model=self.interface, n=int(self.response.value.size)
)
except AttributeError:
# this will fall back to an ordinary IWLS kernel using the observed
# fisher information
# But if there are NaNs in the fisher information's Choleksy
# decomposition, it will use an identity matrix instead.
cinfo = ObservedCholInfoOrIdentity.from_smooth(term, self.interface)
term.coef.inference = gs.MCMCSpec(
gs.IWLSKernel,
kernel_kwargs={"chol_info_fn": cinfo.chol_info}
| iwls_locscale_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up IWLS kernel for {term.coef.name}")
continue
else:
cinfo = CholInfo.from_smooth(term, model=self.interface)
if cinfo.nan_in_cholesky_of_unprocessed_finfo:
if use_fallback_finfos:
try:
logger.warning(
"NaNs von in the Cholesky decomposition of the "
"Fisher information matrix at current values for "
f"term {term.name}. Falling back to the expected Fisher "
"information of a Gaussian model. This can lead to less "
"efficient sampling. Consider running "
".initialize() for longer."
)
cinfo = GaussianLocCholInfo.from_smooth(
term, model=self.interface, n=int(self.response.value.size)
)
except Exception:
logger.exception("Failed to use Gaussian fallback. Continuing.")
logger.warning(
"NaNs von in the Cholesky decomposition of the "
"Fisher information matrix at current values for "
f"term {term.name}. Falling back to an augmented Fisher "
"information to ensure positive definitness. This is likely "
"to lead to less efficient sampling. Consider running "
".initialize() for longer."
)
term.coef.inference = gs.MCMCSpec(
gs.IWLSKernel,
kernel_kwargs={"chol_info_fn": cinfo.chol_info}
| iwls_locscale_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up IWLS kernel for {term.coef.name}")
for term in self.scale.terms.values():
if (
term is self.scale.log_scale_intercept
and self.scale.log_scale_intercept is not None
):
continue
if (
not override_existing_inference_on_locscale
and term.coef.inference is not None
):
logger.debug(
f"Did not set up kernel for {term.coef.name}, "
f"because {term.coef.inference=} is not None."
)
continue
if strategy == "nuts":
term.coef.inference = gs.MCMCSpec(
gs.NUTSKernel,
kernel_kwargs=locscale_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up NUTS kernel for {term.coef.name}")
continue
# iwls cases
if self.interface is None:
raise ValueError(f"{self.interface=} must not be None.")
if self.is_gaussian or strategy == "iwls-nuts":
# this fails if term is a linear term or for another reason has
# no .scale attribute.
try:
cinfo = GaussianScaleCholInfo.from_smooth(
term, model=self.interface, n=int(self.response.value.size)
)
except AttributeError:
# this will fall back to an ordinary IWLS kernel using the observed
# fisher information
# But if there are NaNs in the fisher information's Choleksy
# decomposition, it will use an identity matrix instead.
cinfo = ObservedCholInfoOrIdentity.from_smooth(term, self.interface)
term.coef.inference = gs.MCMCSpec(
gs.IWLSKernel,
kernel_kwargs={"chol_info_fn": cinfo.chol_info}
| iwls_locscale_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up IWLS kernel for {term.coef.name}")
continue
else:
cinfo = CholInfo.from_smooth(term, model=self.interface)
if cinfo.nan_in_cholesky_of_unprocessed_finfo:
if use_fallback_finfos:
try:
logger.warning(
"NaNs von in the Cholesky decomposition of the "
"Fisher information matrix at current values for "
f"term {term.name}. Falling back to the expected Fisher "
"information of a Gaussian model. This can lead to less "
"efficient sampling. Consider running "
".initialize() for longer."
)
cinfo = GaussianScaleCholInfo.from_smooth(
term, model=self.interface, n=int(self.response.value.size)
)
except Exception:
logger.exception("Failed to use Gaussian fallback. Continuing.")
logger.warning(
"NaNs von in the Cholesky decomposition of the "
"Fisher information matrix at current values for "
f"term {term.name}. Falling back to an augmented Fisher "
"information to ensure positive definitness. This is likely "
"to lead to less efficient sampling. Consider running "
".initialize() for longer."
)
term.coef.inference = gs.MCMCSpec(
gs.IWLSKernel,
kernel_kwargs={"chol_info_fn": cinfo.chol_info}
| iwls_locscale_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up IWLS kernel for {term.coef.name}")
trafo_terms = list(self.trafo.terms.values())
if not trafo_terms:
return self
trafo_coef = list(self.trafo.terms.values())[0]
trafo_scale = trafo_coef.scale # type: ignore
trafo_var = trafo_scale.variance_param # type: ignore
if trafo_kernel_kwargs_nuts:
nuts_kernel_kwargs = trafo_kernel_kwargs_nuts
else:
nuts_kernel_kwargs = {
"da_target_accept": 0.9,
"mm_diag": False,
"max_treedepth": 10,
}
if trafo_kernel_kwargs_iwls:
iwls_kernel_kwargs = trafo_kernel_kwargs_iwls
else:
iwls_kernel_kwargs = {
"initial_step_size": 1.0,
"da_target_accept": 0.5,
}
if trafo_var.weak:
if (
trafo_var.value_node[0].inference is None
or override_existing_inference_on_trafo
):
trafo_var.value_node[0].inference = gs.MCMCSpec(
gs.NUTSKernel,
kernel_group="trafo",
kernel_kwargs=nuts_kernel_kwargs,
)
logger.debug(f"Set up NUTS kernel for {trafo_var.value_node[0].name}")
else:
logger.debug(
f"Did not set up kernel for {trafo_var.value_node[0].name}, "
f"because {trafo_var.value_node[0].inference=} is not None."
)
else:
logger.warning(
f"Did not set up kernel for {trafo_var.name}, "
f"because it is strong. It has inference {trafo_var.inference=}."
)
if (
trafo_coef.inference is not None
and not override_existing_inference_on_trafo
):
logger.debug(
f"Did not set up kernel for {trafo_coef.name}, "
f"because {trafo_coef.inference=} is not None."
)
return self
match strategy:
case "iwls_fixed" | "iwls-iwls_fixed":
if self.interface is None:
raise ValueError(f"{self.interface=} must not be None.")
cinfo = PTMCholInfoFixed.from_coef(
coef=trafo_coef,
model=self.interface, # type: ignore
)
if cinfo.nan_in_cholesky_of_unprocessed_finfo:
logger.warning(
"NaNs von in the Cholesky decomposition of the "
"Fisher information matrix at current values for "
f"term {trafo_coef.name}. Falling back to an augmented Fisher "
"information to ensure positive definitness. This is likely "
"to lead to less efficient sampling. Consider running "
".initialize() for longer."
)
trafo_coef.latent_coef.inference = gs.MCMCSpec(
gs.IWLSKernel,
kernel_kwargs={"chol_info_fn": cinfo.chol_info}
| iwls_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(
"Set up ILWS kernel with fixed Fisher "
f"information for {trafo_coef.name}"
)
case "iwls_fixed-nuts":
trafo_coef.latent_coef.inference = gs.MCMCSpec(
gs.NUTSKernel,
kernel_group="trafo",
kernel_kwargs=nuts_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up NUTS kernel for {trafo_coef.name}")
case "nuts":
trafo_coef.latent_coef.inference = gs.MCMCSpec(
gs.NUTSKernel,
kernel_group="trafo",
kernel_kwargs=nuts_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up NUTS kernel for {trafo_coef.name}")
case "iwls-nuts":
trafo_coef.latent_coef.inference = gs.MCMCSpec(
gs.NUTSKernel,
kernel_group="trafo",
kernel_kwargs=nuts_kernel_kwargs,
jitter_dist=jitter_dist,
)
logger.debug(f"Set up NUTS kernel for {trafo_coef.name}")
self.show_mcmc()
return self
[docs]
def show_mcmc(self) -> None:
"""
Logs the current MCMC configuration. If you do not see any output, you need
to set up logging::
import logging
logger = logging.getLogger("liesel_ptm")
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
"""
if self.graph is None:
raise ValueError("Model must be built with .build() first.")
for param in self.graph.parameters.values():
spec = param.inference
if spec is None:
logger.info(f"MCMC Setup for {param.name}: None.")
continue
if not isinstance(spec, gs.MCMCSpec):
logger.info(f"Inference for {param.name}: {spec}.")
continue
logger.info(
f"MCMC Setup for {param.name}: {spec.kernel} in "
f"group '{spec.kernel_group}'."
)
[docs]
def run_mcmc(
self,
seed: int,
warmup: int,
posterior: int,
num_chains: int = 4,
fast_warmup: float = 0.5,
thinning_posterior: int = 1,
thinning_warmup: int = 1,
warm_start: bool = True,
which: str | None = None,
strategy: Literal[
"iwls-nuts",
"iwls_fixed",
"iwls_fixed-nuts",
"nuts",
"iwls-iwls_fixed",
"manual",
] = "iwls-nuts",
cache_path: str | Path | None = None,
apply_jitter: bool = False,
**initialization_kwargs,
) -> gs.SamplingResults:
"""Run MCMC sampling and return sampling results.
Parameters
----------
seed, warmup, posterior
MCMC scheduling parameters: seed and durations.
num_chains
Number of parallel chains to run.
strategy
Which kernel strategy to use for sampling. See
:meth:`.setup_default_mcmc_kernels`.
cache_path
If provided, load/save cached sampling results.
apply_jitter
Whether to apply initial jitter to chain initialisations. Only has an
effect if jittering is specified in the :class:`liesel.goose.MCMCSpec`
for any one variable. Think of this rather as an off-switch than an
on-switch.
warm_start
If True, the model will be initialized by finding posterior modes via
:meth:`.initialize`.
**initialization_kwargs
Forwarded to :meth:`.initialize` when ``warm_start`` is True.
Returns
-------
SamplingResults
The :class:`liesel.goose.SamplingResults` sampling results object
containing chains and diagnostics.
"""
if self.graph is None:
self.build()
if cache_path is not None:
fp = Path(cache_path)
if fp.exists():
return gs.engine.SamplingResults.pkl_load(fp)
if warm_start:
self.initialize(**initialization_kwargs)
if strategy != "manual":
self.setup_default_mcmc_kernels(strategy=strategy)
if apply_jitter:
for param_name in self._hyperparameter_initial_values:
old = self._hyperparameter_initial_values[param_name]
logger.debug(f"Restoring initial value for '{param_name}' to {old}.")
if self.graph is None:
raise ValueError("Must build graph first")
current = self.graph.vars[param_name].value
self.graph.vars[param_name].value = jnp.array(
old, dtype=jnp.asarray(current).dtype
)
eb = gs.LieselMCMC(self.graph, which=which).get_engine_builder( # type: ignore
seed=seed, num_chains=num_chains, apply_jitter=apply_jitter
)
if self.interface is None:
raise ValueError(f"{self.interface=} must not be None.")
eb.set_model(self.interface)
fast_warmup_duration = fast_warmup * warmup
init_duration = int(fast_warmup_duration / 2)
term_duration = init_duration
slow_warmup_duration = warmup - init_duration - term_duration
warmup = slow_warmup_duration + init_duration + term_duration
epochs = gs.stan_epochs(
warmup_duration=warmup,
posterior_duration=posterior,
thinning_posterior=thinning_posterior,
thinning_warmup=thinning_warmup,
init_duration=init_duration,
term_duration=term_duration,
)
eb.set_epochs(epochs)
if cache_path is not None:
results = cache_results(eb, filename=cache_path)
else:
engine = eb.build()
engine.sample_all_epochs()
results = engine.get_results()
return results
[docs]
def plot_qq(
self,
samples: dict[str, Array],
) -> p9.ggplot:
"""Produce a QQ-plot comparing transformed r and its Gaussian reference.
Parameters
----------
samples
Posterior samples dict used to build the distribution.
Returns
-------
ggplot
A ggplot object with QQ comparison of r and h(r) where applicable.
"""
dist = self.init_dist(samples)
r_samples, _ = dist.transformation_and_logdet_parametric(self.response.value)
z_samples, _ = dist.transformation_and_logdet_spline(r_samples)
r_summary = gs.SamplesSummary.from_array(r_samples).to_dataframe()
r_summary["variable"] = "r"
if jnp.asarray(self.trafo.value).ndim == 0:
summary = r_summary
else:
z_summary = gs.SamplesSummary.from_array(z_samples).to_dataframe()
z_summary["variable"] = "h(r)"
summary = pd.concat((r_summary, z_summary), axis=0)
p = (
p9.ggplot()
+ p9.geom_abline(color="black")
+ p9.geom_qq(p9.aes(sample="mean", color="variable"), data=summary)
# + p9.geom_rug(p9.aes(x=self.knots), sides="b")
+ p9.labs(
title="QQ plot of posterior average r and h(r)",
x="Theoretical Quantile",
y="Observed Quantile",
)
)
return p
[docs]
def plot_trafo(
self,
samples: dict[str, Array],
grid: Array | None = None,
ci_quantiles: tuple[float, float] | None = (0.05, 0.95),
hdi_prob: float | None = None,
show_n_samples: int | None = 50,
seed: int | KeyArray = 1,
) -> p9.ggplot:
"""Plot the posterior mean and credible bands of the transformation h(r).
Parameters
----------
samples
Posterior samples dict used to build the distribution.
grid
Points at which to evaluate h(r); if None, uses a grid over responses.
ci_quantiles
Credible interval quantiles to display as ribbon.
hdi_prob
Optional highest-density-interval probability to annotate.
show_n_samples
Number of sampled trajectories to overlay.
seed
RNG seed for subsampling trajectories.
Returns
-------
ggplot
A ggplot object of the transformation with credible bands.
"""
dist = self.init_dist(samples)
r_train = dist.transformation_and_logdet_parametric(self.response.value)[0]
grid_ = (
grid
if grid is not None
else jnp.linspace(min(r_train.min(), -4.0), max(r_train.max(), 4.0), 300)
)
dist = self.init_dist(samples, loc=0.0, scale=1.0)
z_samples, _ = dist.transformation_and_logdet(grid_)
while z_samples.ndim < 3:
z_samples = jnp.expand_dims(z_samples, 0)
ci_quantiles_ = (0.05, 0.95) if ci_quantiles is None else ci_quantiles
hdi_prob_ = 0.9 if hdi_prob is None else hdi_prob
z_summary = gs.SamplesSummary.from_array(
z_samples, quantiles=ci_quantiles_, hdi_prob=hdi_prob_
).to_dataframe()
z_summary["r"] = grid_
p = (
p9.ggplot()
+ p9.labs(
title="Transformation function h(r)",
subtitle="Dotted: Identity function for reference",
x="r",
y="h(r)",
)
+ p9.geom_abline(linetype="dotted")
)
if ci_quantiles is not None:
p = p + p9.geom_ribbon(
p9.aes(
"r",
ymin=f"q_{str(ci_quantiles[0])}",
ymax=f"q_{str(ci_quantiles[1])}",
),
fill="#56B4E9",
alpha=0.5,
data=z_summary,
)
if hdi_prob is not None:
p = p + p9.geom_line(
p9.aes("r", "hdi_low"),
linetype="dashed",
data=z_summary,
)
p = p + p9.geom_line(
p9.aes("r", "hdi_high"),
linetype="dashed",
data=z_summary,
)
if jnp.asarray(self.trafo.value).ndim == 0:
show_n_samples = 0
if show_n_samples is not None and show_n_samples > 0:
key = jax.random.key(seed) if isinstance(seed, int) else seed
summary_samples_df = self.summarise_trafo_by_samples(
key=key, grid=grid_, samples=samples, n=show_n_samples
)
p = p + p9.geom_line(
p9.aes("r", "z", group="sample"),
color="grey",
data=summary_samples_df,
alpha=0.3,
)
p = p + p9.geom_line(
p9.aes("r", "mean"), data=z_summary, size=1.3, color="blue"
)
return p
[docs]
def plot_r_density(
self,
samples: dict[str, Array],
grid: Array | None = None,
ci_quantiles: tuple[float, float] | None = (0.05, 0.95),
hdi_prob: float | None = None,
show_n_samples: int | None = 50,
seed: int | KeyArray = 1,
) -> p9.ggplot:
"""Plot the posterior density of the transformed variable r.
Parameters
----------
samples
Posterior samples dict used to build the distribution.
grid
Points at which to evaluate the density; if None, uses a response grid.
ci_quantiles
Credible interval quantiles for ribbons.
hdi_prob
Optional HDI probability to annotate.
show_n_samples
Number of sampled densities to overlay.
seed
RNG seed for subsampling trajectories.
Returns
-------
ggplot
A ggplot object of the density with credible bands.
"""
dist = self.init_dist(samples)
r_train = dist.transformation_and_logdet_parametric(self.response.value)[0]
grid_ = (
grid
if grid is not None
else jnp.linspace(min(r_train.min(), -4.0), max(r_train.max(), 4.0), 300)
)
dist = self.init_dist(samples, loc=0.0, scale=1.0)
prob_samples = dist.prob(grid_)
while prob_samples.ndim < 3:
prob_samples = jnp.expand_dims(prob_samples, 0)
ci_quantiles_ = (0.05, 0.95) if ci_quantiles is None else ci_quantiles
hdi_prob_ = 0.9 if hdi_prob is None else hdi_prob
prob_summary = gs.SamplesSummary.from_array(
prob_samples, quantiles=ci_quantiles_, hdi_prob=hdi_prob_
).to_dataframe()
prob_summary["r"] = grid_
p = p9.ggplot() + p9.labs(
title="Transformation density $f_R(r)$",
subtitle="Dotted: Standard Gaussian PDF for reference",
x="r",
y="$f_R(r)$",
)
pdf_norm = tfd.Normal(loc=0.0, scale=1.0).prob(grid_)
p = p + p9.geom_line(p9.aes(grid_, pdf_norm), linetype="dotted")
if ci_quantiles is not None:
p = p + p9.geom_ribbon(
p9.aes(
"r",
ymin=f"q_{str(ci_quantiles[0])}",
ymax=f"q_{str(ci_quantiles[1])}",
),
fill="#56B4E9",
alpha=0.5,
data=prob_summary,
)
if hdi_prob is not None:
p = p + p9.geom_line(
p9.aes("r", "hdi_low"),
linetype="dashed",
data=prob_summary,
)
p = p + p9.geom_line(
p9.aes("r", "hdi_high"),
linetype="dashed",
data=prob_summary,
)
if jnp.asarray(self.trafo.value).ndim == 0:
show_n_samples = 0
if show_n_samples is not None and show_n_samples > 0:
key = jax.random.key(seed) if isinstance(seed, int) else seed
summary_samples_df = self.summarise_trafo_by_samples(
key=key, grid=grid_, samples=samples, n=show_n_samples
)
p = p + p9.geom_line(
p9.aes("r", "pdf", group="sample"),
color="grey",
data=summary_samples_df,
alpha=0.3,
)
p = p + p9.geom_line(
p9.aes("r", "mean"), data=prob_summary, size=1.3, color="blue"
)
return p
[docs]
def plot_r_cdf(
self,
samples: dict[str, Array],
grid: Array | None = None,
ci_quantiles: tuple[float, float] | None = (0.05, 0.95),
hdi_prob: float | None = None,
show_n_samples: int | None = 50,
seed: int | KeyArray = 1,
) -> p9.ggplot:
"""Plot the posterior CDF of the transformed variable r.
Parameters
----------
samples
Posterior samples dict used to build the distribution.
grid
Points at which to evaluate the CDF; if None, uses a response grid.
ci_quantiles
Credible interval quantiles for ribbons.
hdi_prob
Optional HDI probability to annotate.
show_n_samples
Number of sampled CDF trajectories to overlay.
seed
RNG seed for subsampling trajectories.
Returns
-------
ggplot
A ggplot object of the CDF with credible bands.
"""
dist = self.init_dist(samples)
r_train = dist.transformation_and_logdet_parametric(self.response.value)[0]
grid_ = (
grid
if grid is not None
else jnp.linspace(min(r_train.min(), -4.0), max(r_train.max(), 4.0), 300)
)
dist = self.init_dist(samples, loc=0.0, scale=1.0)
cdf_samples = dist.cdf(grid_)
while cdf_samples.ndim < 3:
cdf_samples = jnp.expand_dims(cdf_samples, 0)
ci_quantiles_ = (0.05, 0.95) if ci_quantiles is None else ci_quantiles
hdi_prob_ = 0.9 if hdi_prob is None else hdi_prob
cdf_summary = gs.SamplesSummary.from_array(
cdf_samples, quantiles=ci_quantiles_, hdi_prob=hdi_prob_
).to_dataframe()
cdf_summary["r"] = grid_
p = p9.ggplot() + p9.labs(
title="Transformation CDF $F_R(r)$",
subtitle="Dotted: Standard Gaussian CDF for reference",
x="r",
y="$F_R(r)$",
)
pdf_norm = tfd.Normal(loc=0.0, scale=1.0).cdf(grid_)
p = p + p9.geom_line(p9.aes(grid_, pdf_norm), linetype="dotted")
if ci_quantiles is not None:
p = p + p9.geom_ribbon(
p9.aes(
"r",
ymin=f"q_{str(ci_quantiles[0])}",
ymax=f"q_{str(ci_quantiles[1])}",
),
fill="#56B4E9",
alpha=0.5,
data=cdf_summary,
)
if hdi_prob is not None:
p = p + p9.geom_line(
p9.aes("r", "hdi_low"),
linetype="dotted",
data=cdf_summary,
)
p = p + p9.geom_line(
p9.aes("r", "hdi_high"),
linetype="dotted",
data=cdf_summary,
)
if jnp.asarray(self.trafo.value).ndim == 0:
show_n_samples = 0
if show_n_samples is not None and show_n_samples > 0:
key = jax.random.key(seed) if isinstance(seed, int) else seed
summary_samples_df = self.summarise_trafo_by_samples(
key=key, grid=grid_, samples=samples, n=show_n_samples
)
p = p + p9.geom_line(
p9.aes("r", "cdf", group="sample"),
color="grey",
data=summary_samples_df,
alpha=0.3,
)
p = p + p9.geom_line(
p9.aes("r", "mean"), data=cdf_summary, size=1.3, color="blue"
)
return p
[docs]
def plot(
self,
samples: dict[str, Array],
grid: Array | None = None,
ci_quantiles: tuple[float, float] | None = (0.05, 0.95),
hdi_prob: float | None = None,
show_n_samples: int | None = 50,
seed: int | KeyArray = 1,
show: bool = True,
) -> tuple[p9.ggplot, p9.ggplot, p9.ggplot, p9.ggplot]:
"""Produce a set of diagnostic plots (qq, trafo, pdf, cdf).
Parameters
----------
samples
Posterior samples dict used to build the distributions.
grid, ci_quantiles, hdi_prob
Plotting options forwarded to individual plot functions.
show_n_samples, seed
Controls subsampled trajectories overlay.
show
If True, call `.show()` on each plot before returning.
Returns
-------
tuple
Tuple of ggplot objects in order: (qq, trafo, pdf, cdf).
"""
qq = self.plot_qq(samples)
trafo = self.plot_trafo(
samples,
grid=grid,
ci_quantiles=ci_quantiles,
hdi_prob=hdi_prob,
show_n_samples=show_n_samples,
seed=seed,
)
pdf = self.plot_r_density(
samples,
grid=grid,
ci_quantiles=ci_quantiles,
hdi_prob=hdi_prob,
show_n_samples=show_n_samples,
seed=seed,
)
cdf = self.plot_r_cdf(
samples,
grid=grid,
ci_quantiles=ci_quantiles,
hdi_prob=hdi_prob,
show_n_samples=show_n_samples,
seed=seed,
)
if show:
qq.show()
trafo.show()
pdf.show()
cdf.show()
return qq, trafo, pdf, cdf