"""Von Mises distribution over the circle."""
from __future__ import annotations
from dataclasses import dataclass
from typing import override
import jax
import jax.numpy as jnp
from jax import Array
from jax.scipy.special import i0e
from ...geometry import Differentiable
from ...geometry.exponential_family.combinators import DifferentiableProduct
[docs]
@dataclass(frozen=True)
class VonMises(Differentiable):
"""The von Mises distribution is a continuous probability distribution on the circle, analogous to the normal distribution on the line. The probability density function is:
$$p(x; \\mu, \\kappa) = \\frac{1}{2\\pi I_0(\\kappa)}\\exp(\\kappa \\cos(x - \\mu))$$
where:
- $\\mu$ is the mean direction
- $\\kappa$ is the concentration parameter
- $I_0(\\kappa)$ is the modified Bessel function of the first kind of order 0
As an exponential family:
- Sufficient statistic: $\\mathbf{s}(x) = (\\cos(x), \\sin(x))$
- Base measure: $\\mu(x) = -\\log(2\\pi)$
- Natural parameters: $\\theta = \\kappa(\\cos(\\mu), \\sin(\\mu))$
"""
# Methods
[docs]
def split_mean_concentration(self, p: Array) -> tuple[Array, Array]:
"""Split the natural parameters into mean and concentration parameters.
Args:
p: Natural parameters array
Returns:
Tuple of (mean_angle, concentration)
"""
theta = p
kappa = jnp.sqrt(jnp.sum(theta**2))
mu = jnp.arctan2(theta[1], theta[0])
return mu, kappa
[docs]
def join_mean_concentration(self, mu0: float, kappa0: float) -> Array:
"""Join the mean and concentration parameters into natural parameters.
Args:
mu0: Mean angle
kappa0: Concentration parameter
Returns:
Natural parameters array
"""
mu = jnp.atleast_1d(mu0)
kappa = jnp.atleast_1d(kappa0)
return kappa * jnp.concatenate([jnp.cos(mu), jnp.sin(mu)])
# Overrides
@property
@override
def dim(self) -> int:
return 2
@property
@override
def data_dim(self) -> int:
return 1
[docs]
@override
def sufficient_statistic(self, x: Array) -> Array:
"""Compute sufficient statistics: (cos(x), sin(x)).
Args:
x: Data point
Returns:
Sufficient statistics array
"""
return jnp.array([jnp.cos(x), jnp.sin(x)]).ravel()
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
"""Log base measure: -log(2\\pi).
Args:
x: Data point
Returns:
Log base measure (scalar)
"""
return -jnp.log(2 * jnp.pi)
[docs]
@override
def log_partition_function(self, params: Array) -> Array:
"""Compute log partition function.
Args:
params: Natural parameters
Returns:
Log partition function value (scalar)
"""
kappa = jnp.sqrt(jnp.sum(params**2))
# Explicitly cast i0e output to Array
return jnp.log(i0e(kappa)) + kappa
[docs]
@override
def sample(self, key: Array, params: Array, n: int = 1) -> Array:
"""Generate n samples from the Von Mises distribution.
Uses batched rejection sampling with wrapped Cauchy proposal,
adapted from NumPyro's implementation (Devroye, 1986).
For very small concentration (kappa < 0.01), samples uniformly
on the circle since the distribution is nearly uniform.
Args:
key: JAX random key
params: Natural parameters
n: Number of samples
Returns:
Array of n samples with shape (n, 1)
"""
mu, kappa = self.split_mean_concentration(params)
# For very small kappa, the distribution is nearly uniform
# Use uniform sampling to avoid numerical issues
kappa_threshold = 0.01
use_uniform = kappa < kappa_threshold
# Clamp kappa for rejection sampling (used when kappa >= threshold)
kappa_clamped = jnp.maximum(kappa, kappa_threshold)
# Compute shape parameter s with numerical stability
# For small kappa, use approximate formula; for large kappa, use exact
s_cutoff = 1.2e-4 # Cutoff for float64
r = 1.0 + jnp.sqrt(1.0 + 4.0 * kappa_clamped**2)
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * kappa_clamped)
s_exact = (1.0 + rho**2) / (2.0 * rho)
s_approximate = 1.0 / kappa_clamped
s = jnp.where(kappa_clamped > s_cutoff, s_exact, s_approximate)
# Broadcast s to sample shape
shape = (n,)
s = jnp.broadcast_to(s, shape)
kappa_broadcast = jnp.broadcast_to(kappa_clamped, shape)
def cond_fn(val: tuple[Array, Array, Array, Array, Array]) -> Array:
"""Check if all samples done or reached max iterations."""
i, _, done, _, _ = val
return jnp.logical_and(i < 100, jnp.logical_not(jnp.all(done)))
def body_fn(
val: tuple[Array, Array, Array, Array, Array],
) -> tuple[Array, Array, Array, Array, Array]:
i, key, done, u, w = val
key, key_u, key_v = jax.random.split(key, 3)
# Sample uniform in [-1, 1] and [0, 1]
u_new = jax.random.uniform(key_u, shape=shape, minval=-1.0, maxval=1.0)
v = jax.random.uniform(key_v, shape=shape)
# Wrapped Cauchy proposal
z = jnp.cos(jnp.pi * u_new)
w_new = (1.0 + s * z) / (s + z)
# Acceptance criterion
y = kappa_broadcast * (s - w_new)
# Add small epsilon to avoid log(0) or division issues
y_safe = jnp.maximum(y, 1e-10)
v_safe = jnp.maximum(v, 1e-10)
accept = jnp.logical_or(
y_safe * (2.0 - y_safe) >= v_safe,
jnp.log(y_safe / v_safe) + 1.0 >= y_safe,
)
# Update only where not already done
u = jnp.where(done, u, u_new)
w = jnp.where(done, w, w_new)
done = jnp.logical_or(done, accept)
return i + 1, key, done, u, w
# Initialize rejection sampling loop
init_done = jnp.zeros(shape, dtype=bool)
init_u = jnp.zeros(shape)
init_w = jnp.zeros(shape)
init_val = (jnp.array(0), key, init_done, init_u, init_w)
_, _, _, u, w = jax.lax.while_loop(cond_fn, body_fn, init_val)
# Convert to angle and shift by mean (rejection sampling result)
rejection_samples = jnp.sign(u) * jnp.arccos(jnp.clip(w, -1.0, 1.0)) + mu
# Uniform samples on the circle
key, uniform_key = jax.random.split(key)
uniform_samples = jax.random.uniform(
uniform_key, shape=shape, minval=-jnp.pi, maxval=jnp.pi
)
# Use uniform samples when kappa is very small
samples = jnp.where(use_uniform, uniform_samples, rejection_samples)
return samples[..., None]
class VonMisesProduct(DifferentiableProduct[VonMises]):
"""Product of n independent Von Mises distributions.
Useful for modeling multiple circular/angular latent variables.
Each component has natural parameters (\\kappa \\cos(\\mu), \\kappa \\sin(\\mu)) where
\\mu is the mean direction and \\kappa is the concentration.
The sufficient statistic for n components is a 2n-dimensional vector:
[cos(\\theta_1), sin(\\theta_1), cos(\\theta_2), sin(\\theta_2), ..., cos(\\theta_n), sin(\\theta_n)]
Attributes:
n_components: Number of independent Von Mises variables
"""
def __init__(self, n_components: int):
"""Create a product of n independent Von Mises distributions.
Args:
n_components: Number of Von Mises variables
"""
super().__init__(VonMises(), n_components)
@property
def n_components(self) -> int:
"""Number of Von Mises components."""
return self.n_reps