from __future__ import annotations
from collections.abc import Callable
import jax
import jax.numpy as jnp
from jax import Array
from jax.tree_util import Partial as partial
from .approx import bspline_basis, equidistant_knots
from .util import TransformationSpline
[docs]
class PTMKnots:
"""
Knots for a monotonically increasing PTM spline.
Parameters
----------
a
Left boundary of the core interval.
b
Right boundary of the core interval.
nparam
Number of parameters for the spline.
order
Spline order.
eps
Stretch factor for knot spacing.
Attributes
----------
a
Left boundary of the core interval.
b
Right boundary of the core interval.
nparam
Number of parameters for the spline.
order
Spline order.
knots
Array of spline knots.
step
Step size between knots.
"""
def __init__(
self, a: float, b: float, nparam: int, order: int = 3, eps: float = 0.0
) -> None:
self.a = a
self.b = b
self.nparam = nparam
self.order = order
self.knots = equidistant_knots(
jnp.array([a, b]), order=order, n_param=nparam + 1, eps=eps
)
self.step = jnp.diff(self.knots).mean()
LogIncKnots = PTMKnots # Alias
def sfn(exp_shape):
"""
Compute normalization factor for PTM spline coefficients.
Parameters
----------
exp_shape
Exponentiated shape parameters.
Returns
-------
Normalization factor.
"""
order = 3
p = jnp.shape(exp_shape)[-1] + 1
outer_border = exp_shape[..., jnp.array([0, -1])] / 6
inner_border = 5 * exp_shape[..., jnp.array([1, -2])] / 6
middle = exp_shape[..., 2:-2]
summed_exp_shape = (
outer_border.sum(axis=-1, keepdims=True)
+ inner_border.sum(axis=-1, keepdims=True)
+ middle.sum(axis=-1, keepdims=True)
)
return (1 / (p - order)) * summed_exp_shape
def log_sfn(shape):
"""
Compute log normalization factor for PTM spline coefficients.
Parameters
----------
shape
Shape parameters.
Returns
-------
Log normalization factor.
"""
order = 3
J = jnp.shape(shape)[-1] + 1
a = jnp.full((J - 1,), fill_value=1.0 / 6.0)
a = a.at[-2:].set(0.0)
b = jnp.full((J - 1,), fill_value=2.0 / 3.0)
b = b.at[0].set(0.0)
b = b.at[-1].set(0.0)
c = jnp.full((J - 1,), fill_value=1.0 / 6.0)
c = c.at[:2].set(0.0)
log_w = jnp.log(a + b + c)
log_T = jax.scipy.special.logsumexp(shape + log_w)
return log_T - jnp.log(J - order)
def cumsum_leading_zero(exp_shape: Array) -> Array:
"""
Cumulative sum with a leading zero.
"""
zeros_shape = jnp.shape(exp_shape)[:-1] + (1,)
exp_shape = jnp.concatenate((jnp.zeros(zeros_shape), exp_shape), axis=-1)
return jnp.cumsum(exp_shape, axis=-1)
def normalization_coef(shape: Array, dknots: Array) -> Array:
"""
Construct spline coefficients with average slope one over the domain.
Parameters
----------
shape
Shape parameters.
dknots
Knot differences.
Returns
-------
Spline coefficients.
"""
exp_shape = jnp.exp(shape)
cumsum_exp_shape = cumsum_leading_zero(exp_shape)
coef = (dknots / sfn(exp_shape)) * cumsum_exp_shape
return coef
def normalization_coef_log(shape: Array, dknots: Array) -> Array:
"""
Construct log spline coefficients with average slope one over the domain.
Parameters
----------
shape
Shape parameters.
dknots
Knot differences.
Returns
-------
Log spline coefficients.
"""
log_s = log_sfn(shape)
corrected_coef = shape - log_s + jnp.log(dknots)
return corrected_coef
class PTMCoef:
def __init__(self, knots: Array) -> None:
"""
Initialize PTMCoef functionality with spline knots.
Assumes knots were created for a cubic spline.
Parameters
----------
knots
Spline knot sequence for a cubic spline.
Attributes
----------
knots
Array of spline knots.
k1
Value of knots[3].
B
Basis matrix for cumulative sum.
B0
Basis matrix at zero for cumulative sum.
step
Step size between knots.
"""
self.knots = knots
self.k1 = knots[3]
B = bspline_basis(jnp.atleast_1d(knots[3]), knots, 3)
B0 = bspline_basis(jnp.zeros(1), knots, 3)
S = jnp.tril(jnp.ones((B.shape[-1], B.shape[-1]))) # for cumulative sum
self.B = B @ S
self.B0 = B0 @ S
self.step = jnp.diff(knots).mean()
def _add_intercept_and_exponentiate(self, log_increments, intercept, log_slope):
"""
Add intercept and exponentiate log increments.
"""
exp_coef = jnp.exp(log_increments)
prelim_coef = jnp.concatenate((jnp.zeros(1), exp_coef), axis=-1)
offset = (self.B @ prelim_coef) - self.k1
full_coef = jnp.concatenate((-offset + intercept, exp_coef), axis=-1)
fx_at_zero = self.B0 @ full_coef
coef2 = jnp.exp(log_increments + log_slope)
full_coef2 = jnp.concatenate((-offset + intercept, coef2), axis=-1)
fx_at_zero2 = self.B0 @ full_coef2
diff_at_zero = (fx_at_zero2 - fx_at_zero).squeeze()
coef3 = full_coef2.at[..., 0].set(full_coef2[..., 0] - diff_at_zero)
return coef3
def get_ptm_fn(self) -> Callable[[Array, Array, Array], Array]:
"""
Get function to compute PTM spline coefficients.
"""
nparam = len(self.knots) - 4 - 1
zeros = jnp.zeros((nparam,))
add_intercept_and_exponentiate = jax.vmap(self._add_intercept_and_exponentiate)
def compute_coef(log_increments, intercept, log_slope):
log_increments = jnp.atleast_2d(log_increments + zeros)
intercept = jnp.expand_dims(jnp.atleast_1d(intercept), -1)
log_slope = jnp.expand_dims(jnp.atleast_1d(log_slope), -1)
log_increments_slope_one = jax.vmap(normalization_coef_log, (0, None))(
log_increments, self.step
)
full_coef = add_intercept_and_exponentiate(
log_increments_slope_one, intercept, log_slope
)
return full_coef
return compute_coef
def get_ptm_fn_squeeze(self) -> Callable[[Array, Array, Array], Array]:
"""
Get function to compute PTM spline coefficients and squeeze output.
"""
fn = self.get_ptm_fn()
def compute_coef(log_increments, intercept, log_slope):
return fn(log_increments, intercept, log_slope).squeeze(-2)
return compute_coef
class PTMSpline(TransformationSpline):
"""
PTM spline transformation using given knots.
Parameters
----------
knots
Increasing, equidistant spline knot sequence.
eps
Parameter controlling the sharpness of transition to tail extrapolation. \
Is used to compute ``transition_width = eps * (knots[3] - knots[-4])``, \
where ``transition_width`` indicates the width of the transition interval.
continue_linearly
If True, there is no transition to the identity function. Instead, \
the spline will continue linearly in the tails with the slope fixed \
to the slopes at ``knots[3]`` and ``knots[-4]`` for left and right \
extrapolation, respectively.
Attributes
----------
knots
Array of spline knots.
transition_width
Width of the transition interval.
min_eps
Left boundary for tail transition.
max_eps
Right boundary for tail transition.
"""
def __init__(
self,
knots: Array,
eps: float = 0.1,
continue_linearly: bool = False,
) -> None:
if eps < 1e-6:
raise ValueError(f"{eps=} is < 1e-6; that is numerically unstable.")
super().__init__(knots)
self._compute_coef = jax.jit(
partial(PTMCoef(knots).get_ptm_fn_squeeze(), intercept=0.0, log_slope=0.0)
) # type: ignore
if continue_linearly:
eps = 100000.0
self.transition_width = eps * (self.max_knot - self.min_knot)
self.min_eps = self.min_knot - self.transition_width
self.max_eps = self.max_knot + self.transition_width
def target_slope(coef):
return 1.0
self._target_slope_left = target_slope
self._target_slope_right = target_slope
self._boundaries = jnp.array([self.min_knot, self.max_knot])
# if continue_linearly:
# def slope_at_min_knot(coef):
# return self.bspline.dot_and_deriv_n(self.min_knot, coef)[0]
# def slope_at_max_knot(coef):
# return self.bspline.dot_and_deriv_n(self.max_knot, coef)[0]
# self._target_slope_left = slope_at_min_knot
# self._target_slope_right = slope_at_max_knot
def _left_transition_and_deriv(self, x, coef, value_left, deriv_left):
"""
Compute left transition value and derivative.
"""
poly = x * self.min_knot - 0.5 * x * x
target_slope_left = self._target_slope_left(coef)
unsh = (target_slope_left / self.transition_width) * poly + deriv_left * (
x - poly / self.transition_width
)
x0 = self.min_knot
poly0 = x0 * self.min_knot - 0.5 * x0 * x0
const = value_left - (
(target_slope_left / self.transition_width) * poly0
+ deriv_left * (x0 - poly0 / self.transition_width)
)
value = unsh + const
dist = (self.min_knot - x) / self.transition_width
deriv = (1.0 - dist) * deriv_left + target_slope_left * dist
return value, deriv
def _right_transition_and_deriv(self, x, coef, value_right, deriv_right):
"""
Compute right transition value and derivative.
"""
poly = 0.5 * x * x - x * self.max_knot
target_slope_right = self._target_slope_right(coef)
unsh = (target_slope_right / self.transition_width) * poly + deriv_right * (
x - poly / self.transition_width
)
x0 = self.max_knot
poly0 = 0.5 * x0 * x0 - x0 * self.max_knot
const = value_right - (
(target_slope_right / self.transition_width) * poly0
+ deriv_right * (x0 - poly0 / self.transition_width)
)
value = unsh + const
dist = (x - self.max_knot) / self.transition_width
der = (1.0 - dist) * deriv_right + target_slope_right * dist
return value, der
def _left_tail_and_deriv(self, x, coef, fx_at_linear_start):
"""
Compute left tail value and derivative.
"""
target_slope_left = self._target_slope_left(coef)
val = fx_at_linear_start - target_slope_left * (self.min_eps - x)
return val, target_slope_left
def _right_tail_and_deriv(self, x, coef, fx_at_linear_start):
"""
Compute right tail value and derivative.
"""
target_slope_right = self._target_slope_right(coef)
val = fx_at_linear_start + target_slope_right * (x - self.max_eps)
return val, target_slope_right
def _dot_and_deriv_n_fullbatch(self, x: Array, coef: Array) -> tuple[Array, Array]:
"""
Compute dot product and derivative for batch.
"""
fx_n, deriv_n = self.bspline.dot_and_deriv_n(x, coef)
boundary_values, boundary_derivs = self.bspline.dot_and_deriv_n(
self._boundaries, coef
)
left_transition = partial(
self._left_transition_and_deriv,
value_left=boundary_values[0],
deriv_left=boundary_derivs[0],
)
right_transition = partial(
self._right_transition_and_deriv,
value_right=boundary_values[1],
deriv_right=boundary_derivs[1],
)
left_tail = partial(
self._left_tail_and_deriv,
fx_at_linear_start=left_transition(self.min_eps, coef)[0],
)
right_tail = partial(
self._right_tail_and_deriv,
fx_at_linear_start=right_transition(self.max_eps, coef)[0],
)
def branches(x, fx_n, deriv_n):
def fxderiv(x, coef):
return fx_n, deriv_n
code = jnp.where(
# check most common case first
(x >= self.min_knot) & (x <= self.max_knot),
2,
jnp.where(
x < self.min_eps,
0,
jnp.where(
x < self.min_knot,
1,
jnp.where(x < self.max_eps, 3, 4),
),
),
)
value, deriv = jax.lax.switch(
code,
(
left_tail,
left_transition,
fxderiv,
right_transition,
right_tail,
),
x,
coef,
)
return value, deriv
value, deriv = jax.vmap(branches)(x, fx_n, deriv_n)
return value, deriv