"""Combinators for composing exponential families: location-shape products and replicated (independent) products."""
from __future__ import annotations
from abc import ABC
from typing import override
import jax
import jax.numpy as jnp
from jax import Array
from ..manifold.combinators import Pair, Replicated
from .base import (
Analytic,
Differentiable,
ExponentialFamily,
Generative,
)
from .protocols import StatisticalMoments
[docs]
class LocationShape[Location: ExponentialFamily, Shape: ExponentialFamily](
Pair[Location, Shape], ExponentialFamily, ABC
):
"""A product of location and shape exponential families sharing the same data space.
Used for distributions like the Normal that decompose into a location component (e.g. mean vector) and a shape component (e.g. covariance matrix). The sufficient statistic concatenates both components; the base measure comes from the shape component.
"""
# Overrides
@property
@override
def data_dim(self) -> int:
return self.fst_man.data_dim
[docs]
@override
def sufficient_statistic(self, x: Array) -> Array:
loc_stats = self.fst_man.sufficient_statistic(x)
shape_stats = self.snd_man.sufficient_statistic(x)
return self.join_coords(loc_stats, shape_stats)
[docs]
@override
def initialize(
self, key: Array, location: float = 0.0, shape: float = 0.1
) -> Array:
key_loc, key_shp = jax.random.split(key)
fst_loc = self.fst_man.initialize(key_loc, location, shape)
shp_loc = self.snd_man.initialize(key_shp, location, shape)
return self.join_coords(fst_loc, shp_loc)
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
return self.snd_man.log_base_measure(x)
[docs]
class Product[M: ExponentialFamily](Replicated[M], ExponentialFamily):
"""Product of ``n_reps`` independent copies of the same exponential family.
The sufficient statistic and base measure decompose across replicates. If the base family supports ``StatisticalMoments``, the product exposes composed mean and block-diagonal covariance.
"""
# Overrides
@property
@override
def data_dim(self) -> int:
return self.rep_man.data_dim * self.n_reps
[docs]
@override
def sufficient_statistic(self, x: Array) -> Array:
x_reshaped = x.reshape(self.n_reps, -1)
stats = jax.vmap(self.rep_man.sufficient_statistic)(x_reshaped)
return stats.reshape(-1)
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
x_reshaped = x.reshape(self.n_reps, -1)
return jnp.sum(jax.vmap(self.rep_man.log_base_measure)(x_reshaped))
[docs]
@override
def initialize(
self, key: Array, location: float = 0.0, shape: float = 0.1
) -> Array:
keys = jax.random.split(key, self.n_reps)
params_2d = jax.vmap(self.rep_man.initialize, in_axes=(0, None, None))(
keys, location, shape
)
return self.to_1d(params_2d)
[docs]
@override
def initialize_from_sample(
self, key: Array, sample: Array, location: float = 0.0, shape: float = 0.1
) -> Array:
keys = jax.random.split(key, self.n_reps)
# Reshape sample (n_batch, n_reps * data_dim) -> per-replicate samples (n_reps, n_batch, data_dim)
n_batch = sample.shape[0]
rep_samples = sample.reshape(n_batch, self.n_reps, self.rep_man.data_dim)
rep_samples = jnp.moveaxis(rep_samples, 1, 0)
def init_one(rep_key: Array, rep_sample: Array) -> Array:
return self.rep_man.initialize_from_sample(
rep_key, rep_sample, location, shape
)
init_params = jax.vmap(init_one)(keys, rep_samples)
return self.to_1d(init_params)
# Methods
[docs]
def statistical_mean(self, params: Array) -> Array:
"""Compute the mean of the product distribution from its natural parameters.
Requires the base family to satisfy ``StatisticalMoments``.
"""
if not isinstance(self.rep_man, StatisticalMoments):
raise TypeError(
f"Replicated manifold {type(self.rep_man)} does not support statistical moment computation"
)
return self.map(self.rep_man.statistical_mean, params, flatten=True)
[docs]
def statistical_covariance(self, params: Array) -> Array:
"""Compute the block-diagonal covariance of the product distribution from its natural parameters.
Requires the base family to satisfy ``StatisticalMoments``.
"""
if not isinstance(self.rep_man, StatisticalMoments):
raise TypeError(
f"Replicated manifold {type(self.rep_man)} does not support statistical moment computation"
)
component_covs = self.map(self.rep_man.statistical_covariance, params)
# Scalar variances: return diagonal matrix
if component_covs.size == self.n_reps:
return jnp.diag(component_covs.ravel())
# Matrix covariances: build block diagonal
return jax.scipy.linalg.block_diag(*component_covs)
[docs]
class GenerativeProduct[M: Generative](Product[M], Generative):
"""Product of generative exponential families, adding independent sampling across replicates."""
# Overrides
[docs]
@override
def sample(self, key: Array, params: Array, n: int = 1) -> Array:
"""Draw ``n`` samples from the product distribution given its natural parameters."""
rep_keys = jax.random.split(key, self.n_reps)
params_2d = self.to_2d(params)
def sample_rep(rep_key: Array, rep_params: Array) -> Array:
return self.rep_man.sample(rep_key, rep_params, n)
# (n_reps, n, data_dim) -> (n, n_reps * data_dim)
samples = jax.vmap(sample_rep)(rep_keys, params_2d)
return jnp.reshape(jnp.moveaxis(samples, 1, 0), (n, -1))
[docs]
class DifferentiableProduct[M: Differentiable](Differentiable, GenerativeProduct[M]):
"""Product of differentiable exponential families, with log-partition function summed across replicates."""
# Overrides
[docs]
@override
def log_partition_function(self, params: Array) -> Array:
"""Sum of component log-partition functions at the given natural parameters."""
return jnp.sum(self.map(self.rep_man.log_partition_function, params))
[docs]
class AnalyticProduct[M: Analytic](DifferentiableProduct[M], Analytic, ABC):
"""Product of analytic exponential families, with negative entropy summed across replicates."""
# Overrides
[docs]
@override
def initialize_from_sample(
self, key: Array, sample: Array, location: float = 0.0, shape: float = 0.1
) -> Array:
"""Initialize by delegating to each replicate's initialize_from_sample.
This override ensures we use the Product version (which respects
domain-specific initialization in each replicate) rather than
the generic Analytic version.
"""
return Product.initialize_from_sample(self, key, sample, location, shape)
[docs]
@override
def negative_entropy(self, means: Array) -> Array:
"""Sum of component negative entropies at the given mean parameters."""
return jnp.sum(self.map(self.rep_man.negative_entropy, means))