Source code for goal.geometry.exponential_family.base

"""Exponential family hierarchy: ExponentialFamily, Gibbs, Generative, Differentiable, Analytic.

Each level adds capabilities --- sufficient statistics, Gibbs sampling, i.i.d. sampling, log-partition function, or negative entropy --- that unlock progressively more powerful inference algorithms.

Variable names encode the coordinate system throughout: ``params`` for natural parameters (with prefixed variants like ``obs_params`` for slices), ``means`` for mean parameters, and ``coords`` for coordinate-system-agnostic arrays in the manifold layer.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import override

import jax
import jax.numpy as jnp
from jax import Array

from ..manifold.base import Manifold
from ..manifold.util import batched_mean

### Exponential Families ###


[docs] class ExponentialFamily(Manifold, ABC): """A statistical manifold whose points are probability distributions in an exponential family. Subclasses define the sufficient statistic $\\mathbf{s}(x)$ and base measure $\\mu(x)$; higher levels of the hierarchy add the normalizing constant and its dual. Mathematically, an exponential family is a set of distributions whose densities share the form $p(x; \\theta) \\propto \\mu(x)\\exp(\\theta \\cdot \\mathbf{s}(x))$, where $\\theta \\in \\mathbb{R}^n$ are the natural parameters, $\\mathbf{s}(x)$ is the sufficient statistic --- a fixed mapping from data to $\\mathbb{R}^n$ that captures all information the data carries about $\\theta$ --- and $\\mu(x)$ is the base measure, a fixed reference density independent of $\\theta$. """ # Contract @property @abstractmethod def data_dim(self) -> int: """Dimension of the data space."""
[docs] @abstractmethod def sufficient_statistic(self, x: Array) -> Array: """Compute the sufficient statistic $\\mathbf{s}(x)$ of an observation."""
[docs] @abstractmethod def log_base_measure(self, x: Array) -> Array: """Compute $\\log \\mu(x)$ for an observation."""
# Methods
[docs] def average_sufficient_statistic(self, xs: Array, batch_size: int = 256) -> Array: """Average sufficient statistics over a batch of observations.""" return batched_mean(self.sufficient_statistic, xs, batch_size)
[docs] def check_natural_parameters(self, params: Array) -> Array: """Check if the given natural parameters are valid (all finite).""" return jnp.all(jnp.isfinite(params)).astype(jnp.int32)
[docs] def initialize( self, key: Array, location: float = 0.0, shape: float = 0.1, ) -> Array: """Generate random natural parameters from a Gaussian perturbation.""" return jax.random.normal(key, shape=(self.dim,)) * shape + location
[docs] def initialize_from_sample( self, key: Array, sample: Array, # pyright: ignore[reportUnusedParameter] location: float = 0.0, shape: float = 0.1, ) -> Array: """Generate random natural parameters, optionally informed by data. Default: ignores the sample. ``Analytic`` overrides this to use average sufficient statistics. """ return self.initialize(key, location, shape)
class Gibbs(ExponentialFamily, ABC): """Adds Gibbs sampling to an exponential family, enabling MCMC-based inference for models without direct i.i.d. sampling.""" # Contract @abstractmethod def gibbs_step(self, key: Array, params: Array, state: Array) -> Array: """Perform one Gibbs sampling step given natural parameters and a current state.""" # Methods def gibbs_chain( self, key: Array, params: Array, initial_state: Array, k: int, ) -> Array: """Run ``k`` Gibbs sampling steps given natural parameters, returning the final state.""" def step(state: Array, step_idx: Array) -> tuple[Array, None]: subkey = jax.random.fold_in(key, step_idx) return self.gibbs_step(subkey, params, state), None final_state, _ = jax.lax.scan(step, initial_state, jnp.arange(k)) return final_state
[docs] class Generative(Gibbs, ABC): """Adds i.i.d. sampling to a Gibbs-capable exponential family, enabling Monte Carlo estimation when closed-form expressions are unavailable.""" # Contract
[docs] @abstractmethod def sample(self, key: Array, params: Array, n: int = 1) -> Array: """Draw ``n`` samples from the distribution with the given natural parameters."""
# Overrides
[docs] @override def gibbs_step( self, key: Array, params: Array, state: Array, ) -> Array: """Perform one Gibbs sampling step given natural parameters and a current state. Default: samples independently, ignoring state. Override for models with efficient conditional sampling. """ return self.sample(key, params, n=1)[0]
# Methods
[docs] def stochastic_to_mean(self, key: Array, params: Array, n: int) -> Array: """Estimate average sufficient statistics by sampling from the given natural parameters.""" samples = self.sample(key, params, n) return self.average_sufficient_statistic(samples)
[docs] class Differentiable(Generative, ABC): """Adds an analytic log-partition function, enabling exact density evaluation and gradient-based optimization. Mathematically, the log-partition function $\\psi(\\theta) = \\log \\int \\mu(x)\\exp(\\theta \\cdot \\mathbf{s}(x))\\,dx$ normalizes the density. Its gradient defines the mean parameters $\\eta = \\nabla\\psi(\\theta) = \\mathbb{E}_{p(x;\\theta)}[\\mathbf{s}(x)]$, providing a dual coordinate system on the manifold. """ # Contract
[docs] @abstractmethod def log_partition_function(self, params: Array) -> Array: """Compute the log-partition function $\\psi$ at the given natural parameters."""
# Methods
[docs] def to_mean(self, params: Array) -> Array: """Convert natural parameters to mean parameters via $\\eta = \\nabla \\psi(\\theta)$.""" return jax.grad(self.log_partition_function)(params)
[docs] def log_density(self, params: Array, x: Array) -> Array: """Evaluate log-density at observation $x$ under the given natural parameters. Computes $\\log p(x;\\theta) = \\theta \\cdot \\mathbf{s}(x) + \\log \\mu(x) - \\psi(\\theta)$. """ suff_stats = self.sufficient_statistic(x) return ( jnp.dot(params, suff_stats) + self.log_base_measure(x) - self.log_partition_function(params) )
[docs] def density(self, params: Array, x: Array) -> Array: """Evaluate density at observation $x$ under the given natural parameters.""" return jnp.exp(self.log_density(params, x))
[docs] def relative_entropy(self, p_params: Array, q_params: Array) -> Array: """Compute KL divergence $D(p \\| q)$ between two distributions given their natural parameters. Uses the Bregman divergence form: $D(p \\| q) = \\psi(\\theta_q) - \\psi(\\theta_p) + \\eta_p \\cdot (\\theta_p - \\theta_q)$. """ p_means = self.to_mean(p_params) psi_p = self.log_partition_function(p_params) psi_q = self.log_partition_function(q_params) return psi_q - psi_p + jnp.dot(p_means, p_params - q_params)
[docs] def average_log_density( self, params: Array, xs: Array, batch_size: int = 2048 ) -> Array: """Average log-density over a batch of observations under the given natural parameters.""" def _log_density(x: Array) -> Array: return self.log_density(params, x) return batched_mean(_log_density, xs, batch_size)
[docs] class Analytic(Differentiable, ABC): """Adds a closed-form negative entropy $\\phi(\\eta)$, completing the duality between natural and mean coordinates. Mathematically, $\\phi(\\eta) = \\sup_{\\theta}\\{\\theta \\cdot \\eta - \\psi(\\theta)\\}$ is the Legendre conjugate of $\\psi$, and $\\theta = \\nabla\\phi(\\eta)$ inverts the natural-to-mean mapping. **NB:** This negative entropy is the convex conjugate of the log-partition function and does not include the base measure term. It may differ from the entropy as defined in information theory. """ # Contract
[docs] @abstractmethod def negative_entropy(self, means: Array) -> Array: """Compute negative entropy $\\phi$ at the given mean parameters."""
# Overrides
[docs] @override def initialize_from_sample( self, key: Array, sample: Array, location: float = 0.0, shape: float = 0.1, ) -> Array: """Initialize natural parameters from noisy average sufficient statistics of the sample.""" avg_suff_stat = self.average_sufficient_statistic(sample) noise = jax.random.normal(key, shape=(self.dim,)) * shape + location noisy_means = avg_suff_stat + noise return self.to_natural(noisy_means)
# Methods
[docs] def to_natural(self, means: Array) -> Array: """Convert mean parameters to natural parameters via $\\theta = \\nabla\\phi(\\eta)$.""" return jax.grad(self.negative_entropy)(means)