Source code for goal.models.base.poisson

"""Poisson distributions as exponential families.

This module provides two count models in the exponential family framework:

1. Base Distributions:
   - `Poisson`: Standard Poisson distribution for count data
   - `CoMPoisson`: Conway-Maxwell-Poisson distribution for flexible dispersion

2. Components:
   - `CoMShape`: Shape component for the COM-Poisson distribution

The Poisson distribution models count data with a single rate parameter, where the mean equals the variance. The Conway-Maxwell-Poisson extends this with a dispersion parameter, allowing for both over- and under-dispersed count data.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import override

import jax
import jax.numpy as jnp
from jax import Array

from ...geometry import (
    Analytic,
    Differentiable,
    ExponentialFamily,
    LocationShape,
)
from ...geometry.exponential_family.combinators import AnalyticProduct

### Classes ###


[docs] @dataclass(frozen=True) class Poisson(Analytic): """ The Poisson distribution models counts and is defined by a single rate parameter $\\eta > 0$. The probability mass function at count $k \\in \\mathbb{N}$ is given by As an exponential family: - Natural parameter: $\\theta = \\log(\\eta)$ - Probability mass function: $p(k; \\theta) = e^{\\theta k - \\log(k!)}$ - Sufficient statistic: $s(x) = x$ - Base measure: $\\mu(k) = -\\log(k!)$ - Log-partition function: $\\psi(\\theta) = e^{\\theta}$ - Negative entropy: $\\phi(\\eta) = \\eta\\log(\\eta) - \\eta$ Properties: - Mean = Variance = $\\eta$ - Mode = $\\lfloor \\eta \\rfloor$ """ @property @override def dim(self) -> int: """Single rate parameter.""" return 1 @property @override def data_dim(self) -> int: return 1
[docs] @override def sufficient_statistic(self, x: Array) -> Array: return jnp.atleast_1d(x)
[docs] @override def log_base_measure(self, x: Array) -> Array: k = jnp.asarray(x, dtype=jnp.float32) return -_log_factorial(k)
[docs] @override def log_partition_function(self, params: Array) -> Array: return jnp.squeeze(jnp.exp(params))
[docs] @override def negative_entropy(self, means: Array) -> Array: rate = means return jnp.squeeze(rate * (jnp.log(rate) - 1))
[docs] @override def sample(self, key: Array, params: Array, n: int = 1) -> Array: means = self.to_mean(params) rate = means # JAX's Poisson sampler expects rate parameter return jax.random.poisson(key, rate, shape=(n,))[..., None]
# Methods
[docs] def statistical_mean(self, params: Array) -> Array: return self.to_mean(params).reshape([1])
[docs] def statistical_covariance(self, params: Array) -> Array: return self.to_mean(params).reshape([1, 1])
[docs] @override def initialize_from_sample( self, key: Array, sample: Array, location: float = 0.0, shape: float = 0.1, ) -> Array: """Initialize Poisson parameters from sample data. Handles the case where some observations are 0 by clipping the mean to a small positive value before converting to natural parameters. Args: key: Random key sample: Sample data (count values) location: Mean of noise distribution shape: Std dev of noise distribution Returns: Natural parameters (log rate). """ avg_suff = self.average_sufficient_statistic(sample) # Clip mean to small positive value to avoid log(0) = -inf # Use 0.1 as minimum rate (small but not too small) clipped_mean = jnp.clip(avg_suff, 0.1, None) natural = jnp.log(clipped_mean) noise = jax.random.normal(key, shape=(self.dim,)) * shape + location return natural + noise
[docs] @dataclass(frozen=True) class CoMShape(ExponentialFamily): """Shape component of a CoMPoisson distribution. This represents the dispersion structure with sufficient statistic log(x!). It captures deviations from the standard Poisson variance-mean relationship. The dispersion parameter $\\nu$ controls whether the distribution is: - Equidispersed ($\\nu = 1$): Variance = Mean (standard Poisson) - Underdispersed ($\\nu > 1$): Variance < Mean - Overdispersed ($\\nu < 1$): Variance > Mean """ def __init__(self): super().__init__() @property @override def dim(self) -> int: return 1 @property @override def data_dim(self) -> int: return 1
[docs] @override def sufficient_statistic(self, x: Array) -> Array: return jnp.atleast_1d(_log_factorial(x))
[docs] @override def log_base_measure(self, x: Array) -> Array: return jnp.array(0.0)
[docs] @dataclass(frozen=True) class CoMPoisson(LocationShape[Poisson, CoMShape], Differentiable): """The Conway-Maxwell Poisson distribution is a generalization of the Poisson distribution that can model both over- and under-dispersed count data. Its probability mass function is: $$p(x; \\mu, \\nu) = \\frac{\\mu^x}{(x!)^\\nu Z(\\mu, \\nu)}$$ where: - $\\mu > 0$ is related to the mode of the distribution - $\\nu > 0$ is the dispersion parameter (pseudo-precision) - $Z(\\mu, \\nu)$ is the normalizing constant: $$Z(\\mu, \\nu) = \\sum_{j=0}^{\\infty} \\frac{\\mu^j}{(j!)^\\nu}$$ Special cases: - When $\\nu = 1$: Standard Poisson distribution - When $\\nu < 1$: Over-dispersed (variance > mean) - When $\\nu > 1$: Under-dispersed (variance < mean) - When $\\nu \\to \\infty$: Bernoulli distribution - When $\\nu = 0$: Geometric distribution As an exponential family: - Natural parameters: $\\theta_1 = \\nu\\log(\\mu)$, $\\theta_2 = -\\nu$ - Sufficient statistics: $s(x) = (x, \\log(x!))$ - Log-partition function: log of the normalizing constant $Z$ """ # Fields window_size: int = 200 """Fixed number of terms to evaluate in series expansions.""" # Methods
[docs] def split_mode_dispersion(self, params: Array) -> tuple[Array, Array]: """Convert from natural parameters to mode-shape parameters. The COM-Poisson distribution can be parameterized by either natural parameters $(\\theta_1, \\theta_2)$ or by mode-shape parameters $(\\mu, \\nu)$. The conversion is given by: $$\\nu = -\\theta_2$$ $$\\mu = \\exp(-\\theta_1/\\theta_2)$$ """ theta1, theta2 = params[0], params[1] nu = -theta2 mu = jnp.exp(-theta1 / theta2) return mu, nu
[docs] def join_mode_dispersion(self, mu: Array, nu: Array) -> Array: """Convert from mode-shape parameters to natural parameters. The COM-Poisson distribution can be parameterized by either mode-shape parameters $(\\mu, \\nu)$ or natural parameters $(\\theta_1, \\theta_2)$. The conversion is given by: - $\\theta_1 = \\nu\\log(\\mu)$ - $\\theta_2 = -\\nu$ """ theta1 = nu * jnp.log(mu) theta2 = -nu return jnp.array([theta1, theta2]).ravel()
[docs] def approximate_mean_variance(self, params: Array) -> tuple[Array, Array]: """Compute approximate mean and variance of COM-Poisson distribution. Given mode $\\mu$ and shape $\\nu$ parameters, the approximations are: $E(X) \\approx \\mu + 1/(2\\nu) - 1/2$ $\\text{Var}(X) \\approx \\mu / \\nu$ """ mu, nu = self.split_mode_dispersion(params) approx_mean = mu + 1 / (2 * nu) - 0.5 approx_var = mu / nu return approx_mean, approx_var
[docs] def numerical_mean_variance( self, params: Array, ) -> tuple[Array, Array]: """Compute mean and variance using numerical integration. Uses window-based approach centered on mode to compute: $$E[X] = \\sum_{x=0}^\\infty x p(x)$$ $$E[X^2] = \\sum_{x=0}^\\infty x^2 p(x)$$ $$\\text{Var}(X) = E[X^2] - E[X]^2$$ """ # Get mode for window centering mu, _ = self.split_mode_dispersion(params) # Create fixed window of indices mode_shift = jnp.maximum(0, jnp.floor(mu - self.window_size / 2)).astype( jnp.int32 ) indices = jnp.arange(self.window_size) + mode_shift # Compute log probabilities log_probs = jax.vmap(self.log_density, in_axes=(None, 0))(params, indices) probs = jnp.exp(log_probs) # Compute first and second moments mean = jnp.sum(indices * probs) second_moment = jnp.sum(indices**2 * probs) # Compute variance variance = second_moment - mean**2 return mean, variance
[docs] def statistical_mean(self, params: Array) -> Array: """Numerical approximation of the mean.""" mean, _ = self.numerical_mean_variance(params) return mean.reshape([1])
[docs] def statistical_covariance(self, params: Array) -> Array: """Numerical approximation of the covariance.""" _, var = self.numerical_mean_variance(params) return var.reshape([1, 1])
# Override @property @override def fst_man(self) -> Poisson: return Poisson() @property @override def snd_man(self) -> CoMShape: return CoMShape()
[docs] @override def log_base_measure(self, x: Array) -> Array: return jnp.asarray(0.0)
[docs] @override def log_partition_function(self, params: Array) -> Array: """Compute log partition function using fixed-width window strategy. Evaluates: $$\\psi(\\theta) = \\log\\sum_{j=0}^{\\infty} \\exp(\\theta_1 j + \\theta_2 \\log(j!))$$ using a fixed number of terms centered on the mode. Args: params: Array of natural parameters $(\\theta_1, \\theta_2)$ Returns: Value of log partition function $\\psi(\\theta)$ """ # Estimate mode and center window around it # Estimate mode mu, _ = self.split_mode_dispersion(params) # Create fixed window of indices base_indices = jnp.arange(self.window_size) # Shift window to be centered on mode (rounded down to integer) mode_shift = jnp.maximum(0, jnp.floor(mu - self.window_size / 2)).astype( jnp.int32 ) indices = base_indices + mode_shift def _compute_log_partition_terms(index: Array) -> Array: return jnp.dot(params, self.sufficient_statistic(index)) # Compute terms and use log-sum-exp for numerical stability log_terms = jax.vmap(_compute_log_partition_terms)(indices) return jax.nn.logsumexp(log_terms)
# TODO: Come up with a better scheme than rejection sampling
[docs] @override def sample(self, key: Array, params: Array, n: int = 1) -> Array: """Generate random COM-Poisson samples using Algorithm 2 from Benson & Friel (2021).""" mu, nu = self.split_mode_dispersion(params) mode = jnp.floor(mu) # Envelope terms for both Poisson and Geometric cases # Underdispersed case (nu >= 1): Poisson envelope log_pois_scale = (nu - 1) * (mode * jnp.log(mu) - _log_factorial(mode)) # Overdispersed case (nu < 1): Geometric envelope p_geo = (2 * nu) / (2 * nu * mu + 1 + nu) ratio = jnp.floor(mu / ((1 - p_geo) ** (1 / nu))) log_geo_scale = ( -jnp.log(p_geo) + nu * ratio * jnp.log(mu) - (ratio * jnp.log(1 - p_geo) + nu * _log_factorial(ratio)) ) def sample_one(key: Array) -> Array: def cond_fn(val: tuple[Array, Array, Array]) -> Array: _, _, accept = val return jnp.logical_not(accept) def body_fn(val: tuple[Array, Array, Array]) -> tuple[Array, Array, Array]: key, y, _ = val key, key_prop, key_u = jax.random.split(key, 3) # Underdispersed case (nu >= 1): Poisson envelope pois_y = jax.random.poisson(key_prop, mu) log_alpha_pois = nu * ( pois_y * jnp.log(mu) - _log_factorial(pois_y) ) - (log_pois_scale + pois_y * jnp.log(mu) - _log_factorial(pois_y)) # Overdispersed case (nu >= 1): Poisson envelope u0 = jax.random.uniform(key_prop) geo_y = jnp.floor(jnp.log(u0) / jnp.log(1 - p_geo)) log_alpha_geo = nu * (geo_y * jnp.log(mu) - _log_factorial(geo_y)) - ( log_geo_scale + geo_y * jnp.log(1 - p_geo) + jnp.log(p_geo) ) # Select proposal and ratio based on nu y = jnp.where(nu >= 1, pois_y, geo_y) log_alpha = jnp.where(nu >= 1, log_alpha_pois, log_alpha_geo) # Accept/reject step u = jax.random.uniform(key_u) accept = jnp.asarray(jnp.log(u) <= log_alpha).reshape(()) return key, y.reshape(()), accept init_val = (key, jnp.array(0), jnp.asarray(False).reshape(())) _, sample, _ = jax.lax.while_loop(cond_fn, body_fn, init_val) return sample keys = jax.random.split(key, n) samples = jax.vmap(sample_one)(keys) return samples[..., None]
[docs] @override def check_natural_parameters(self, params: Array) -> Array: """Check if natural parameters are valid for COM-Poisson. For parameters $(\\theta_1, \\theta_2)$, the following conditions must hold: - $\\theta_1$ is finite, $\\theta_2 < 0$ """ finite = super().check_natural_parameters(params) theta2_valid = params[1] < 0 return finite & theta2_valid
[docs] @override def initialize( self, key: Array, location: float = 0.0, shape: float = 0.1, ) -> Array: """Initialize COM-Poisson parameters.""" key_mu, key_nu = jax.random.split(key) # Ensure mu stays positive by using exp mu_init = jnp.exp(jax.random.normal(key_mu) * shape + location) # Keep nu in a reasonable range nu_init = 1.0 + jnp.abs(jax.random.normal(key_nu)) * shape return self.join_mode_dispersion(mu_init, nu_init)
[docs] @override def initialize_from_sample( self, key: Array, sample: Array, location: float = 0.0, shape: float = 0.1 ) -> Array: """Initialize COM-Poisson parameters from sample. Estimates mode and shape parameters using method of moments based on sample mean and variance, with added noise for regularization. """ # Compute sample statistics mean = jnp.mean(sample) var = jnp.var(sample) # Add noise for regularization noise = jax.random.normal(key, shape=(2,)) * shape + location mean = mean + noise[0] var = var + noise[1] a = var b = -(mean + 0.5) c = 0.5 nu = (-b + jnp.sqrt(b**2 - 4 * a * c)) / (2 * a) mu = var * nu # Convert to natural parameters return self.join_mode_dispersion(mu, nu)
class Poissons(AnalyticProduct[Poisson]): """Product of n independent Poisson distributions. Useful for modeling count data like image pixel intensities. Unlike Binomial, Poisson counts are unbounded (can exceed any fixed value). Attributes: n_neurons: Number of independent Poisson units """ def __init__(self, n_neurons: int): """Create a product of n independent Poissons. Args: n_neurons: Number of units """ super().__init__(Poisson(), n_neurons) @property def n_neurons(self) -> int: """Number of Poisson units.""" return self.n_reps ### Helper Functions ### def _log_factorial(k: Array) -> Array: return jax.lax.lgamma(k.astype(float) + 1)