Source code for goal.models.base.categorical
"""Categorical distributions over finite sets.
This module provides:
- `Categorical`: Distribution over n states with probabilities summing to 1
- `Bernoulli`: Special case for binary variables (equivalent to Categorical(2))
- `Bernoullis`: Product of n independent Bernoulli distributions
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import override
import jax
import jax.numpy as jnp
from jax import Array
from ...geometry import Analytic
from ...geometry.exponential_family.combinators import AnalyticProduct
@dataclass(frozen=True)
class Bernoulli(Analytic):
"""Bernoulli distribution for a single binary variable.
Mathematically equivalent to Categorical(n_categories=2).
The distribution over binary values x in {0, 1} is:
.. math::
p(x; \\theta) = \\sigma(\\theta)^x (1 - \\sigma(\\theta))^{1-x}
where $\\sigma(\\theta) = 1/(1 + \\exp(-\\theta))$ is the sigmoid function.
As an exponential family:
- Sufficient statistic: s(x) = x (identity)
- Base measure: \\mu(x) = 0
- Natural parameter: \\theta = log(p/(1-p)) (log odds)
- Mean parameter: \\eta = p = P(x=1)
- Log partition: \\psi(\\theta) = log(1 + exp(\\theta)) = softplus(\\theta)
- Negative entropy: \\phi(\\eta) = \\eta*log(\\eta) + (1-\\eta)*log(1-\\eta)
"""
@property
@override
def dim(self) -> int:
"""Parameter dimension is 1."""
return 1
@property
@override
def data_dim(self) -> int:
"""Data dimension is 1 (single binary value)."""
return 1
@override
def sufficient_statistic(self, x: Array) -> Array:
"""Identity sufficient statistic s(x) = x."""
return jnp.atleast_1d(x).astype(jnp.float32)
@override
def log_base_measure(self, x: Array) -> Array:
"""Base measure is constant (zero in log space)."""
return jnp.array(0.0)
@override
def log_partition_function(self, params: Array) -> Array:
"""Log partition function: log(1 + exp(\\theta)) = softplus(\\theta)."""
return jax.nn.softplus(params[0])
@override
def negative_entropy(self, means: Array) -> Array:
"""Negative entropy: \\eta*log(\\eta) + (1-\\eta)*log(1-\\eta)."""
p = means[0]
p0 = 1 - p
# Add small epsilon for numerical stability
eps = 1e-10
return p * jnp.log(p + eps) + p0 * jnp.log(p0 + eps)
@override
def sample(self, key: Array, params: Array, n: int = 1) -> Array:
"""Sample from Bernoulli distribution.
Args:
key: JAX random key
params: Natural parameters (log odds)
n: Number of samples
Returns:
Array of shape (n, 1) with binary values
"""
prob = self.to_mean(params)[0]
return jax.random.bernoulli(key, prob, shape=(n, 1)).astype(jnp.float32)
@override
def initialize_from_sample(
self,
key: Array,
sample: Array,
location: float = 0.0,
shape: float = 0.1,
) -> Array:
"""Initialize Bernoulli parameters from sample data.
Shrinks sample means toward 0.5 to handle boundary cases (exact 0s and 1s),
converts to natural parameters (logits), then adds noise in that space.
Args:
key: Random key
sample: Sample data (binary values)
location: Mean of noise distribution
shape: Std dev of noise distribution
Returns:
Natural parameters (log-odds).
"""
avg_suff = self.average_sufficient_statistic(sample)
# Shrink toward 0.5 to handle exact 0s and 1s smoothly
shrinkage = 0.01
shrunk_means = (1 - shrinkage) * avg_suff + shrinkage * 0.5
# Convert to natural parameters (logits) and add noise there
natural = self.to_natural(shrunk_means)
noise = jax.random.normal(key, shape=(self.dim,)) * shape + location
return natural + noise
# Convenience methods
def to_prob(self, means: Array) -> Array:
"""Extract P(x=1) from mean parameters."""
return means[0]
def from_prob(self, prob: float) -> Array:
"""Construct mean parameters from P(x=1)."""
return jnp.array([prob])
[docs]
@dataclass(frozen=True)
class Categorical(Analytic):
"""Categorical distribution over $n$ states.
The categorical distribution describes discrete probability distributions over $n$ states with probabilities $\\eta_i$ where $\\sum_{i=0}^n \\eta_i = 1$.
$$p(k; \\eta) = \\eta_k$$
As an exponential family:
- Base measure: $\\mu(k) = 0$
- Sufficient statistic: One-hot encoding for $k > 0$
- Log partition: $\\psi(\\theta) = \\log(1 + \\sum_{i=1}^d e^{\\theta_i})$
- Negative entropy: $\\phi(\\eta) = \\sum_{i=0}^d \\eta_i \\log(\\eta_i)$
"""
n_categories: int
"""Number of categories."""
@property
@override
def dim(self) -> int:
"""Dimension $d$ is `n_categories - 1` due to the sum-to-one constraint."""
return self.n_categories - 1
@property
@override
def data_dim(self) -> int:
"""Dimension of the data space."""
return 1
# Categorical methods
[docs]
def from_probs(self, probs: Array) -> Array:
"""Construct the mean parameters from the complete probabilities, dropping the first element."""
return probs[1:]
[docs]
def to_probs(self, means: Array) -> Array:
"""Return the probabilities of all labels."""
prob0 = 1 - jnp.sum(means)
return jnp.concatenate([jnp.array([prob0]), means])
# Overrides
[docs]
@override
def sufficient_statistic(self, x: Array) -> Array:
return jax.nn.one_hot(x - 1, self.n_categories - 1).reshape(-1)
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
return jnp.array(0.0)
[docs]
@override
def log_partition_function(self, params: Array) -> Array:
array = jnp.concatenate([jnp.array([0.0]), params])
max_val = jnp.max(array)
return max_val + jax.nn.logsumexp(array - max_val)
[docs]
@override
def negative_entropy(self, means: Array) -> Array:
probs = self.to_probs(means)
return jnp.sum(probs * jnp.log(probs))
[docs]
@override
def sample(
self,
key: Array,
params: Array,
n: int = 1,
) -> Array:
means = self.to_mean(params)
probs = self.to_probs(means)
key = jnp.asarray(key)
# Use Gumbel-Max trick: argmax(log(p) + Gumbel(0,1)) ~ Categorical(p)
g = jax.random.gumbel(key, shape=(n, self.n_categories))
return jnp.argmax(jnp.log(probs) + g, axis=-1)[..., None]
class Bernoullis(AnalyticProduct[Bernoulli]):
"""Product of n independent Bernoulli distributions.
Represents the distribution over n binary random variables where each
variable is independent. Commonly used for:
- Mean-field approximation of Boltzmann machines
- Observable/latent layers in Restricted Boltzmann Machines
The parameters represent the bias/activation of each binary unit.
Attributes:
n_neurons: Number of binary units
"""
def __init__(self, n_neurons: int):
"""Create a product of n independent Bernoullis.
Args:
n_neurons: Number of binary units
"""
super().__init__(Bernoulli(), n_neurons)
@property
def n_neurons(self) -> int:
"""Number of binary units."""
return self.n_reps