Normal Distribution¶
This module provides implementations of multivariate normal distributions in an exponential family framework. Each normal is build out of the base components: - Euclidean: The location component (\(\mathbb{R}^n\)) - Covariance: The shape component with flexible structure
Class Hierarchy¶
Base Components¶
- class Euclidean(_dim: int)[source]¶
Bases:
DifferentiableEuclidean space \(\mathbb{R}^n\) of dimension \(n\). Euclidean space consists of \(n\)-dimensional real vectors with the standard Euclidean distance metric
\[d(x,y) = \sqrt{\sum_{i=1}^n (x_i - y_i)^2}.\]Euclidean also serves as the location component of a Normal distribution, and on its own we treat it as a normal distribution with unit covariance.
- As an exponential family:
Sufficient statistic: Identity map \(s(x) = x\)
Base measure: \(\mu(x) = -\frac{n}{2}\log(2\pi)\)
- log_base_measure(x: Array) Array[source]¶
Standard normal base measure including normalizing constant.
- sample(key: Array, params: Array, n: int = 1) Array[source]¶
Sample from a standard normal distribution with mean given by natural parameters.
For Euclidean space with unit covariance, the natural parameters are the mean, so we sample x ~ N(params, I).
- Parameters:
key – JAX random key
params – Natural parameters (the mean)
n – Number of samples
- Returns:
Array of shape (n, dim) with samples
- log_partition_function(params: Array) Array[source]¶
Compute log partition function for standard normal with unit covariance.
For a normal distribution N(mu, I) with sufficient statistic s(x) = x and natural parameter theta = mu:
psi(theta) = 0.5 ||theta||^2 + (d/2) log(2pi)
- Parameters:
params – Natural parameters (the mean)
- Returns:
Scalar log partition function value
- class Covariance(data_dim: int, rep: Rep)[source]¶
Bases:
SquareMap[Euclidean],ExponentialFamily,GenericShape component of a Normal distribution.
- This represents the covariance structure of a Normal distribution through different matrix representations:
PositiveDefinite: Full covariance matrix
Diagonal: diagonal elements
Scale: Scalar multiple of identity
Identity: Unit covariance
The Rep parameter determines both:
How parameters are stored (symmetric matrix, diagonal, scalar)
The effiency of operations (matrix multiply, inversion, etc.)
- As an exponential family:
Sufficient statistic: Outer product \(s(x) = x \otimes x\)
Base measure: \(\mu(x) = -\frac{n}{2}\log(2\pi)\)
- rep: Rep¶
The matrix representation strategy for this linear map.
- sufficient_statistic(x: Array) Array[source]¶
Outer product with appropriate covariance structure.
- Returns:
Sufficient statistic in mean parameters (outer product of x).
- check_natural_parameters(params: Array) Array[source]¶
Check if natural parameters (precision matrix) are valid.
For covariance/precision matrices, we check: 1. All parameters are finite 2. Precision matrix is numerically positive definite
- Parameters:
params – Natural parameters (precision matrix).
- Returns:
Boolean array indicating validity.
- initialize(key: Array, location: float = 0.0, shape: float = 0.1) Array[source]¶
Initialize covariance matrix with random perturbation from identity.
Uses a low-rank perturbation I + LL^T where L has entries drawn from N(0, shape/sqrt(dim)) to ensure reasonable condition numbers.
- Returns:
Natural parameters (covariance matrix).
Normal Distribution¶
- class Normal(_data_dim: int, rep: Rep)[source]¶
Bases:
GeneralizedGaussian[Euclidean,Covariance],LocationShape[Euclidean,Covariance],Analytic,Generic(Multivariate) Normal distributions.
The standard expression for the Normal density is
\[p(x; \mu, \Sigma) = (2\pi)^{-d/2}|\Sigma|^{-1/2}e^{-\frac{1}{2}(x-\mu) \cdot \Sigma^{-1} \cdot (x-\mu)},\]where
\(\mu\) is the mean vector,
\(\Sigma\) is the covariance matrix, and
\(d\) is the dimension of the data.
- As an exponential family:
Sufficient statistic: \(\mathbf{s}(x) = (x, x \otimes x)\)
Base measure: \(\mu(x) = -\frac{d}{2}\log(2\pi)\)
Natural parameters: \(\theta_1 = \Sigma^{-1}\mu\), \(\theta_2 = -\frac{1}{2}\Sigma^{-1}\)
Mean parameters: \(\eta_1 = \mu\), \(\eta_2 = \mu\mu^T + \Sigma\)
Log-partition: \(\psi(\theta) = -\frac{1}{4}\theta_1 \cdot \theta_2^{-1} \cdot \theta_1 - \frac{1}{2}\log|-2\theta_2|\)
Negative entropy: \(\phi(\eta) = -\frac{1}{2}\log|\eta_2 - \eta_1\eta_1^T| - \frac{d}{2}(1 + \log(2\pi))\)
Different covariance structures are handled through the Rep parameter, providing appropriate trade-offs between flexibility and computational efficiency.
- rep: Rep¶
Covariance representation type.
- sufficient_statistic(x: Array) Array[source]¶
Compute sufficient statistic (x, x otimes x).
- Returns:
Sufficient statistic in mean parameters.
- log_partition_function(params: Array) Array[source]¶
Compute log partition function from natural parameters.
- Parameters:
params – Natural parameters (location, precision).
- Returns:
Log partition function value.
- negative_entropy(means: Array) Array[source]¶
Compute negative entropy from mean parameters.
- Parameters:
means – Mean parameters (mean, second moment).
- Returns:
Negative entropy value.
- sample(key: Array, params: Array, n: int = 1) Array[source]¶
Sample from the distribution.
- Parameters:
key – Random key.
params – Natural parameters.
n – Number of samples.
- Returns:
Samples from the distribution.
- initialize(key: Array, location: float = 0.0, shape: float = 0.1) Array[source]¶
Initialize means with normal and covariance matrix with random diagonal structure.
- Returns:
Natural parameters.
- initialize_from_sample(key: Array, sample: Array, location: float = 0.0, shape: float = 0.1) Array[source]¶
Initialize Normal parameters from sample data.
Computes mean and second moments from sample data, then adds regularizing noise to avoid degenerate cases. The noise is scaled relative to the observed variance to maintain reasonable parameter ranges.
- Parameters:
key – Random key
sample – Sample data to initialize from
location – Scale for additive noise to mean (relative to observed std dev)
shape – Scale for multiplicative noise to covariance
- Returns:
Natural parameters.
- check_natural_parameters(params: Array) Array[source]¶
Check if natural parameters are valid.
Delegates to Covariance check after extracting precision component.
- Parameters:
params – Natural parameters to check.
- Returns:
Boolean array indicating validity.
- property cov_man: Covariance[Rep]¶
Covariance manifold.
- join_mean_covariance(mean: Array, covariance: Array) Array[source]¶
Construct mean parameters from the mean mu and covariance Sigma.
- Parameters:
mean – Mean vector (in mean parameters).
covariance – Covariance matrix (in mean parameters).
- Returns:
Combined mean parameters.
- split_mean_covariance(means: Array) tuple[Array, Array][source]¶
Extract the mean mu and covariance Sigma from mean parameters.
- Parameters:
means – Mean parameters (mean, second moment).
- Returns:
Tuple of (mean vector, covariance matrix) in mean parameters.
- split_mean_second_moment(means: Array) tuple[Array, Array][source]¶
Split parameters into mean and second-moment components.
- Parameters:
means – Mean parameters.
- Returns:
Tuple of (mean vector, second moment) in mean parameters.
- join_mean_second_moment(mean: Array, second_moment: Array) Array[source]¶
Join mean and second-moment parameters.
- Parameters:
mean – Mean vector (in mean parameters).
second_moment – Second moment (in mean parameters).
- Returns:
Combined mean parameters.
- split_location_precision(params: Array) tuple[Array, Array][source]¶
Split natural location and precision (inverse covariance) parameters.
There’s some subtle rescaling that has to happen to ensure that the natural parameters behaves correctly when used either as a vector in a dot product, or as a precision matrix.
For a multivariate normal distribution, the natural parameters (theta_1, theta_2) are related to the standard parameters (mu, Sigma) by theta_1 = Sigma^{-1} mu and theta_2 = -1/2 Sigma^{-1}. Matrix representations require different scaling to maintain these relationships:
- Diagonal case:
No additional rescaling needed as parameters directly represent diagonal elements
- Full (PositiveDefinite) case:
Off-diagonal elements appear twice in the precision matrix but once in the natural parameters
For i != j, element theta_{2,ij} is stored as double its matrix value to account for missing parameters in the dot product
When converting to precision Sigma^{-1}, vector elements corresponding to off-diagonal elements are halved
- Scale case:
The exponential family dot product has to be scaled by 1/d
This scales needs to be stored in either in the sufficient statistic or the natural parameters
We store it in the sufficient statistic (hence its defined as an average), which requires that we divide the natural parameters by d when converting to precision
- Parameters:
params – Natural parameters.
- Returns:
Tuple of (location, precision) in natural parameters.
- join_location_precision(location: Array, precision: Array) Array[source]¶
Join natural location and precision (inverse covariance) parameters.
Inverts the scaling in split_location_precision.
- Parameters:
location – Location vector (in natural parameters).
precision – Precision matrix (in natural parameters).
- Returns:
Combined natural parameters.
- embed_rep(trg_man: Normal, params: Array) Array[source]¶
Embed natural parameters into a more complex representation.
For example, a diagonal matrix can be embedded as a full matrix with zeros off the diagonal.
- Parameters:
trg_man – Target manifold with more complex representation.
params – Natural parameters to embed.
- Returns:
Embedded natural parameters.
- project_rep(trg_man: Normal, means: Array) Array[source]¶
Project mean parameters to a simpler representation.
For example, a full matrix can be projected to a diagonal one. In Mean coordinates this corresponds to the information (moment matching) projection.
- Parameters:
trg_man – Target manifold with simpler representation.
means – Mean parameters to project.
- Returns:
Projected mean parameters.
- regularize_covariance(means: Array, jitter: float = 0, min_var: float = 0) Array[source]¶
Regularize covariance matrix to ensure numerical stability and reasonable variances.
- This method applies two forms of regularization to the covariance matrix:
A minimum variance constraint that prevents any dimension from having variance below a specified threshold
A jitter term that adds a small positive value to all diagonal elements, improving numerical stability
The regularization preserves the correlation structure while ensuring the covariance matrix remains well-conditioned.
- Parameters:
means – Mean parameters to regularize.
jitter – Value to add to diagonal.
min_var – Minimum variance.
- Returns:
Regularized mean parameters.
- relative_whiten(given_means: Array, relative_means: Array) Array[source]¶
Whiten a normal distribution relative to another normal distribution.
Transforms a normal distribution to have standard normal parameters by: 1. Extracting the distribution’s mean and covariance 2. Scaling the mean: new_mean = precision^(1/2) @ (old_mean - mean) 3. Transforming the covariance: new_cov = precision^(1/2) @ old_cov @ precision^(1/2) where precision^(1/2) is the inverse of the Cholesky decomposition of covariance
- Parameters:
given – Mean parameters to whiten.
relative – Mean parameters of reference distribution.
- Returns:
Whitened mean parameters.
- standard_normal() Array[source]¶
Return the standard normal distribution.
- Returns:
Mean parameters for standard normal (zero mean, identity covariance).
- statistical_mean(params: Array) Array[source]¶
Compute the mean of the distribution.
- Parameters:
params – Natural parameters.
- Returns:
Mean vector.
- statistical_covariance(params: Array) Array[source]¶
Compute the covariance of the distribution.
- Parameters:
params – Natural parameters.
- Returns:
Covariance matrix (as dense array).
- property shp_man: Covariance[Rep]¶
covariance manifold.
- Type:
Shape component
- property snd_man: Covariance[Rep]¶
shape manifold (covariance structure).
- Type:
Second component
Specialized Normal Types¶
- goal.models.base.gaussian.normal.FullNormal¶
Type alias:
Normal[PositiveDefinite]Normal distribution with unrestricted positive definite covariance matrix.
This variant can represent arbitrary correlations between dimensions with a full covariance matrix. Most flexible but requires \(O(d^2)\) parameters for d-dimensional data and \(O(d^3)\) operations for key computations.
- goal.models.base.gaussian.normal.DiagonalNormal¶
Type alias:
Normal[Diagonal]Normal distribution with diagonal covariance matrix.
This variant assumes independence between dimensions, with a diagonal covariance matrix representing only per-dimension variances. Requires \(O(d)\) parameters and \(O(d)\) operations for key computations.
- goal.models.base.gaussian.normal.IsotropicNormal¶
Type alias:
Normal[Scale]Normal distribution with scalar multiple of identity covariance matrix.
This variant has equal variance in all dimensions and no correlations. Requires only a single scale parameter regardless of dimensionality and enables highly efficient computations.
- goal.models.base.gaussian.normal.StandardNormal¶
Type alias:
Normal[Identity]Normal distribution with identity covariance matrix.
This variant has unit variance in all dimensions and no correlations. Requires no shape parameters and enables optimal computational efficiency.
Covariance Types¶
- goal.models.base.gaussian.normal.FullCovariance¶
Type alias:
Covariance[PositiveDefinite]Unrestricted positive definite covariance matrix representation.
Stores all \(\frac{d(d+1)}{2}\) unique elements of a symmetric positive definite matrix. Supports arbitrary correlation structures with maximum flexibility at the cost of higher parameter count and computational complexity.
- goal.models.base.gaussian.normal.DiagonalCovariance¶
Type alias:
Covariance[Diagonal]Diagonal covariance matrix representation.
Stores only \(d\) diagonal elements representing per-dimension variances. Assumes independence between dimensions, enabling \(O(d)\) storage and operations while still allowing heterogeneous variances across dimensions.
- goal.models.base.gaussian.normal.IsotropicCovariance¶
Type alias:
Covariance[Scale]Scalar multiple of identity covariance matrix representation.
Represents the covariance matrix using a single scalar value, allowing for extremely efficient \(O(1)\) storage and fast vectorized computations. Models equal variance in all dimensions with no correlations.