Source code for goal.models.base.gaussian.generalized
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import override
import jax
import jax.numpy as jnp
from jax import Array
from ....geometry import (
Differentiable,
ExponentialFamily,
)
[docs]
@dataclass(frozen=True)
class Euclidean(Differentiable):
"""Euclidean space $\\mathbb{R}^n$ of dimension $n$. Euclidean space consists of $n$-dimensional real vectors with the standard Euclidean distance metric
$$d(x,y) = \\sqrt{\\sum_{i=1}^n (x_i - y_i)^2}.$$
Euclidean also serves as the location component of a Normal distribution, and on its own we treat it as a normal distribution with unit covariance.
As an exponential family:
- Sufficient statistic: Identity map $s(x) = x$
- Base measure: $\\mu(x) = -\\frac{n}{2}\\log(2\\pi)$
""" # Fields
_dim: int
# Overrides
@property
@override
def dim(self) -> int:
"""Return the dimension of the space."""
return self._dim
@property
@override
def data_dim(self) -> int:
return self.dim
[docs]
@override
def sufficient_statistic(self, x: Array) -> Array:
"""Identity map on the data."""
return x
[docs]
@override
def log_base_measure(self, x: Array) -> Array:
"""Standard normal base measure including normalizing constant."""
return -0.5 * self.dim * jnp.log(2 * jnp.pi)
[docs]
@override
def sample(self, key: Array, params: Array, n: int = 1) -> Array:
"""Sample from a standard normal distribution with mean given by natural parameters.
For Euclidean space with unit covariance, the natural parameters are the mean,
so we sample x ~ N(params, I).
Args:
key: JAX random key
params: Natural parameters (the mean)
n: Number of samples
Returns:
Array of shape (n, dim) with samples
"""
noise = jax.random.normal(key, (n, self.dim))
return params + noise
[docs]
@override
def log_partition_function(self, params: Array) -> Array:
"""Compute log partition function for standard normal with unit covariance.
For a normal distribution N(\\mu, I) with sufficient statistic s(x) = x
and natural parameter \\theta = \\mu:
\\psi(\\theta) = 0.5 ||\\theta||^2 + (d/2) log(2\\pi)
Args:
params: Natural parameters (the mean)
Returns:
Scalar log partition function value
"""
return 0.5 * jnp.sum(params**2) + 0.5 * self.dim * jnp.log(2 * jnp.pi)
class GeneralizedGaussian[L: ExponentialFamily, S: ExponentialFamily](
Differentiable, ABC
):
r"""ABC for exponential families with Gaussian-like sufficient statistics.
This abc captures the shared structure between Normal distributions and
Boltzmann machines, where the sufficient statistics take the form:
$$s(x) = (x, x \otimes x)$$
with appropriate constraints on the second moment term for minimality.
In theory, both Normal distributions and Boltzmann machines share this fundamental
sufficient statistic structure but differ in their constraints:
- Normal: Second moment must be positive definite (continuous domain)
- Boltzmann: Second moment has off-diagonal interactions only (discrete binary domain)
This abstraction enables unified conjugation algorithms for harmoniums (bilinear
exponential family models) that can work with either continuous or discrete
distributions.
Type Parameters:
L: The location component manifold type (e.g., Euclidean for Normal,
Bernoullis for Boltzmann)
S: The shape component manifold type (e.g., Covariance for Normal,
CouplingMatrix for Boltzmann)
"""
@property
@abstractmethod
def loc_man(self) -> L:
"""Return the location component manifold."""
@property
@abstractmethod
def shp_man(self) -> S:
"""Return the shape component manifold."""
# Core split/join operations for harmonium conjugation
@abstractmethod
def split_location_precision(self, params: Array) -> tuple[Array, Array]:
"""Split natural parameters into location and precision in natural coordinates.
For harmonium conjugation, natural coordinates represent:
- Location: $\\theta_1 = \\Sigma^{-1}\\mu$ (Normal) or bias parameters (Boltzmann)
- Precision: $\\theta_2 = -\\frac{1}{2}\\Sigma^{-1}$ (Normal) or $-\\frac{1}{2}J$ (Boltzmann precisions)
Args:
params: Parameters in natural coordinates
Returns:
location: Location parameters in natural coordinates
precision: Precision/interaction parameters in natural coordinates
"""
@abstractmethod
def join_location_precision(self, location: Array, precision: Array) -> Array:
"""Join location and precision in natural coordinates.
Inverse of split_location_precision, combining components back into
full parameter vector.
Args:
location: Location parameters in natural coordinates
precision: Precision/interaction parameters in natural coordinates
Returns:
params: Combined parameters in natural coordinates
"""
@abstractmethod
def split_mean_second_moment(self, means: Array) -> tuple[Array, Array]:
"""Split parameters into mean and second moment in mean coordinates.
For harmonium conjugation, mean coordinates represent:
- Mean: $\\eta_1 = \\mu$ (first moment)
- Second moment: $\\eta_2 = \\mu\\mu^T + \\Sigma$ (Normal) or correlations (Boltzmann)
Args:
params: Parameters in mean coordinates
Returns:
mean: Mean parameters in mean coordinates
second_moment: Second moment parameters in mean coordinates
"""
@abstractmethod
def join_mean_second_moment(self, mean: Array, second_moment: Array) -> Array:
"""Join mean and second moment in mean coordinates.
Inverse of split_mean_second_moment, combining components back into
full parameter vector.
Args:
mean: Mean parameters in mean coordinates
second_moment: Second moment parameters in mean coordinates
Returns:
params: Combined parameters in mean coordinates
"""