Source code for liesel_ptm.bsplines

from __future__ import annotations

import logging
from collections.abc import Callable
from functools import partial

import jax
import jax.numpy as jnp

from .custom_types import Array
from .liesel_internal import splines

kn = splines.create_equidistant_knots

logger = logging.getLogger(__name__)


[docs] @partial(jax.jit, static_argnums=2) @partial(jnp.vectorize, excluded=(1, 2), signature="(n)->(n,p)") def bspline_basis(x, knots, order): """ Vectorized B-spline basis function evaluation. Parameters ---------- x Input array. knots Array of knots. order Order of the spline (``order=3`` for a cubic spline). Returns ------- B-spline basis matrix. """ min_knot = knots[order] max_knot = knots[-(order + 1)] basis = splines.build_design_matrix_b_spline(x, knots, order) mask = jnp.logical_or(x < min_knot, x > max_knot) mask = jnp.expand_dims(mask, -1) return jnp.where(mask, 0.0, basis)
@partial(jax.jit, static_argnums=2) @partial(jnp.vectorize, excluded=(1, 2), signature="(n)->(n,p)") def bspline_basis_deriv(x, knots, order): min_knot = knots[order] max_knot = knots[-(order + 1)] basis = splines.build_design_matrix_b_spline(x, knots[1:-1], order - 1) dknots = jnp.diff(knots).mean() D = jnp.diff(jnp.identity(jnp.shape(knots)[-1] - order - 1)).T basis_grad = basis @ (D / dknots) mask = jnp.logical_or(x < min_knot, x > max_knot) mask = jnp.expand_dims(mask, -1) return jnp.where(mask, 0.0, basis_grad) @partial(jax.jit, static_argnums=2) @partial(jnp.vectorize, excluded=(1, 2), signature="(n)->(n,p)") def bspline_basis_deriv2(x, knots, order): min_knot = knots[order] max_knot = knots[-(order + 1)] basis = splines.build_design_matrix_b_spline(x, knots[2:-2], order - 2) dknots = jnp.diff(knots).mean() D = jnp.diff(jnp.identity(jnp.shape(knots)[-1] - order - 1)).T basis_grad = basis @ D[1::, 1:] @ (D / (dknots**2)) mask = jnp.logical_or(x < min_knot, x > max_knot) mask = jnp.expand_dims(mask, -1) return jnp.where(mask, 0.0, basis_grad) class BSplineApprox: def __init__( self, knots: Array, order: int, approx: bool = True, ngrid: int = 1000, postmultiply_by: Array | None = None, ) -> None: self.knots = knots self.dknots = jnp.mean(jnp.diff(knots)) self.order = order self.nparam = jnp.shape(knots)[0] - order - 1 self.min_knot = self.knots[order] self.max_knot = self.knots[-(order + 1)] grid = jnp.linspace(self.min_knot, self.max_knot, ngrid) self.step = (self.max_knot - self.min_knot) / ngrid prepend = jnp.array([self.min_knot - self.step]) append = jnp.array([self.max_knot + self.step]) self.grid = jnp.concatenate((prepend, grid, append)) Z = jnp.eye(self.nparam) if postmultiply_by is None else postmultiply_by self.postmultiply_by = Z basis_grids = self._compute_basis_and_deriv2(self.grid) self.basis = basis_grids[0] self.basis_deriv = basis_grids[1] self.basis_deriv2 = basis_grids[2] self.approx = approx if self.approx: self.get_basis = self._approx_basis self.get_basis_and_deriv = self._approx_basis_and_deriv self.get_basis_deriv_and_deriv2 = self._approx_basis_deriv_and_deriv2 else: self.get_basis = self._compute_basis # type: ignore self.get_basis_and_deriv = self._compute_basis_and_deriv # type: ignore self.get_basis_deriv_and_deriv2 = ( self._compute_basis_and_deriv2 # type: ignore ) def _compute_basis(self, x: Array) -> Array: return bspline_basis(x, self.knots, self.order) @ self.postmultiply_by def _compute_basis_and_deriv(self, x: Array) -> tuple[Array, Array]: basis = bspline_basis(x, self.knots, self.order) @ self.postmultiply_by deriv = bspline_basis_deriv(x, self.knots, self.order) @ self.postmultiply_by return basis, deriv def _compute_basis_and_deriv2(self, x: Array) -> tuple[Array, Array, Array]: basis = bspline_basis(x, self.knots, self.order) @ self.postmultiply_by deriv = bspline_basis_deriv(x, self.knots, self.order) @ self.postmultiply_by deriv2 = bspline_basis_deriv2(x, self.knots, self.order) @ self.postmultiply_by return basis, deriv, deriv2 @partial(jax.jit, static_argnums=0) @partial(jnp.vectorize, excluded=[0], signature="(n)->(n,p)") def _approx_basis(self, x: Array) -> Array: i = jnp.searchsorted(self.grid, x, side="right") - 1 lo = self.grid[i] k = jnp.expand_dims((x - lo) / self.step, -1) basis = (1.0 - k) * self.basis[i, :] + (k * self.basis[i + 1, :]) mask = jnp.logical_or(x < self.min_knot, x > self.max_knot) mask = jnp.expand_dims(mask, -1) return jnp.where(mask, 0.0, basis) @partial(jax.jit, static_argnums=0) @partial(jnp.vectorize, excluded=[0], signature="(n)->(n,p),(n,p)") def _approx_basis_and_deriv(self, x: Array) -> tuple[Array, Array]: """ Returns the basis matrix approximation and its gradient with respect to the data. """ i = jnp.searchsorted(self.grid, x, side="right") - 1 lo = self.grid[i] k = jnp.expand_dims((x - lo) / self.step, -1) basis = (1.0 - k) * self.basis[i, :] + (k * self.basis[i + 1, :]) basis_deriv = (1.0 - k) * self.basis_deriv[i, :] + ( k * self.basis_deriv[i + 1, :] ) mask = jnp.logical_or(x < self.min_knot, x > self.max_knot) mask = jnp.expand_dims(mask, -1) return jnp.where(mask, 0.0, basis), jnp.where(mask, 0.0, basis_deriv) @partial(jax.jit, static_argnums=0) @partial(jnp.vectorize, excluded=[0], signature="(n)->(n,p),(n,p),(n,p)") def _approx_basis_deriv_and_deriv2(self, x: Array) -> tuple[Array, Array, Array]: """ Returns the basis matrix approximation and its first and second derivative with respect to the data. """ i = jnp.searchsorted(self.grid, x, side="right") - 1 lo = self.grid[i] k = jnp.expand_dims((x - lo) / self.step, -1) basis = (1.0 - k) * self.basis[i, :] + (k * self.basis[i + 1, :]) basis_deriv = (1.0 - k) * self.basis_deriv[i, :] + ( k * self.basis_deriv[i + 1, :] ) basis_deriv2 = (1.0 - k) * self.basis_deriv2[i, :] + ( k * self.basis_deriv2[i + 1, :] ) mask = jnp.logical_or(x < self.min_knot, x > self.max_knot) mask = jnp.expand_dims(mask, -1) basis = jnp.where(mask, 0.0, basis) basis_deriv = jnp.where(mask, 0.0, basis_deriv) basis_deriv2 = jnp.where(mask, 0.0, basis_deriv2) return basis, basis_deriv, basis_deriv2 def _get_basis_dot_fn(self) -> Callable[[Array, Array], Array]: @jax.custom_jvp def _basis_dot( x: Array, coef: Array, ) -> Array: x = jnp.atleast_1d(x) basis = self.get_basis(x) smooth = jnp.einsum("...ip,...p->...i", basis, coef) return smooth @_basis_dot.defjvp def _basis_dot_jvp(primals, tangents): x, coef = primals x_dot, coef_dot = tangents basis, basis_deriv = self.get_basis_and_deriv(x) smooth = jnp.einsum("...ip,...p->...i", basis, coef) tangent_x = jnp.einsum("...ip,...p->...i", basis_deriv, coef) * x_dot tangent_coef = jnp.einsum("...ip,...p->...i", basis, coef_dot) tangent = tangent_x + tangent_coef return smooth, tangent return jax.jit(_basis_dot) def _get_basis_dot_and_deriv_fn( self, ) -> Callable[[Array, Array], tuple[Array, Array]]: @jax.custom_jvp def _basis_dot_and_deriv( x: Array, coef: Array, ) -> tuple[Array, Array]: x = jnp.atleast_1d(x) basis, basis_deriv = self.get_basis_and_deriv(x) smooth = jnp.einsum("...ip,...p->...i", basis, coef) smooth_deriv = jnp.einsum("...ip,...p->...i", basis_deriv, coef) return smooth, smooth_deriv @_basis_dot_and_deriv.defjvp def _basis_dot_and_deriv_jvp(primals, tangents): x, coef = primals x = jnp.atleast_1d(x) x_dot, coef_dot = tangents basis, basis_deriv, basis_deriv2 = self.get_basis_deriv_and_deriv2(x) smooth = jnp.einsum("...ip,...p->...i", basis, coef) smooth_deriv = jnp.einsum("...ip,...p->...i", basis_deriv, coef) smooth_deriv2 = jnp.einsum("...ip,...p->...i", basis_deriv2, coef) primal_out = (smooth, smooth_deriv) tangent_bdot_x = jnp.einsum("...ip,...p->...i", basis_deriv, coef) * x_dot tangent_bdot_coef = jnp.einsum("...ip,...p->...i", basis, coef_dot) tangent_bdot = tangent_bdot_x + tangent_bdot_coef tangent_deriv_x = smooth_deriv2 * x_dot tangent_deriv_coef = jnp.einsum("...ip,...p->...i", basis_deriv, coef_dot) tangent_deriv = tangent_deriv_x + tangent_deriv_coef tangent_out = (tangent_bdot, tangent_deriv) return primal_out, tangent_out return jax.jit(_basis_dot_and_deriv) class BSpline(BSplineApprox): """ B-Spline with linear extrapolation. Beyond the range of interior knots, this B-Spline can smoothly transition to a linear function. Params ------ knots The knots of the B-Spline. Assumed to be equidistant. order The order of the B-Spline. A cubic B-Spline is given by ``order=3``. approx If ``True`` (default), the B-Spline will be evaluated with an approximated \ basis matrix. ngrid The number of grid points used for the approximation. extrapolate If ``True`` (default), the B-Spline will smoothly transition to a linear \ function beyond the ranger of interior knots. The parameter ``eps`` controls \ the width of the transition segment, and ``target_slope`` defines the slope of the linear extrapolation. eps Controls the width of the transition from the B-Spline to linear \ extrapolation. This is a factor applied to the range of the knots. \ The default of ``0.1`` means that the transition interval width is \ ``0.1 * (max_knot - min_knot)``. target_slope Target slope for the linear extrapolation of the spline beyond the range \ of interior knots. If ``None`` (default), the target slope is set to the \ average slope of the spline over the range of the interior knots. postmultiply_by An array to be post-multiplied to the basis matrix, such that :meth:`.dot` \ effectively evaluates ``basis_matrix(x) @ Z @ coef``, where ``Z`` is the \ value of ``postmultiply_by``. If ``None`` (default), the identity matrix is \ used for ``Z``. """ def __init__( self, knots: Array, order: int = 3, approx: bool = True, ngrid: int = 1000, postmultiply_by: Array | None = None, extrapolate: bool = True, eps: float = 0.3, target_slope: float | None = None, ) -> None: super().__init__( knots=knots, order=order, approx=approx, ngrid=ngrid, postmultiply_by=postmultiply_by, ) self.basis_min, self.basis_grad_min = self.get_basis_and_deriv( jnp.atleast_1d(self.knots[order]) ) self.basis_max, self.basis_grad_max = self.get_basis_and_deriv( jnp.atleast_1d(self.knots[-(order + 1)]) ) self.eps = eps * (self.max_knot - self.min_knot) self.min_eps = self.min_knot - self.eps self.max_eps = self.max_knot + self.eps self.extrapolate = extrapolate self.target_slope = target_slope if extrapolate: self._basis_dot = self._get_extrap_basis_dot_fn(target_slope) self._basis_dot_and_deriv = self._get_extrap_basis_dot_and_deriv_fn( target_slope ) else: self._basis_dot = self._get_basis_dot_fn() self._basis_dot_and_deriv = self._get_basis_dot_and_deriv_fn() def _get_extrap_basis_dot_fn( self, target_slope: float | None = None ) -> Callable[[Array, Array], Array]: basis_dot_and_deriv_fn = self._get_basis_dot_and_deriv_fn() basis_dot_fn = self._get_basis_dot_fn() target_slope_fn = ( (lambda knots, coef, order: target_slope) if target_slope is not None else avg_slope_bspline ) @partial(jnp.vectorize, signature="(n),(p)->(n)") def basis_dot(x: Array, coef: Array) -> Array: basis_dot_left, deriv_left = basis_dot_and_deriv_fn(self.min_knot, coef) basis_dot_right, deriv_right = basis_dot_and_deriv_fn(self.max_knot, coef) target_slope = target_slope_fn(self.knots, coef, self.order) # ----------------------------------------------------------------- # Helper functions # ----------------------------------------------------------------- def _extrap_left_transition(self, x: Array) -> Array: def _unshifted_extrap(x: Array) -> Array: polynomial = x * self.min_knot - (x**2) / 2 term1 = (target_slope / self.eps) * polynomial term2 = jnp.squeeze(deriv_left) * (x - polynomial / self.eps) return term1 + term2 const = jnp.squeeze(basis_dot_left) - jnp.squeeze( _unshifted_extrap(self.min_knot) ) return _unshifted_extrap(x) + const def _extrap_right_transition(self, x: Array) -> Array: def _unshifted_extrap(x: Array) -> Array: polynomial = (x**2) / 2 - x * self.max_knot term1 = (target_slope / self.eps) * polynomial term2 = jnp.squeeze(deriv_right) * (x - polynomial / self.eps) return term1 + term2 const = jnp.squeeze(basis_dot_right) - jnp.squeeze( _unshifted_extrap(self.max_knot) ) return _unshifted_extrap(x) + const # ----------------------------------------------------------------- # Function segments # ----------------------------------------------------------------- outl = x < self.min_eps outr = x > self.max_eps transitl = (self.min_eps <= x) & (x < self.min_knot) transitr = (self.max_knot < x) & (x <= self.max_eps) center = (self.min_knot <= x) & (x <= self.max_knot) # ----------------------------------------------------------------- # Core spline # ----------------------------------------------------------------- value_center = center * basis_dot_fn(x, coef) # ----------------------------------------------------------------- # Main values # ----------------------------------------------------------------- # start points of linear extrapolation linl_start = _extrap_left_transition(self, self.min_eps) linr_start = _extrap_right_transition(self, self.max_eps) # linear transition val_linl = outl * (linl_start - target_slope * (self.min_eps - x)) val_linr = outr * (linr_start + target_slope * (x - self.max_eps)) # transition val_transitl = transitl * _extrap_left_transition(self, x) val_transitr = transitr * _extrap_right_transition(self, x) # put everything together value = val_linl + val_transitl + value_center + val_transitr + val_linr return value # Return jitted function return jax.jit(basis_dot) def _get_extrap_basis_dot_and_deriv_fn( self, target_slope: float | None = None ) -> Callable[[Array, Array], tuple[Array, Array]]: basis_dot_and_deriv_fn = self._get_basis_dot_and_deriv_fn() minmax = jnp.array([self.min_knot, self.max_knot]) target_slope_fn = ( (lambda knots, coef, order: target_slope) if target_slope is not None else avg_slope_bspline ) @partial(jnp.vectorize, signature="(n),(p)->(n),(n)") def basis_dot_and_deriv(x: Array, coef: Array) -> tuple[Array, Array]: basis_dot_left, deriv_left = basis_dot_and_deriv_fn(self.min_knot, coef) basis_dot_right, deriv_right = basis_dot_and_deriv_fn(self.max_knot, coef) target_slope = target_slope_fn(self.knots, coef, self.order) # ----------------------------------------------------------------- # Helper functions # ----------------------------------------------------------------- def _extrap_left_transition(self, x: Array) -> Array: def _unshifted_extrap(x: Array) -> Array: polynomial = x * self.min_knot - (x**2) / 2 term1 = (target_slope / self.eps) * polynomial term2 = jnp.squeeze(deriv_left) * (x - polynomial / self.eps) return term1 + term2 const = jnp.squeeze(basis_dot_left) - jnp.squeeze( _unshifted_extrap(self.min_knot) ) return _unshifted_extrap(x) + const def _extrap_right_transition(self, x: Array) -> Array: def _unshifted_extrap(x: Array) -> Array: polynomial = (x**2) / 2 - x * self.max_knot term1 = (target_slope / self.eps) * polynomial term2 = jnp.squeeze(deriv_right) * (x - polynomial / self.eps) return term1 + term2 const = jnp.squeeze(basis_dot_right) - jnp.squeeze( _unshifted_extrap(self.max_knot) ) return _unshifted_extrap(x) + const # ----------------------------------------------------------------- # Function segments # ----------------------------------------------------------------- outl = x < self.min_eps outr = x > self.max_eps transitl = (self.min_eps <= x) & (x < self.min_knot) transitr = (self.max_knot < x) & (x <= self.max_eps) center = (self.min_knot <= x) & (x <= self.max_knot) # ----------------------------------------------------------------- # Core spline and derivative # ----------------------------------------------------------------- value_center, deriv_center = basis_dot_and_deriv_fn(x, coef) value_center = center * value_center deriv_center = center * deriv_center # ----------------------------------------------------------------- # Main values # ----------------------------------------------------------------- # start points of linear extrapolation linl_start = _extrap_left_transition(self, self.min_eps) linr_start = _extrap_right_transition(self, self.max_eps) # linear transition val_linl = outl * (linl_start - target_slope * (self.min_eps - x)) val_linr = outr * (linr_start + target_slope * (x - self.max_eps)) # transition val_transitl = transitl * _extrap_left_transition(self, x) val_transitr = transitr * _extrap_right_transition(self, x) # put everything together value = val_linl + val_transitl + value_center + val_transitr + val_linr # ----------------------------------------------------------------- # Derivative # ----------------------------------------------------------------- # parts for evaluating transition derivatives _, (deriv_at_min, deriv_at_max) = basis_dot_and_deriv_fn(minmax, coef) distl = (self.min_knot - x) / self.eps distr = (x - self.max_knot) / self.eps # transition derivatives derivl = transitl * ((1.0 - distl) * deriv_at_min + target_slope * distl) derivr = transitr * ((1.0 - distr) * deriv_at_max + target_slope * distr) # put everything together deriv = deriv_center + (outl + outr) * target_slope + derivl + derivr return value, deriv # Return jitted function return jax.jit(basis_dot_and_deriv) def __call__(self, x: Array, coef: Array) -> Array: return self.dot(x, coef) def dot(self, x: Array, coef: Array) -> Array: """ Evaluates the B-Spline for given input and coefficient arrays. Essentially, this computes ``basis_matrix(x) @ coef``, albeit with linear extrapolation. Parameters ---------- x Input array. coef Coefficient array. """ return self._basis_dot(x, coef) def dot_and_deriv(self, x: Array, coef: Array) -> tuple[Array, Array]: """ Evaluates the B-Spline and its derivative for given input and coefficient arrays. Essentially, this computes: .. code-block:: python def dot_and_deriv(x, coef): dot = basis_matrix(x) @ coef deriv = basis_matrix_deriv(x) @ coef return dot, deriv The derivative is taken with respect to ``x``. Parameters ---------- x Input array. coef Coefficient array. """ return self._basis_dot_and_deriv(x, coef) class ExtrapBSplineApprox(BSpline): """ Alias for :class:`.BSpline` for compatibility. """ def get_extrap_basis_dot_fn( self, target_slope: float | None = None ) -> Callable[[Array, Array], Array]: return self._get_extrap_basis_dot_fn(target_slope) def get_extrap_basis_dot_and_deriv_fn( self, target_slope: float | None = None ) -> Callable[[Array, Array], tuple[Array, Array]]: return self._get_extrap_basis_dot_and_deriv_fn(target_slope) def avg_slope_bspline(knots: Array, coef: Array, order: int): dk = jnp.diff(knots).mean() p = jnp.shape(coef)[-1] coef = jnp.diff(coef) outer_border = coef[..., jnp.array([0, -1])] / 6 inner_border = 5 * coef[..., jnp.array([1, -2])] / 6 middle = coef[..., 2:-2] summed_coef = ( outer_border.sum(axis=-1, keepdims=True) + inner_border.sum(axis=-1, keepdims=True) + middle.sum(axis=-1, keepdims=True) ) return summed_coef / (dk * (p - order)) def _extrapolate_bspline_linearly_left( x, smooth: Array, coef: Array, knots: Array, slope: float | Array = 1.0, order: int = 3, reparam_matrix: Array | None = None, ): """ Can handle batched x and coef. The batch dimensions must be leading dimensions. Batching of x and smooth must be the same. """ reparam_matrix = ( reparam_matrix if reparam_matrix is not None else jnp.eye(coef.shape[-1]) ) min_knot = knots[order] min_output = _basis_dot( min_knot, knots, coef # , order=order, reparam_matrix=reparam_matrix ) # min_basis = bspline_basis(min_knot, knots, order) # min_basis = jnp.einsum("...ij,...jp->...ip", min_basis, reparam_matrix) # min_output = jnp.einsum("...ij,...j->...i", min_basis, coef) diffmin = min_knot - x linear_min = min_output - slope * diffmin smaller_than_min = x < min_knot smooth = jnp.where(smaller_than_min, linear_min, smooth) return smooth def _extrapolate_bspline_linearly_right( x, smooth: Array, coef: Array, knots: Array, slope: float | Array = 1.0, order: int = 3, reparam_matrix: Array | None = None, ): """ Can handle batched x and coef. The batch dimensions must be leading dimensions. Batching of x and smooth must be the same. """ reparam_matrix = ( reparam_matrix if reparam_matrix is not None else jnp.eye(coef.shape[-1]) ) max_knot = jnp.atleast_1d(knots[-(order + 1)]) max_basis = bspline_basis(max_knot, knots, order) max_basis = jnp.einsum("...ij,...jp->...ip", max_basis, reparam_matrix) max_output = jnp.einsum("...ij,...j->...i", max_basis, coef) diffmax = x - max_knot linear_max = max_output + slope * diffmax larger_than_max = x > max_knot smooth = jnp.where(larger_than_max, linear_max, smooth) return smooth def _extrapolate_bspline_linearly( x: Array, smooth: Array, coef: Array, knots: Array, slope: float = 1.0, order: int = 3, reparam_matrix: Array | None = None, ) -> Array: smooth = _extrapolate_bspline_linearly_left( x, smooth, coef, knots, slope=slope, order=order, reparam_matrix=reparam_matrix ) smooth = _extrapolate_bspline_linearly_right( x, smooth, coef, knots, slope=slope, order=order, reparam_matrix=reparam_matrix ) return smooth def _extrapolate_bspline_grad_constant_left( x: Array, smooth: Array, knots: Array, const: float, order: int = 3 ): min_knot = knots[order] smaller_than_min = x < min_knot smooth = jnp.where(smaller_than_min, const, smooth) return smooth def _extrapolate_bspline_grad_constant_right( x: Array, smooth: Array, knots: Array, const: float, order: int = 3 ): max_knot = knots[-(order + 1)] larger_than_max = x > max_knot smooth = jnp.where(larger_than_max, const, smooth) return smooth def _extrapolating_basis_dot_fixed_slope( x: Array, knots: Array, coef: Array, order: int = 3, slope: float = 1.0, reparam_matrix: Array | None = None, ): """ Can handle batched x and coef. The batch dimensions must be leading dimensions. """ # smooth = _basis_dot(x, knots, coef, order=order, reparam_matrix=reparam_matrix) smooth = _basis_dot(x, knots, coef) smooth = _extrapolate_bspline_linearly( x, smooth, coef, knots, slope=slope, order=3, reparam_matrix=reparam_matrix ) return smooth def _average_slope_in_segment(coef: Array, knots: Array) -> Array: """ Important! Assumes B-Spline of order 3! Also important! Assume to work on the coefs "g" of a full spline "Bg", where "B" is the basis matrix and "g" are the coefs. """ dknots = jnp.diff(knots).mean() dcoef = jnp.diff(coef) weighted_coefs = dcoef[..., 0] / 6 + 2 * dcoef[..., 1] / 3 + dcoef[..., 2] / 6 return jnp.expand_dims(weighted_coefs / dknots, -1) def _average_slope_left(coef: Array, knots: Array) -> Array: return _average_slope_in_segment(coef[:4], knots) def _average_slope_right(coef: Array, knots: Array) -> Array: return _average_slope_in_segment(coef[-4:], knots) _average_slope_left_jac = jax.jacobian(_average_slope_left) _average_slope_right_jac = jax.jacobian(_average_slope_right) def _extrapolating_dot_jac_coef(x: Array, knots: Array, coef: Array) -> Array: left = jnp.expand_dims(x < knots[3], -1) right = jnp.expand_dims(x > knots[-4], -1) inside = jnp.logical_not(jnp.logical_or(left, right)) # jax.jacobian(_basis_dot, argnums=2)(jnp.atleast_1d(knots[3]), knots, coef) bl = bspline_basis(jnp.atleast_1d(knots[3]), knots, 3) sl = _average_slope_left_jac(coef, knots) left_jac = bl - sl * jnp.expand_dims((knots[3] - x), -1) br = bspline_basis(jnp.atleast_1d(knots[-4]), knots, 3) sr = _average_slope_right_jac(coef, knots) right_jac = br + sr * jnp.expand_dims((x - knots[-4]), -1) jac = jnp.where( inside, bspline_basis(x, knots, 3), jnp.where( left, left_jac, right_jac, ), ) return jac @jax.custom_jvp def _extrapolating_basis_dot_continue_average_slope( x: Array, knots: Array, coef: Array, ): """ Extrapolates the B-Spline with straight lines with the slopes set to the average slopes in the boundary segments. """ # smooth = _basis_dot(x, knots, coef, order=order, reparam_matrix=reparam_matrix) smooth = _basis_dot(x, knots, coef) slope_left = _average_slope_left(coef, knots) slope_right = _average_slope_right(coef, knots) # slope_left = _average_slope_in_segment(coef[:4], knots) # slope_right = _average_slope_in_segment(coef[-4:], knots) smooth = _extrapolate_bspline_linearly_left( x, smooth, coef, knots, slope=slope_left, order=3, reparam_matrix=None ) smooth = _extrapolate_bspline_linearly_right( x, smooth, coef, knots, slope=slope_right, order=3, reparam_matrix=None, ) return smooth @_extrapolating_basis_dot_continue_average_slope.defjvp def _extrapolating_basis_dot_continue_average_slope_jvp(primals, tangents): x, knots, coef = primals x_dot, _, coef_dot = tangents x = jnp.atleast_1d(x) primal_out = _extrapolating_basis_dot_continue_average_slope(x, knots, coef) tangent_x = ( _extrapolating_basis_dot_grad_continue_average_slope(x, knots, coef) * x_dot ) tangent_coef = _extrapolating_dot_jac_coef(x, knots, coef) @ coef_dot tangent_out = tangent_x + tangent_coef return primal_out, tangent_out def _extrapolating_basis_dot_continue_point_slope( x: Array, knots: Array, coef: Array, order: int = 3, reparam_matrix: Array | None = None, ): """ Extrapolates the B-Spline with straight lines with the slopes at the boundary knots. """ # smooth = _basis_dot(x, knots, coef, order=order, reparam_matrix=reparam_matrix) smooth = _basis_dot(x, knots, coef) min_basis = bspline_basis_deriv(jnp.atleast_1d(knots[order]), knots, order) max_basis = bspline_basis_deriv(jnp.atleast_1d(knots[-(order + 1)]), knots, order) slope_left = jnp.einsum("...ip,...p->...i", min_basis, coef) slope_right = jnp.einsum("...ip,...p->...i", max_basis, coef) smooth = _extrapolate_bspline_linearly_left( x, smooth, coef, knots, slope=slope_left, order=3, reparam_matrix=reparam_matrix ) smooth = _extrapolate_bspline_linearly_right( x, smooth, coef, knots, slope=slope_right, order=3, reparam_matrix=reparam_matrix, ) return smooth @jax.custom_jvp def _basis_dot( x: Array, knots: Array, coef: Array, ): x = jnp.atleast_1d(x) basis = bspline_basis(x, knots, 3) smooth = jnp.einsum("...ip,...p->...i", basis, coef) return smooth @_basis_dot.defjvp def _basis_dot_jvp(primals, tangents): x, knots, coef = primals x_dot, _, coef_dot = tangents x = jnp.atleast_1d(x) primal_out = _basis_dot(x, knots, coef) tangent_x = _basis_dot_grad(x, knots, coef) * x_dot tangent_coef = bspline_basis(x, knots, 3) @ coef_dot tangent_out = tangent_x + tangent_coef return primal_out, tangent_out def pad0(original_array, num_zeros_begin, num_zeros_end): # Get the shape of the original array shape = original_array.shape # Construct the pad width based on the number of zeros to prepend at each end pad_width = [ (0, 0) if i < len(shape) - 1 else (num_zeros_begin, num_zeros_end) for i in range(len(shape)) ] # Pad the last axis with the specified number of zeros at each end zero_prepended_array = jnp.pad(original_array, tuple(pad_width), mode="constant") return zero_prepended_array @jax.custom_jvp def _basis_dot_grad( x: Array, knots: Array, coef: Array, ): x = jnp.atleast_1d(x) dcoef = jnp.diff(coef) / jnp.diff(knots).mean() basis = bspline_basis(x, knots, 2) mask = jnp.logical_or(x < knots[3], x > knots[-4]) mask = jnp.expand_dims(mask, -1) basis = jnp.where(mask, 0.0, basis) smooth = jnp.einsum("...ip,...p->...i", basis, pad0(dcoef, 1, 1)) return smooth @_basis_dot_grad.defjvp def _basis_dot_grad_jvp(primals, tangents): x, knots, coef = primals x_dot, _, coef_dot = tangents x = jnp.atleast_1d(x) primal_out = _basis_dot_grad(x, knots, coef) tangent_x = _basis_dot_grad2(x, knots, coef) * x_dot tangent_coef = bspline_basis_deriv(x, knots, 3) @ coef_dot tangent_out = tangent_x + tangent_coef return primal_out, tangent_out def _basis_dot_grad2( x: Array, knots: Array, coef: Array, ): # TODO Übergang stimmt nicht! x = jnp.atleast_1d(x) dcoef = jnp.diff(coef) / jnp.diff(knots).mean() dcoef = jnp.diff(dcoef) / jnp.diff(knots).mean() basis = bspline_basis(x, knots, 1) mask = jnp.logical_or(x < knots[3], x > knots[-4]) mask = jnp.expand_dims(mask, -1) basis = jnp.where(mask, 0.0, basis) smooth = jnp.einsum("...ip,...p->...i", basis, pad0(dcoef, 2, 2)) return smooth # x = jnp.atleast_1d(x) # basis = bspline_basis_grad2(x, knots, 3) # smooth = jnp.einsum("...ip,...p->...i", basis, coef) # return smooth def _extrapolating_basis_dot_grad_fixed_slope( x: Array, knots: Array, coef: Array, order: int = 3, slope_left: float | Array = 1.0, slope_right: float | Array = 1.0, reparam_matrix: Array | None = None, ): # smooth = _basis_dot_grad(x, knots, coef, order=order, reparam_matrix=reparam_matrix) smooth = _basis_dot_grad(x, knots, coef) smooth = _extrapolate_bspline_grad_constant_left( x, smooth, knots, slope_left, order=order ) smooth = _extrapolate_bspline_grad_constant_right( x, smooth, knots, slope_right, order=order ) return smooth @jax.custom_jvp def _extrapolating_basis_dot_grad_continue_average_slope( x: Array, knots: Array, coef: Array, ): slope_left = _average_slope_left(coef, knots) slope_right = _average_slope_right(coef, knots) smooth = _extrapolating_basis_dot_grad_fixed_slope( x, knots, coef, order=3, slope_left=slope_left, slope_right=slope_right, reparam_matrix=None, ) return smooth @_extrapolating_basis_dot_grad_continue_average_slope.defjvp def _extrapolating_basis_dot_grad_continue_average_slope_jvp(primals, tangents): x, knots, coef = primals x_dot, _, coef_dot = tangents x = jnp.atleast_1d(x) primal_out = _extrapolating_basis_dot_grad_continue_average_slope(x, knots, coef) tangent_x = _basis_dot_grad2(x, knots, coef) * x_dot sl = _average_slope_left_jac(coef, knots) b = bspline_basis_deriv(x, knots, 3) sr = _average_slope_right_jac(coef, knots) left = jnp.expand_dims(x < knots[3], -1) right = jnp.expand_dims(x > knots[-4], -1) inside = jnp.logical_not(jnp.logical_or(left, right)) tangent_coef = jnp.where(inside, b, jnp.where(left, sl, sr)) @ coef_dot tangent_out = tangent_x + tangent_coef return primal_out, tangent_out