Source code for goal.models.base.categorical

"""Categorical distributions over finite sets.

This module provides:
- `Categorical`: Distribution over n states with probabilities summing to 1
- `Bernoulli`: Special case for binary variables (equivalent to Categorical(2))
- `Bernoullis`: Product of n independent Bernoulli distributions
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import override

import jax
import jax.numpy as jnp
from jax import Array

from ...geometry import Analytic
from ...geometry.exponential_family.combinators import AnalyticProduct


@dataclass(frozen=True)
class Bernoulli(Analytic):
    """Bernoulli distribution for a single binary variable.

    Mathematically equivalent to Categorical(n_categories=2).

    The distribution over binary values x in {0, 1} is:

    .. math::

        p(x; \\theta) = \\sigma(\\theta)^x (1 - \\sigma(\\theta))^{1-x}

    where $\\sigma(\\theta) = 1/(1 + \\exp(-\\theta))$ is the sigmoid function.

    As an exponential family:
        - Sufficient statistic: s(x) = x (identity)
        - Base measure: \\mu(x) = 0
        - Natural parameter: \\theta = log(p/(1-p)) (log odds)
        - Mean parameter: \\eta = p = P(x=1)
        - Log partition: \\psi(\\theta) = log(1 + exp(\\theta)) = softplus(\\theta)
        - Negative entropy: \\phi(\\eta) = \\eta*log(\\eta) + (1-\\eta)*log(1-\\eta)
    """

    @property
    @override
    def dim(self) -> int:
        """Parameter dimension is 1."""
        return 1

    @property
    @override
    def data_dim(self) -> int:
        """Data dimension is 1 (single binary value)."""
        return 1

    @override
    def sufficient_statistic(self, x: Array) -> Array:
        """Identity sufficient statistic s(x) = x."""
        return jnp.atleast_1d(x).astype(jnp.float32)

    @override
    def log_base_measure(self, x: Array) -> Array:
        """Base measure is constant (zero in log space)."""
        return jnp.array(0.0)

    @override
    def log_partition_function(self, params: Array) -> Array:
        """Log partition function: log(1 + exp(\\theta)) = softplus(\\theta)."""
        return jax.nn.softplus(params[0])

    @override
    def negative_entropy(self, means: Array) -> Array:
        """Negative entropy: \\eta*log(\\eta) + (1-\\eta)*log(1-\\eta)."""
        p = means[0]
        p0 = 1 - p
        # Add small epsilon for numerical stability
        eps = 1e-10
        return p * jnp.log(p + eps) + p0 * jnp.log(p0 + eps)

    @override
    def sample(self, key: Array, params: Array, n: int = 1) -> Array:
        """Sample from Bernoulli distribution.

        Args:
            key: JAX random key
            params: Natural parameters (log odds)
            n: Number of samples

        Returns:
            Array of shape (n, 1) with binary values
        """
        prob = self.to_mean(params)[0]
        return jax.random.bernoulli(key, prob, shape=(n, 1)).astype(jnp.float32)

    @override
    def initialize_from_sample(
        self,
        key: Array,
        sample: Array,
        location: float = 0.0,
        shape: float = 0.1,
    ) -> Array:
        """Initialize Bernoulli parameters from sample data.

        Shrinks sample means toward 0.5 to handle boundary cases (exact 0s and 1s),
        converts to natural parameters (logits), then adds noise in that space.

        Args:
            key: Random key
            sample: Sample data (binary values)
            location: Mean of noise distribution
            shape: Std dev of noise distribution

        Returns:
            Natural parameters (log-odds).
        """
        avg_suff = self.average_sufficient_statistic(sample)
        # Shrink toward 0.5 to handle exact 0s and 1s smoothly
        shrinkage = 0.01
        shrunk_means = (1 - shrinkage) * avg_suff + shrinkage * 0.5
        # Convert to natural parameters (logits) and add noise there
        natural = self.to_natural(shrunk_means)
        noise = jax.random.normal(key, shape=(self.dim,)) * shape + location
        return natural + noise

    # Convenience methods

    def to_prob(self, means: Array) -> Array:
        """Extract P(x=1) from mean parameters."""
        return means[0]

    def from_prob(self, prob: float) -> Array:
        """Construct mean parameters from P(x=1)."""
        return jnp.array([prob])


[docs] @dataclass(frozen=True) class Categorical(Analytic): """Categorical distribution over $n$ states. The categorical distribution describes discrete probability distributions over $n$ states with probabilities $\\eta_i$ where $\\sum_{i=0}^n \\eta_i = 1$. $$p(k; \\eta) = \\eta_k$$ As an exponential family: - Base measure: $\\mu(k) = 0$ - Sufficient statistic: One-hot encoding for $k > 0$ - Log partition: $\\psi(\\theta) = \\log(1 + \\sum_{i=1}^d e^{\\theta_i})$ - Negative entropy: $\\phi(\\eta) = \\sum_{i=0}^d \\eta_i \\log(\\eta_i)$ """ n_categories: int """Number of categories.""" @property @override def dim(self) -> int: """Dimension $d$ is `n_categories - 1` due to the sum-to-one constraint.""" return self.n_categories - 1 @property @override def data_dim(self) -> int: """Dimension of the data space.""" return 1 # Categorical methods
[docs] def from_probs(self, probs: Array) -> Array: """Construct the mean parameters from the complete probabilities, dropping the first element.""" return probs[1:]
[docs] def to_probs(self, means: Array) -> Array: """Return the probabilities of all labels.""" prob0 = 1 - jnp.sum(means) return jnp.concatenate([jnp.array([prob0]), means])
# Overrides
[docs] @override def sufficient_statistic(self, x: Array) -> Array: return jax.nn.one_hot(x - 1, self.n_categories - 1).reshape(-1)
[docs] @override def log_base_measure(self, x: Array) -> Array: return jnp.array(0.0)
[docs] @override def log_partition_function(self, params: Array) -> Array: array = jnp.concatenate([jnp.array([0.0]), params]) max_val = jnp.max(array) return max_val + jax.nn.logsumexp(array - max_val)
[docs] @override def negative_entropy(self, means: Array) -> Array: probs = self.to_probs(means) return jnp.sum(probs * jnp.log(probs))
[docs] @override def sample( self, key: Array, params: Array, n: int = 1, ) -> Array: means = self.to_mean(params) probs = self.to_probs(means) key = jnp.asarray(key) # Use Gumbel-Max trick: argmax(log(p) + Gumbel(0,1)) ~ Categorical(p) g = jax.random.gumbel(key, shape=(n, self.n_categories)) return jnp.argmax(jnp.log(probs) + g, axis=-1)[..., None]
class Bernoullis(AnalyticProduct[Bernoulli]): """Product of n independent Bernoulli distributions. Represents the distribution over n binary random variables where each variable is independent. Commonly used for: - Mean-field approximation of Boltzmann machines - Observable/latent layers in Restricted Boltzmann Machines The parameters represent the bias/activation of each binary unit. Attributes: n_neurons: Number of binary units """ def __init__(self, n_neurons: int): """Create a product of n independent Bernoullis. Args: n_neurons: Number of binary units """ super().__init__(Bernoulli(), n_neurons) @property def n_neurons(self) -> int: """Number of binary units.""" return self.n_reps