Exponential Families

Exponential family hierarchy: ExponentialFamily, Gibbs, Generative, Differentiable, Analytic.

Each level adds capabilities — sufficient statistics, Gibbs sampling, i.i.d. sampling, log-partition function, or negative entropy — that unlock progressively more powerful inference algorithms.

Variable names encode the coordinate system throughout: params for natural parameters (with prefixed variants like obs_params for slices), means for mean parameters, and coords for coordinate-system-agnostic arrays in the manifold layer.

Class Hierarchy

Inheritance diagram of goal.geometry.exponential_family.base

class ExponentialFamily[source]

Bases: Manifold, ABC

A statistical manifold whose points are probability distributions in an exponential family.

Subclasses define the sufficient statistic \(\mathbf{s}(x)\) and base measure \(\mu(x)\); higher levels of the hierarchy add the normalizing constant and its dual.

Mathematically, an exponential family is a set of distributions whose densities share the form \(p(x; \theta) \propto \mu(x)\exp(\theta \cdot \mathbf{s}(x))\), where \(\theta \in \mathbb{R}^n\) are the natural parameters, \(\mathbf{s}(x)\) is the sufficient statistic — a fixed mapping from data to \(\mathbb{R}^n\) that captures all information the data carries about \(\theta\) — and \(\mu(x)\) is the base measure, a fixed reference density independent of \(\theta\).

abstract property data_dim: int

Dimension of the data space.

abstractmethod sufficient_statistic(x: Array) Array[source]

Compute the sufficient statistic \(\mathbf{s}(x)\) of an observation.

abstractmethod log_base_measure(x: Array) Array[source]

Compute \(\log \mu(x)\) for an observation.

average_sufficient_statistic(xs: Array, batch_size: int = 256) Array[source]

Average sufficient statistics over a batch of observations.

check_natural_parameters(params: Array) Array[source]

Check if the given natural parameters are valid (all finite).

initialize(key: Array, location: float = 0.0, shape: float = 0.1) Array[source]

Generate random natural parameters from a Gaussian perturbation.

initialize_from_sample(key: Array, sample: Array, location: float = 0.0, shape: float = 0.1) Array[source]

Generate random natural parameters, optionally informed by data.

Default: ignores the sample. Analytic overrides this to use average sufficient statistics.

class Generative[source]

Bases: Gibbs, ABC

Adds i.i.d. sampling to a Gibbs-capable exponential family, enabling Monte Carlo estimation when closed-form expressions are unavailable.

abstractmethod sample(key: Array, params: Array, n: int = 1) Array[source]

Draw n samples from the distribution with the given natural parameters.

gibbs_step(key: Array, params: Array, state: Array) Array[source]

Perform one Gibbs sampling step given natural parameters and a current state.

Default: samples independently, ignoring state. Override for models with efficient conditional sampling.

stochastic_to_mean(key: Array, params: Array, n: int) Array[source]

Estimate average sufficient statistics by sampling from the given natural parameters.

class Differentiable[source]

Bases: Generative, ABC

Adds an analytic log-partition function, enabling exact density evaluation and gradient-based optimization.

Mathematically, the log-partition function \(\psi(\theta) = \log \int \mu(x)\exp(\theta \cdot \mathbf{s}(x))\,dx\) normalizes the density. Its gradient defines the mean parameters \(\eta = \nabla\psi(\theta) = \mathbb{E}_{p(x;\theta)}[\mathbf{s}(x)]\), providing a dual coordinate system on the manifold.

abstractmethod log_partition_function(params: Array) Array[source]

Compute the log-partition function \(\psi\) at the given natural parameters.

to_mean(params: Array) Array[source]

Convert natural parameters to mean parameters via \(\eta = \nabla \psi(\theta)\).

log_density(params: Array, x: Array) Array[source]

Evaluate log-density at observation \(x\) under the given natural parameters.

Computes \(\log p(x;\theta) = \theta \cdot \mathbf{s}(x) + \log \mu(x) - \psi(\theta)\).

density(params: Array, x: Array) Array[source]

Evaluate density at observation \(x\) under the given natural parameters.

relative_entropy(p_params: Array, q_params: Array) Array[source]

Compute KL divergence \(D(p \| q)\) between two distributions given their natural parameters.

Uses the Bregman divergence form: \(D(p \| q) = \psi(\theta_q) - \psi(\theta_p) + \eta_p \cdot (\theta_p - \theta_q)\).

average_log_density(params: Array, xs: Array, batch_size: int = 2048) Array[source]

Average log-density over a batch of observations under the given natural parameters.

class Analytic[source]

Bases: Differentiable, ABC

Adds a closed-form negative entropy \(\phi(\eta)\), completing the duality between natural and mean coordinates.

Mathematically, \(\phi(\eta) = \sup_{\theta}\{\theta \cdot \eta - \psi(\theta)\}\) is the Legendre conjugate of \(\psi\), and \(\theta = \nabla\phi(\eta)\) inverts the natural-to-mean mapping.

NB: This negative entropy is the convex conjugate of the log-partition function and does not include the base measure term. It may differ from the entropy as defined in information theory.

abstractmethod negative_entropy(means: Array) Array[source]

Compute negative entropy \(\phi\) at the given mean parameters.

initialize_from_sample(key: Array, sample: Array, location: float = 0.0, shape: float = 0.1) Array[source]

Initialize natural parameters from noisy average sufficient statistics of the sample.

to_natural(means: Array) Array[source]

Convert mean parameters to natural parameters via \(\theta = \nabla\phi(\eta)\).