"""Poisson distributions as exponential families.
This module provides two count models in the exponential family framework:
1. Base Distributions:
- `Poisson`: Standard Poisson distribution for count data
- `CoMPoisson`: Conway-Maxwell-Poisson distribution for flexible dispersion
2. Components:
- `CoMShape`: Shape component for the COM-Poisson distribution
The Poisson distribution models count data with a single rate parameter, where the mean equals the variance. The Conway-Maxwell-Poisson extends this with a dispersion parameter, allowing for both over- and under-dispersed count data.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import override
import jax
import jax.numpy as jnp
from jax import Array
from ...geometry import (
Analytic,
Differentiable,
ExponentialFamily,
LocationShape,
)
from ...geometry.exponential_family.combinators import AnalyticProduct
### Classes ###
[docs]
@dataclass(frozen=True)
class Poisson(Analytic):
"""
The Poisson distribution models counts and is defined by a single rate parameter $\\eta > 0$. The probability mass function at count $k \\in \\mathbb{N}$ is given by
As an exponential family:
- Natural parameter: $\\theta = \\log(\\eta)$
- Probability mass function: $p(k; \\theta) = e^{\\theta k - \\log(k!)}$
- Sufficient statistic: $s(x) = x$
- Base measure: $\\mu(k) = -\\log(k!)$
- Log-partition function: $\\psi(\\theta) = e^{\\theta}$
- Negative entropy: $\\phi(\\eta) = \\eta\\log(\\eta) - \\eta$
Properties:
- Mean = Variance = $\\eta$
- Mode = $\\lfloor \\eta \\rfloor$
"""
@property
@override
def dim(self) -> int:
"""Single rate parameter."""
return 1
@property
@override
def data_dim(self) -> int:
return 1
[docs]
@override
def sufficient_statistic(self, x: Array) -> Array:
return jnp.atleast_1d(x)
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
k = jnp.asarray(x, dtype=jnp.float32)
return -_log_factorial(k)
[docs]
@override
def log_partition_function(self, params: Array) -> Array:
return jnp.squeeze(jnp.exp(params))
[docs]
@override
def negative_entropy(self, means: Array) -> Array:
rate = means
return jnp.squeeze(rate * (jnp.log(rate) - 1))
[docs]
@override
def sample(self, key: Array, params: Array, n: int = 1) -> Array:
means = self.to_mean(params)
rate = means
# JAX's Poisson sampler expects rate parameter
return jax.random.poisson(key, rate, shape=(n,))[..., None]
# Methods
[docs]
def statistical_mean(self, params: Array) -> Array:
return self.to_mean(params).reshape([1])
[docs]
def statistical_covariance(self, params: Array) -> Array:
return self.to_mean(params).reshape([1, 1])
[docs]
@override
def initialize_from_sample(
self,
key: Array,
sample: Array,
location: float = 0.0,
shape: float = 0.1,
) -> Array:
"""Initialize Poisson parameters from sample data.
Handles the case where some observations are 0 by clipping the mean
to a small positive value before converting to natural parameters.
Args:
key: Random key
sample: Sample data (count values)
location: Mean of noise distribution
shape: Std dev of noise distribution
Returns:
Natural parameters (log rate).
"""
avg_suff = self.average_sufficient_statistic(sample)
# Clip mean to small positive value to avoid log(0) = -inf
# Use 0.1 as minimum rate (small but not too small)
clipped_mean = jnp.clip(avg_suff, 0.1, None)
natural = jnp.log(clipped_mean)
noise = jax.random.normal(key, shape=(self.dim,)) * shape + location
return natural + noise
[docs]
@dataclass(frozen=True)
class CoMShape(ExponentialFamily):
"""Shape component of a CoMPoisson distribution. This represents the dispersion structure with sufficient statistic log(x!). It captures deviations from the standard Poisson variance-mean relationship.
The dispersion parameter $\\nu$ controls whether the distribution is:
- Equidispersed ($\\nu = 1$): Variance = Mean (standard Poisson)
- Underdispersed ($\\nu > 1$): Variance < Mean
- Overdispersed ($\\nu < 1$): Variance > Mean
"""
def __init__(self):
super().__init__()
@property
@override
def dim(self) -> int:
return 1
@property
@override
def data_dim(self) -> int:
return 1
[docs]
@override
def sufficient_statistic(self, x: Array) -> Array:
return jnp.atleast_1d(_log_factorial(x))
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
return jnp.array(0.0)
[docs]
@dataclass(frozen=True)
class CoMPoisson(LocationShape[Poisson, CoMShape], Differentiable):
"""The Conway-Maxwell Poisson distribution is a generalization of the Poisson distribution that can model both over- and under-dispersed count data. Its probability mass function is:
$$p(x; \\mu, \\nu) = \\frac{\\mu^x}{(x!)^\\nu Z(\\mu, \\nu)}$$
where:
- $\\mu > 0$ is related to the mode of the distribution
- $\\nu > 0$ is the dispersion parameter (pseudo-precision)
- $Z(\\mu, \\nu)$ is the normalizing constant:
$$Z(\\mu, \\nu) = \\sum_{j=0}^{\\infty} \\frac{\\mu^j}{(j!)^\\nu}$$
Special cases:
- When $\\nu = 1$: Standard Poisson distribution
- When $\\nu < 1$: Over-dispersed (variance > mean)
- When $\\nu > 1$: Under-dispersed (variance < mean)
- When $\\nu \\to \\infty$: Bernoulli distribution
- When $\\nu = 0$: Geometric distribution
As an exponential family:
- Natural parameters: $\\theta_1 = \\nu\\log(\\mu)$, $\\theta_2 = -\\nu$
- Sufficient statistics: $s(x) = (x, \\log(x!))$
- Log-partition function: log of the normalizing constant $Z$
"""
# Fields
window_size: int = 200
"""Fixed number of terms to evaluate in series expansions."""
# Methods
[docs]
def split_mode_dispersion(self, params: Array) -> tuple[Array, Array]:
"""Convert from natural parameters to mode-shape parameters.
The COM-Poisson distribution can be parameterized by either natural parameters $(\\theta_1, \\theta_2)$ or by mode-shape parameters $(\\mu, \\nu)$. The conversion
is given by:
$$\\nu = -\\theta_2$$
$$\\mu = \\exp(-\\theta_1/\\theta_2)$$
"""
theta1, theta2 = params[0], params[1]
nu = -theta2
mu = jnp.exp(-theta1 / theta2)
return mu, nu
[docs]
def join_mode_dispersion(self, mu: Array, nu: Array) -> Array:
"""Convert from mode-shape parameters to natural parameters.
The COM-Poisson distribution can be parameterized by either mode-shape parameters $(\\mu, \\nu)$ or natural parameters $(\\theta_1, \\theta_2)$. The conversion
is given by:
- $\\theta_1 = \\nu\\log(\\mu)$
- $\\theta_2 = -\\nu$
"""
theta1 = nu * jnp.log(mu)
theta2 = -nu
return jnp.array([theta1, theta2]).ravel()
[docs]
def approximate_mean_variance(self, params: Array) -> tuple[Array, Array]:
"""Compute approximate mean and variance of COM-Poisson distribution.
Given mode $\\mu$ and shape $\\nu$ parameters, the approximations are:
$E(X) \\approx \\mu + 1/(2\\nu) - 1/2$
$\\text{Var}(X) \\approx \\mu / \\nu$
"""
mu, nu = self.split_mode_dispersion(params)
approx_mean = mu + 1 / (2 * nu) - 0.5
approx_var = mu / nu
return approx_mean, approx_var
[docs]
def numerical_mean_variance(
self,
params: Array,
) -> tuple[Array, Array]:
"""Compute mean and variance using numerical integration.
Uses window-based approach centered on mode to compute:
$$E[X] = \\sum_{x=0}^\\infty x p(x)$$
$$E[X^2] = \\sum_{x=0}^\\infty x^2 p(x)$$
$$\\text{Var}(X) = E[X^2] - E[X]^2$$
"""
# Get mode for window centering
mu, _ = self.split_mode_dispersion(params)
# Create fixed window of indices
mode_shift = jnp.maximum(0, jnp.floor(mu - self.window_size / 2)).astype(
jnp.int32
)
indices = jnp.arange(self.window_size) + mode_shift
# Compute log probabilities
log_probs = jax.vmap(self.log_density, in_axes=(None, 0))(params, indices)
probs = jnp.exp(log_probs)
# Compute first and second moments
mean = jnp.sum(indices * probs)
second_moment = jnp.sum(indices**2 * probs)
# Compute variance
variance = second_moment - mean**2
return mean, variance
[docs]
def statistical_mean(self, params: Array) -> Array:
"""Numerical approximation of the mean."""
mean, _ = self.numerical_mean_variance(params)
return mean.reshape([1])
[docs]
def statistical_covariance(self, params: Array) -> Array:
"""Numerical approximation of the covariance."""
_, var = self.numerical_mean_variance(params)
return var.reshape([1, 1])
# Override
@property
@override
def fst_man(self) -> Poisson:
return Poisson()
@property
@override
def snd_man(self) -> CoMShape:
return CoMShape()
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
return jnp.asarray(0.0)
[docs]
@override
def log_partition_function(self, params: Array) -> Array:
"""Compute log partition function using fixed-width window strategy.
Evaluates:
$$\\psi(\\theta) = \\log\\sum_{j=0}^{\\infty} \\exp(\\theta_1 j + \\theta_2 \\log(j!))$$
using a fixed number of terms centered on the mode.
Args:
params: Array of natural parameters $(\\theta_1, \\theta_2)$
Returns:
Value of log partition function $\\psi(\\theta)$
"""
# Estimate mode and center window around it
# Estimate mode
mu, _ = self.split_mode_dispersion(params)
# Create fixed window of indices
base_indices = jnp.arange(self.window_size)
# Shift window to be centered on mode (rounded down to integer)
mode_shift = jnp.maximum(0, jnp.floor(mu - self.window_size / 2)).astype(
jnp.int32
)
indices = base_indices + mode_shift
def _compute_log_partition_terms(index: Array) -> Array:
return jnp.dot(params, self.sufficient_statistic(index))
# Compute terms and use log-sum-exp for numerical stability
log_terms = jax.vmap(_compute_log_partition_terms)(indices)
return jax.nn.logsumexp(log_terms)
# TODO: Come up with a better scheme than rejection sampling
[docs]
@override
def sample(self, key: Array, params: Array, n: int = 1) -> Array:
"""Generate random COM-Poisson samples using Algorithm 2 from Benson & Friel (2021)."""
mu, nu = self.split_mode_dispersion(params)
mode = jnp.floor(mu)
# Envelope terms for both Poisson and Geometric cases
# Underdispersed case (nu >= 1): Poisson envelope
log_pois_scale = (nu - 1) * (mode * jnp.log(mu) - _log_factorial(mode))
# Overdispersed case (nu < 1): Geometric envelope
p_geo = (2 * nu) / (2 * nu * mu + 1 + nu)
ratio = jnp.floor(mu / ((1 - p_geo) ** (1 / nu)))
log_geo_scale = (
-jnp.log(p_geo)
+ nu * ratio * jnp.log(mu)
- (ratio * jnp.log(1 - p_geo) + nu * _log_factorial(ratio))
)
def sample_one(key: Array) -> Array:
def cond_fn(val: tuple[Array, Array, Array]) -> Array:
_, _, accept = val
return jnp.logical_not(accept)
def body_fn(val: tuple[Array, Array, Array]) -> tuple[Array, Array, Array]:
key, y, _ = val
key, key_prop, key_u = jax.random.split(key, 3)
# Underdispersed case (nu >= 1): Poisson envelope
pois_y = jax.random.poisson(key_prop, mu)
log_alpha_pois = nu * (
pois_y * jnp.log(mu) - _log_factorial(pois_y)
) - (log_pois_scale + pois_y * jnp.log(mu) - _log_factorial(pois_y))
# Overdispersed case (nu >= 1): Poisson envelope
u0 = jax.random.uniform(key_prop)
geo_y = jnp.floor(jnp.log(u0) / jnp.log(1 - p_geo))
log_alpha_geo = nu * (geo_y * jnp.log(mu) - _log_factorial(geo_y)) - (
log_geo_scale + geo_y * jnp.log(1 - p_geo) + jnp.log(p_geo)
)
# Select proposal and ratio based on nu
y = jnp.where(nu >= 1, pois_y, geo_y)
log_alpha = jnp.where(nu >= 1, log_alpha_pois, log_alpha_geo)
# Accept/reject step
u = jax.random.uniform(key_u)
accept = jnp.asarray(jnp.log(u) <= log_alpha).reshape(())
return key, y.reshape(()), accept
init_val = (key, jnp.array(0), jnp.asarray(False).reshape(()))
_, sample, _ = jax.lax.while_loop(cond_fn, body_fn, init_val)
return sample
keys = jax.random.split(key, n)
samples = jax.vmap(sample_one)(keys)
return samples[..., None]
[docs]
@override
def check_natural_parameters(self, params: Array) -> Array:
"""Check if natural parameters are valid for COM-Poisson.
For parameters $(\\theta_1, \\theta_2)$, the following conditions must hold:
- $\\theta_1$ is finite, $\\theta_2 < 0$
"""
finite = super().check_natural_parameters(params)
theta2_valid = params[1] < 0
return finite & theta2_valid
[docs]
@override
def initialize(
self,
key: Array,
location: float = 0.0,
shape: float = 0.1,
) -> Array:
"""Initialize COM-Poisson parameters."""
key_mu, key_nu = jax.random.split(key)
# Ensure mu stays positive by using exp
mu_init = jnp.exp(jax.random.normal(key_mu) * shape + location)
# Keep nu in a reasonable range
nu_init = 1.0 + jnp.abs(jax.random.normal(key_nu)) * shape
return self.join_mode_dispersion(mu_init, nu_init)
[docs]
@override
def initialize_from_sample(
self, key: Array, sample: Array, location: float = 0.0, shape: float = 0.1
) -> Array:
"""Initialize COM-Poisson parameters from sample.
Estimates mode and shape parameters using method of moments based on sample mean and variance, with added noise for regularization.
"""
# Compute sample statistics
mean = jnp.mean(sample)
var = jnp.var(sample)
# Add noise for regularization
noise = jax.random.normal(key, shape=(2,)) * shape + location
mean = mean + noise[0]
var = var + noise[1]
a = var
b = -(mean + 0.5)
c = 0.5
nu = (-b + jnp.sqrt(b**2 - 4 * a * c)) / (2 * a)
mu = var * nu
# Convert to natural parameters
return self.join_mode_dispersion(mu, nu)
class Poissons(AnalyticProduct[Poisson]):
"""Product of n independent Poisson distributions.
Useful for modeling count data like image pixel intensities.
Unlike Binomial, Poisson counts are unbounded (can exceed any fixed value).
Attributes:
n_neurons: Number of independent Poisson units
"""
def __init__(self, n_neurons: int):
"""Create a product of n independent Poissons.
Args:
n_neurons: Number of units
"""
super().__init__(Poisson(), n_neurons)
@property
def n_neurons(self) -> int:
"""Number of Poisson units."""
return self.n_reps
### Helper Functions ###
def _log_factorial(k: Array) -> Array:
return jax.lax.lgamma(k.astype(float) + 1)