Source code for goal.models.harmonium.lgm

"""This module provides implementations of linear Gaussian models (LGMs), including factor analysis and principal component analysis. LGMs model linear, Gaussian relationships between observable and latent variables. The conjugacy of LGMs enables exact inference and EM."""

from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Any, override

import jax
import jax.numpy as jnp
from jax import Array

from ...geometry import (
    AnalyticConjugated,
    Diagonal,
    DifferentiableConjugated,
    EmbeddedMap,
    Identity,
    IdentityEmbedding,
    LinearEmbedding,
    MatrixRep,
    PositiveDefinite,
    Rectangular,
    Scale,
    SymmetricConjugated,
)
from ..base.gaussian.boltzmann import Boltzmann, DiagonalBoltzmann
from ..base.gaussian.generalized import Euclidean, GeneralizedGaussian
from ..base.gaussian.normal import (
    Covariance,
    FullNormal,
    Normal,
    full_normal,
)

### Helper Functions ###


[docs] @dataclass(frozen=True) class GeneralizedGaussianLocationEmbedding[G: GeneralizedGaussian[Any, Any]]( LinearEmbedding[Euclidean, G], ): """Embedding of the Euclidean location component into a GeneralizedGaussian distribution. Projects a GeneralizedGaussian point in mean coordinates to its Euclidean location component, or embeds a location vector in natural coordinates into a GeneralizedGaussian with zero shape parameters. """ gau_man: G """The GeneralizedGaussian distribution.""" @property @override def amb_man(self) -> G: return self.gau_man @property @override def sub_man(self) -> Euclidean: return self.gau_man.loc_man
[docs] @override def project(self, means: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride] """Project to Euclidean location component. Works on mean coordinates, extracting the location component from the full generalized Gaussian parameterization. If given a data point (size == data_dim), converts it to sufficient statistics (mean coordinates) first. Parameters ---------- means : Array Mean coordinate parameters or data point in GeneralizedGaussian space. Returns ------- Array Mean parameters in Euclidean space (location only). """ # Convert data points to sufficient statistics (mean coordinates) if means.size == self.gau_man.data_dim and means.size != self.gau_man.dim: means = self.gau_man.sufficient_statistic(means) loc, _ = self.gau_man.split_mean_second_moment(means) return loc
[docs] @override def embed(self, params: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride] """Embed Euclidean location into GeneralizedGaussian with zero shape. Parameters ---------- params : Array Natural parameters in Euclidean space. Returns ------- Array Natural parameters in GeneralizedGaussian space. """ zero_shape = self.gau_man.shp_man.zeros() return self.gau_man.join_location_precision(params, zero_shape)
[docs] @override def translate( # pyright: ignore[reportIncompatibleMethodOverride] self, params: Array, delta: Array, ) -> Array: """Translate by adding Euclidean offset to location. Parameters ---------- params : Array Natural parameters in GeneralizedGaussian space. delta : Array Euclidean offset to add. Returns ------- Array Translated natural parameters in GeneralizedGaussian space. """ loc, shape = self.gau_man.split_location_precision(params) new_loc = loc + delta return self.gau_man.join_location_precision(new_loc, shape)
[docs] @dataclass(frozen=True) class NormalCovarianceEmbedding[SubRep: PositiveDefinite, AmbRep: PositiveDefinite]( LinearEmbedding[Normal[SubRep], Normal[AmbRep]] ): """Embedding of a normal distribution with a simpler covariance structure into a more complex one.""" # Fields _sub_man: Normal[SubRep] """The sub-manifold with the simpler covariance structure.""" _amb_man: Normal[AmbRep] """The super-manifold with the more complex covariance structure.""" def __post_init__(self): if not isinstance(self.sub_man.cov_man.rep, type(self.amb_man.cov_man.rep)): raise TypeError( f"Sub-manifold rep {self.sub_man.cov_man.rep} must be simpler than super-manifold rep {self.amb_man.cov_man.rep}" ) @property @override def amb_man(self) -> Normal[AmbRep]: return self._amb_man @property @override def sub_man(self) -> Normal[SubRep]: return self._sub_man
[docs] @override def project(self, means: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride] """Project from ambient to sub-manifold representation. Parameters ---------- means : Array Mean parameters in ambient manifold. Returns ------- Array Mean parameters in sub-manifold. """ return self.amb_man.project_rep(self.sub_man, means)
[docs] @override def embed(self, params: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride] """Embed from sub-manifold to ambient representation. Parameters ---------- params : Array Natural parameters in sub-manifold. Returns ------- Array Natural parameters in ambient manifold. """ return self.sub_man.embed_rep(self.amb_man, params)
[docs] @override def translate( # pyright: ignore[reportIncompatibleMethodOverride] self, params: Array, delta: Array, ) -> Array: """Translate by embedding and adding. Parameters ---------- params : Array Natural parameters in ambient manifold. delta : Array Natural parameters in sub-manifold to add. Returns ------- Array Translated natural parameters in ambient manifold. """ embedded_q = self.sub_man.embed_rep(self.amb_man, delta) return params + embedded_q
@dataclass(frozen=True) class BoltzmannEmbedding(LinearEmbedding[DiagonalBoltzmann, Boltzmann]): """Embedding of DiagonalBoltzmann (mean-field) into full Boltzmann. This embedding connects the mean-field approximation (independent binary units) to the full Boltzmann machine (with pairwise coupling). - DiagonalBoltzmann has n parameters (biases only, no coupling) - Boltzmann has n(n+1)/2 parameters (biases absorbed into diagonal + off-diagonal coupling) The embedding places DiagonalBoltzmann parameters on the diagonal of the coupling matrix (as bias terms) with zero off-diagonal coupling. """ _sub_man: DiagonalBoltzmann """The mean-field Boltzmann (diagonal/independent units).""" _amb_man: Boltzmann """The full Boltzmann machine with coupling.""" @property @override def sub_man(self) -> DiagonalBoltzmann: return self._sub_man @property @override def amb_man(self) -> Boltzmann: return self._amb_man @override def project(self, means: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride] """Extract first moment (location) from full Boltzmann means. Parameters ---------- means : Array Mean parameters in Boltzmann space. Returns ------- Array Mean parameters in DiagonalBoltzmann space (first moments E[x_i]). """ loc, _ = self._amb_man.split_mean_second_moment(means) return loc @override def embed(self, params: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride] """Embed DiagonalBoltzmann params into full Boltzmann with zero coupling. Parameters ---------- params : Array Natural parameters in DiagonalBoltzmann space (biases). Returns ------- Array Natural parameters in Boltzmann space (biases on diagonal, zero coupling). """ zero_prec = self._amb_man.shp_man.zeros() return self._amb_man.join_location_precision(params, zero_prec) @override def translate( # pyright: ignore[reportIncompatibleMethodOverride] self, params: Array, delta: Array, ) -> Array: """Add DiagonalBoltzmann delta to Boltzmann params (location/bias only). Parameters ---------- params : Array Natural parameters in Boltzmann space. delta : Array Natural parameters in DiagonalBoltzmann space to add. Returns ------- Array Translated natural parameters in Boltzmann space. """ loc, prec = self._amb_man.split_location_precision(params) return self._amb_man.join_location_precision(loc + delta, prec)
[docs] @dataclass(frozen=True) class LGM[ ObsRep: PositiveDefinite, PostGaussian: GeneralizedGaussian[Any, Any], PriorGaussian: GeneralizedGaussian[Any, Any], ]( DifferentiableConjugated[Normal[ObsRep], PostGaussian, PriorGaussian], ABC, ): """A linear Gaussian model (LGM) implemented as a harmonium with Gaussian latent variables. Linear Gaussian Models represent a joint distribution over observable variables $X$ and latent variables $Z$ where both are Gaussian and the relationship between them is linear. In generative terms, this can be viewed as: $$x = Az + \\mu + \\epsilon$$ where: - $z$ is drawn from a multivariate normal (typically a standard normal), - $A$ is the loading matrix mapping latent to observable space, - $\\mu$ is the observable bias term, and - $\\epsilon \\sim \\mathcal{N}(0, \\Sigma)$ is Gaussian noise. **Posterior vs Prior Structure**: The posterior latent distribution (conditioned on observables) uses the `PostGaussian` parameterization, which may employ a restricted covariance structure (e.g., diagonal) for computational efficiency during frequent inference. The prior latent distribution uses the `PriorGaussian` parameterization, whose shape is dictated by the conjugation parameters. When `PostGaussian` is more restricted than `PriorGaussian`, the prior is constructed by embedding the restricted posterior covariance structure into the fuller prior structure, ensuring compatibility with the required conjugation parameter computation. As a harmonium, the joint distribution takes the form $$p(x,z) \\propto \\exp(\\theta_X \\cdot s_X(x) + \\theta_Z \\cdot s_Z(z) + x \\cdot \\Theta^m_{XZ} \\cdot z),$$ where - $s_X(x) = (x, \\text{tril}(x \\otimes x))$ is the sufficient statistic of the observable normal, - $s_Z(z) = (z, \\text{tril}(z \\otimes z))$ is the sufficient statistic of the latent normal, and - and $\\Theta^m_{XZ}$ are the first-order interaction terms between $X$ and $Z$. The conjugation parameters are $\\rho = (\\rho^m, P^{\\sigma})$ where - $\\rho^m = -\\frac{1}{2} \\Theta^m_{ZX} \\cdot {\\Theta_X^{\\sigma}}^{-1} \\cdot \\theta^m_X$ - $P^{\\sigma} = -\\frac{1}{4} \\Theta^m_{ZX} \\cdot {\\Theta_X^{\\sigma}}^{-1} \\cdot \\Theta^m_{XZ}$ """ # Fields obs_dim: int """Dimension of the observable variables.""" obs_rep: ObsRep """Covariance structure of the observable variables.""" ### Methods ### # Overrides @property @override def obs_man(self) -> Normal[ObsRep]: """Override to construct directly from fields, avoiding circular dependency.""" return Normal(self.obs_dim, self.obs_rep) @property def int_obs_emb(self) -> GeneralizedGaussianLocationEmbedding[Normal[ObsRep]]: return GeneralizedGaussianLocationEmbedding(self.obs_man) @property def int_pst_emb(self) -> LinearEmbedding[Euclidean, PostGaussian]: """Embedding of Euclidean location into posterior latent - general for all GeneralizedGaussians.""" return GeneralizedGaussianLocationEmbedding(self.pst_man) @property @override def int_man(self) -> EmbeddedMap[PostGaussian, Normal[ObsRep]]: return EmbeddedMap( Rectangular(), self.int_pst_emb, self.int_obs_emb, )
[docs] @override def conjugation_parameters( self, lkl_params: Array, ) -> Array: """Compute conjugation parameters for linear Gaussian model. Parameters ---------- lkl_params : Array Natural parameters for likelihood function. Returns ------- Array Natural parameters for conjugation in PriorGaussian space. """ # Get parameters obs_cov_man = self.obs_man.cov_man obs_bias, int_mat = self.lkl_fun_man.split_coords(lkl_params) obs_loc, obs_prec = self.obs_man.split_location_precision(obs_bias) # Intermediate computations obs_sigma = obs_cov_man.inverse(obs_prec) obs_mean = obs_cov_man(obs_sigma, obs_loc) # Conjugation parameters im = self.int_man int_mat_trn = im.transpose(int_mat) rho_mean = im.trn_man.rep.matvec( im.trn_man.matrix_shape, int_mat_trn, obs_mean ) _, rho_shape = _change_of_basis( im.matrix_shape, im.rep, int_mat, obs_cov_man.rep, obs_sigma, ) rho_shape *= -1 # Join parameters into moment parameters return self.prr_man.join_location_precision(rho_mean, rho_shape)
[docs] @dataclass(frozen=True) class NormalLGM[ObsRep: PositiveDefinite, PstRep: PositiveDefinite]( LGM[ObsRep, Normal[PstRep], FullNormal], ): """Differentiable Linear Gaussian Model with Normal latent variables. Extends the abstract LGM with Normal-specific implementations for computing observable distributions and converting to joint Normal form. """ lat_dim: int """Dimension of the latent variables.""" pst_rep: PstRep # Overrides @property @override def pst_man(self) -> Normal[PstRep]: """Override to construct directly from fields, avoiding circular dependency.""" return Normal(self.lat_dim, self.pst_rep) @property @override def pst_prr_emb(self) -> NormalCovarianceEmbedding[PstRep, PositiveDefinite]: """Embedding of posterior Normal into prior Normal via covariance structure.""" prior_gau = full_normal(self.lat_dim) return NormalCovarianceEmbedding(self.pst_man, prior_gau) # Methods
[docs] def observable_distribution( self, params: Array, ) -> tuple[FullNormal, Array]: # (Normal, Natural[Normal]) """Returns the marginal normal distribution over observable variables. Parameters ---------- params : Array Natural parameters for the linear Gaussian model. Returns ------- tuple[Normal, Array] The Normal manifold and its natural parameters. """ # Build transposed LGM with full covariance observable variables transposed_lgm = NormalAnalyticLGM( obs_dim=self.pst_man.data_dim, # Original latent becomes observable obs_rep=PositiveDefinite(), lat_dim=self.obs_dim, # Original observable becomes latent ) # Construct parameters for transposed model obs_params, int_params, lat_params = self.split_coords(params) nor_man = transposed_lgm.prr_man obs_params_emb = self.obs_man.embed_rep(nor_man, obs_params) lat_params_emb = self.pst_man.embed_rep(transposed_lgm.obs_man, lat_params) # Join parameters with interaction matrix transposed transposed_params = transposed_lgm.join_coords( lat_params_emb, # Original latent becomes observable self.int_man.transpose(int_params), obs_params_emb, ) # Use harmonium prior to get marginal distribution return nor_man, transposed_lgm.prior(transposed_params)
[docs] def whiten_prior(self, means: Array) -> Array: """Reparameterize so latent prior is N(0,I) while preserving the observable marginal. In mean coordinates: - obs_means: unchanged (observable marginal E[s_X(x)] is preserved) - lat_means: set to standard_normal() (mean coords of N(0,I)) - int_means: updated to WL where W \\Sigma_z = E[x \\otimes z] - E[x] \\otimes E[z] and L = chol(\\Sigma_z) """ obs_means, int_means, lat_means = self.split_coords(means) obs_loc, _ = self.obs_man.split_mean_second_moment(obs_means) lat_mean, lat_cov = self.prr_man.split_mean_covariance(lat_means) # W \Sigma_z = E[x \otimes z] - E[x] \otimes E[z] int_mat = self.int_man.to_matrix(int_means) # (obs_dim, lat_dim) cross_cov = int_mat - jnp.outer(obs_loc, lat_mean) # W \Sigma_z # WL = W \Sigma_z @ L^{-T}, L = chol(\Sigma_z) chol = jnp.linalg.cholesky(self.prr_man.cov_man.to_matrix(lat_cov)) wl_mat = jax.scipy.linalg.solve_triangular(chol, cross_cov.T, lower=True).T new_int_means = self.int_man.from_matrix(wl_mat) new_lat_means = self.prr_man.standard_normal() return self.join_coords(obs_means, new_int_means, new_lat_means)
[docs] def to_normal(self, params: Array) -> Array: """Convert a linear model to a normal model. Parameters ---------- params : Array Natural parameters for the linear Gaussian model. Returns ------- Array Natural parameters for the joint Normal distribution. """ lat_dim = self.prr_man.data_dim new_man: NormalLGM[PositiveDefinite, PositiveDefinite] = NormalLGM( obs_dim=self.obs_man.data_dim, obs_rep=PositiveDefinite(), lat_dim=lat_dim, pst_rep=PositiveDefinite(), ) obs_params, int_params, lat_params = self.split_coords(params) emb_obs_params = self.obs_man.embed_rep(new_man.obs_man, obs_params) emb_lat_params = self.pst_man.embed_rep(new_man.prr_man, lat_params) obs_loc, obs_prs = new_man.obs_man.split_location_precision(emb_obs_params) lat_loc, lat_prs = new_man.prr_man.split_location_precision(emb_lat_params) nor_man = full_normal(self.data_dim) nor_loc = jnp.concatenate([obs_loc, lat_loc]) obs_prs_array = new_man.obs_man.cov_man.to_matrix(obs_prs) lat_prs_array = new_man.prr_man.cov_man.to_matrix(lat_prs) int_array = -self.int_man.to_matrix(int_params) joint_shape_array = jnp.block( [[obs_prs_array, int_array], [int_array.T, lat_prs_array]] ) return nor_man.join_location_precision( nor_loc, nor_man.cov_man.from_matrix(joint_shape_array) )
[docs] @dataclass(frozen=True) class BoltzmannLGM[ObsRep: PositiveDefinite]( SymmetricConjugated[Normal[ObsRep], Boltzmann], LGM[ObsRep, Boltzmann, Boltzmann], ): """Differentiable Linear Gaussian Model with Boltzmann latent variables. This model combines a Normal observable distribution with Boltzmann (binary) latent variables. The latent states are discrete binary vectors, making this suitable for discrete representation learning and binary latent factor models. The observable distribution remains Gaussian (continuous), while the latent distribution is a Boltzmann machine (discrete binary). This enables learning discrete latent representations of continuous data. """ lat_dim: int """Number of binary latent units.""" # Overrides @property @override def pst_man(self) -> Boltzmann: """Override to construct directly from fields, avoiding circular dependency.""" return Boltzmann(self.lat_dim) @property @override def pst_prr_emb(self) -> LinearEmbedding[Boltzmann, Boltzmann]: """Embedding of posterior Boltzmann into prior Boltzmann. For Boltzmann machines, both posterior and prior use the same manifold structure (no covariance simplification like in Normal case), so we use the identity embedding. """ return IdentityEmbedding(self.pst_man) @property @override def lat_man(self) -> Boltzmann: """The latent manifold is a Boltzmann machine.""" return Boltzmann(self.lat_dim)
@dataclass(frozen=True) class DifferentiableBoltzmannLGM[ObsRep: PositiveDefinite]( LGM[ObsRep, DiagonalBoltzmann, Boltzmann], ): """Differentiable Linear Gaussian Model with mean-field Boltzmann latent variables. This model combines a Normal observable distribution with Boltzmann (binary) latent variables, using a mean-field (independent units) approximation for the posterior. **Posterior vs Prior Structure**: - Posterior: Uses `DiagonalBoltzmann` (independent binary units) for computational efficiency. This mean-field approximation has O(n) log partition computation. - Prior: Uses full `Boltzmann` with pairwise coupling. The conjugation parameters naturally produce coupling terms even when the posterior is mean-field. This asymmetric structure enables efficient inference while maintaining the expressiveness of the full Boltzmann prior for modeling complex dependencies. """ lat_dim: int """Number of binary latent units.""" # Overrides @property @override def pst_man(self) -> DiagonalBoltzmann: """Mean-field posterior: diagonal Boltzmann (independent units).""" return DiagonalBoltzmann(self.lat_dim) @property @override def pst_prr_emb(self) -> BoltzmannEmbedding: """Embedding from mean-field DiagonalBoltzmann to full Boltzmann.""" return BoltzmannEmbedding(self.pst_man, Boltzmann(self.lat_dim))
[docs] @dataclass(frozen=True) class NormalAnalyticLGM[ObsRep: PositiveDefinite]( AnalyticConjugated[Normal[ObsRep], FullNormal], NormalLGM[ObsRep, PositiveDefinite], ): """Analytic Linear Gaussian Model that extends the differentiable LGM with full analytical tractability, adding conversions between mean and natural coordinates, and a closed-form implementation of EM.""" def __init__(self, obs_dim: int, obs_rep: ObsRep, lat_dim: int): super().__init__( obs_dim=obs_dim, obs_rep=obs_rep, lat_dim=lat_dim, pst_rep=PositiveDefinite(), ) @property @override def lat_man(self) -> FullNormal: """The latent manifold is a full Normal distribution.""" return full_normal(self.lat_dim) @property @override def pst_prr_emb(self) -> NormalCovarianceEmbedding[PositiveDefinite, PositiveDefinite]: """Embedding of posterior Normal into prior Normal via covariance structure.""" prior_gau = full_normal(self.lat_dim) return NormalCovarianceEmbedding(self.pst_man, prior_gau)
[docs] @override def to_natural_likelihood( self, means: Array, # Mean[Self] ) -> Array: """Convert mean parameters to natural likelihood parameters. Parameters ---------- means : Array Mean parameters for the analytic linear Gaussian model. Returns ------- Array Natural parameters for likelihood function. """ # Get relevant manifolds ocm = self.obs_man.cov_man lcm = self.lat_man.cov_man im = self.int_man # Deconstruct parameters obs_means, int_means, lat_means = self.split_coords(means) obs_mean, obs_cov = self.obs_man.split_mean_covariance(obs_means) lat_mean, lat_cov = self.lat_man.split_mean_covariance(lat_means) int_cov = int_means - im.rep.outer_product(obs_mean, lat_mean) # Construct precisions lat_prs = lcm.inverse(lat_cov) im_trn = im.trn_man int_cov_t = im.transpose(int_cov) # cob_man, cob = _change_of_basis(im_trn, int_cov_t, lat_cov_man, lat_prs) cob_man, cob = _change_of_basis( im_trn.matrix_shape, im_trn.rep, int_cov_t, lcm.rep, lat_prs, ) shaped_cob = ocm.from_matrix(cob_man.to_matrix(cob)) obs_prs = ocm.inverse(obs_cov - shaped_cob) sizes = ( ocm.matrix_shape[0], ocm.matrix_shape[1], im.matrix_shape[1], lcm.matrix_shape[1], ) _, _, int_params = _dual_composition( sizes, ocm.rep, obs_prs, im.rep, int_cov, lcm.rep, lat_prs, ) obs_loc0 = ocm(obs_prs, obs_mean) # obs_loc1 = self._int_man_internal(int_params, lat_mean) obs_loc1 = im.rep.matvec(im.matrix_shape, int_params, lat_mean) obs_loc = obs_loc0 - obs_loc1 # Return natural parameters obs_params = self.obs_man.join_location_precision(obs_loc, obs_prs) return self.lkl_fun_man.join_coords(obs_params, int_params)
[docs] @dataclass(frozen=True) class FactorAnalysis(NormalAnalyticLGM[Diagonal]): """A factor analysis model with Gaussian latent variables.""" def __init__(self, obs_dim: int, lat_dim: int): super().__init__(obs_dim, Diagonal(), lat_dim)
[docs] @override def expectation_maximization( self, params: Array, xs: Array, ) -> Array: """Perform a single iteration of the EM algorithm. Without further constraints the latent Normal of FA is not identifiable, and so we hold it fixed at standard normal. Parameters ---------- params : Array Current natural parameters. xs : Array Observation data. Returns ------- Array Updated natural parameters. """ # E-step: compute mean statistics; whiten to set prior to N(0,I) while # preserving the observable marginal, then convert to natural coordinates. q = self.mean_posterior_statistics(params, xs) return self.to_natural(self.whiten_prior(q))
[docs] def from_loadings( self, loadings: Array, means: Array, diags: Array, ) -> Array: """Convert standard factor analysis parameters to natural parameters. Parameters ---------- loadings : Array Loading matrix (obs_dim, lat_dim). means : Array Observation means. diags : Array Diagonal noise variances. Returns ------- Array Natural parameters for the factor analysis model. """ # Initialize interaction matrix scaled by precision om = self.obs_man mu = means cov = diags obs_params = om.to_natural(om.join_mean_covariance(mu, cov)) obs_prs = om.split_location_precision(obs_params)[1] dns_prs = om.cov_man.to_matrix(obs_prs) int_mat = self.int_man.from_matrix(dns_prs @ loadings) # Combine parameters lkl_params = self.lkl_fun_man.join_coords(obs_params, int_mat) z = self.lat_man.to_natural(self.lat_man.standard_normal()) return self.join_conjugated(lkl_params, z)
[docs] @dataclass(frozen=True) class PrincipalComponentAnalysis(NormalAnalyticLGM[Scale]): """A principal component analysis model with Gaussian latent variables.""" def __init__(self, obs_dim: int, lat_dim: int): super().__init__(obs_dim, Scale(), lat_dim)
[docs] @override def expectation_maximization( self, params: Array, xs: Array, ) -> Array: """Perform a single iteration of the EM algorithm. Without further constraints the latent Normal of PCA is not identifiable, and so we hold it fixed at standard normal. Parameters ---------- params : Array Current natural parameters. xs : Array Observation data. Returns ------- Array Updated natural parameters. """ # E-step: compute mean statistics; whiten to set prior to N(0,I) while # preserving the observable marginal, then convert to natural coordinates. q = self.mean_posterior_statistics(params, xs) return self.to_natural(self.whiten_prior(q))
### Helper Functions ### def _dual_composition( sizes: tuple[int, int, int, int], h_rep: MatrixRep, h_params: Array, # Parameters in some coordinate system g_rep: MatrixRep, g_params: Array, # Parameters in dual coordinates f_rep: MatrixRep, f_params: Array, # Parameters in original coordinate system ) -> tuple[ MatrixRep, tuple[int, int], # Output shape Array, # Parameters in original coordinate system ]: """Three-way matrix multiplication that respects coordinate duality. Computes h @ g @ f where g is in dual coordinates. """ # First multiply g @ f h_shape = (sizes[0], sizes[1]) g_shape = (sizes[1], sizes[2]) f_shape = (sizes[2], sizes[3]) rep_gf, shape_gf, params_gf = g_rep.matmat( g_shape, g_params, f_rep, f_shape, f_params, ) # Then multiply h @ (g @ f) return h_rep.matmat(h_shape, h_params, rep_gf, shape_gf, params_gf) def _change_of_basis( f_size: tuple[int, int], f_rep: MatrixRep, f_params: Array, # Parameters in some coordinate system g_rep: PositiveDefinite, g_params: Array, # Parameters in dual coordinates ) -> tuple[ Covariance[Any], Array, # Parameters in original coordinate system ]: """Linear change of basis transformation. Computes f.T @ g @ f where g is in dual coordinates. The result is always symmetric (positive semi-definite if g is positive definite). """ sizes = (f_size[1], f_size[0], f_size[0], f_size[1]) f_trans_params = f_rep.transpose(f_size, f_params) fgf_rep, fgf_sizes, fgf_params = _dual_composition( sizes, f_rep, f_trans_params, g_rep, g_params, f_rep, f_params, ) # If fgf_rep is diagonal or stricter, preserve it; otherwise use PositiveDefinite # since the result of f.T @ g @ f is always symmetric if isinstance(fgf_rep, (Diagonal, Scale, Identity)): cov_man = Covariance(fgf_sizes[0], fgf_rep) else: cov_man = Covariance(fgf_sizes[0], PositiveDefinite()) fgf_params = cov_man.from_matrix( fgf_rep.to_matrix(cov_man.matrix_shape, fgf_params) ) return cov_man, fgf_params