"""This module provides implementations of linear Gaussian models (LGMs), including factor analysis and principal component analysis. LGMs model linear, Gaussian relationships between observable and latent variables. The conjugacy of LGMs enables exact inference and EM."""
from __future__ import annotations
from abc import ABC
from dataclasses import dataclass
from typing import Any, override
import jax
import jax.numpy as jnp
from jax import Array
from ...geometry import (
AnalyticConjugated,
Diagonal,
DifferentiableConjugated,
EmbeddedMap,
Identity,
IdentityEmbedding,
LinearEmbedding,
MatrixRep,
PositiveDefinite,
Rectangular,
Scale,
SymmetricConjugated,
)
from ..base.gaussian.boltzmann import Boltzmann, DiagonalBoltzmann
from ..base.gaussian.generalized import Euclidean, GeneralizedGaussian
from ..base.gaussian.normal import (
Covariance,
FullNormal,
Normal,
full_normal,
)
### Helper Functions ###
[docs]
@dataclass(frozen=True)
class GeneralizedGaussianLocationEmbedding[G: GeneralizedGaussian[Any, Any]](
LinearEmbedding[Euclidean, G],
):
"""Embedding of the Euclidean location component into a GeneralizedGaussian distribution.
Projects a GeneralizedGaussian point in mean coordinates to its Euclidean location component, or embeds
a location vector in natural coordinates into a GeneralizedGaussian with zero shape parameters.
"""
gau_man: G
"""The GeneralizedGaussian distribution."""
@property
@override
def amb_man(self) -> G:
return self.gau_man
@property
@override
def sub_man(self) -> Euclidean:
return self.gau_man.loc_man
[docs]
@override
def project(self, means: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride]
"""Project to Euclidean location component.
Works on mean coordinates, extracting the location component from the full
generalized Gaussian parameterization. If given a data point (size == data_dim),
converts it to sufficient statistics (mean coordinates) first.
Parameters
----------
means : Array
Mean coordinate parameters or data point in GeneralizedGaussian space.
Returns
-------
Array
Mean parameters in Euclidean space (location only).
"""
# Convert data points to sufficient statistics (mean coordinates)
if means.size == self.gau_man.data_dim and means.size != self.gau_man.dim:
means = self.gau_man.sufficient_statistic(means)
loc, _ = self.gau_man.split_mean_second_moment(means)
return loc
[docs]
@override
def embed(self, params: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride]
"""Embed Euclidean location into GeneralizedGaussian with zero shape.
Parameters
----------
params : Array
Natural parameters in Euclidean space.
Returns
-------
Array
Natural parameters in GeneralizedGaussian space.
"""
zero_shape = self.gau_man.shp_man.zeros()
return self.gau_man.join_location_precision(params, zero_shape)
[docs]
@override
def translate( # pyright: ignore[reportIncompatibleMethodOverride]
self,
params: Array,
delta: Array,
) -> Array:
"""Translate by adding Euclidean offset to location.
Parameters
----------
params : Array
Natural parameters in GeneralizedGaussian space.
delta : Array
Euclidean offset to add.
Returns
-------
Array
Translated natural parameters in GeneralizedGaussian space.
"""
loc, shape = self.gau_man.split_location_precision(params)
new_loc = loc + delta
return self.gau_man.join_location_precision(new_loc, shape)
[docs]
@dataclass(frozen=True)
class NormalCovarianceEmbedding[SubRep: PositiveDefinite, AmbRep: PositiveDefinite](
LinearEmbedding[Normal[SubRep], Normal[AmbRep]]
):
"""Embedding of a normal distribution with a simpler covariance structure into a more complex one."""
# Fields
_sub_man: Normal[SubRep]
"""The sub-manifold with the simpler covariance structure."""
_amb_man: Normal[AmbRep]
"""The super-manifold with the more complex covariance structure."""
def __post_init__(self):
if not isinstance(self.sub_man.cov_man.rep, type(self.amb_man.cov_man.rep)):
raise TypeError(
f"Sub-manifold rep {self.sub_man.cov_man.rep} must be simpler than super-manifold rep {self.amb_man.cov_man.rep}"
)
@property
@override
def amb_man(self) -> Normal[AmbRep]:
return self._amb_man
@property
@override
def sub_man(self) -> Normal[SubRep]:
return self._sub_man
[docs]
@override
def project(self, means: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride]
"""Project from ambient to sub-manifold representation.
Parameters
----------
means : Array
Mean parameters in ambient manifold.
Returns
-------
Array
Mean parameters in sub-manifold.
"""
return self.amb_man.project_rep(self.sub_man, means)
[docs]
@override
def embed(self, params: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride]
"""Embed from sub-manifold to ambient representation.
Parameters
----------
params : Array
Natural parameters in sub-manifold.
Returns
-------
Array
Natural parameters in ambient manifold.
"""
return self.sub_man.embed_rep(self.amb_man, params)
[docs]
@override
def translate( # pyright: ignore[reportIncompatibleMethodOverride]
self,
params: Array,
delta: Array,
) -> Array:
"""Translate by embedding and adding.
Parameters
----------
params : Array
Natural parameters in ambient manifold.
delta : Array
Natural parameters in sub-manifold to add.
Returns
-------
Array
Translated natural parameters in ambient manifold.
"""
embedded_q = self.sub_man.embed_rep(self.amb_man, delta)
return params + embedded_q
@dataclass(frozen=True)
class BoltzmannEmbedding(LinearEmbedding[DiagonalBoltzmann, Boltzmann]):
"""Embedding of DiagonalBoltzmann (mean-field) into full Boltzmann.
This embedding connects the mean-field approximation (independent binary units)
to the full Boltzmann machine (with pairwise coupling).
- DiagonalBoltzmann has n parameters (biases only, no coupling)
- Boltzmann has n(n+1)/2 parameters (biases absorbed into diagonal + off-diagonal coupling)
The embedding places DiagonalBoltzmann parameters on the diagonal of the coupling matrix
(as bias terms) with zero off-diagonal coupling.
"""
_sub_man: DiagonalBoltzmann
"""The mean-field Boltzmann (diagonal/independent units)."""
_amb_man: Boltzmann
"""The full Boltzmann machine with coupling."""
@property
@override
def sub_man(self) -> DiagonalBoltzmann:
return self._sub_man
@property
@override
def amb_man(self) -> Boltzmann:
return self._amb_man
@override
def project(self, means: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride]
"""Extract first moment (location) from full Boltzmann means.
Parameters
----------
means : Array
Mean parameters in Boltzmann space.
Returns
-------
Array
Mean parameters in DiagonalBoltzmann space (first moments E[x_i]).
"""
loc, _ = self._amb_man.split_mean_second_moment(means)
return loc
@override
def embed(self, params: Array) -> Array: # pyright: ignore[reportIncompatibleMethodOverride]
"""Embed DiagonalBoltzmann params into full Boltzmann with zero coupling.
Parameters
----------
params : Array
Natural parameters in DiagonalBoltzmann space (biases).
Returns
-------
Array
Natural parameters in Boltzmann space (biases on diagonal, zero coupling).
"""
zero_prec = self._amb_man.shp_man.zeros()
return self._amb_man.join_location_precision(params, zero_prec)
@override
def translate( # pyright: ignore[reportIncompatibleMethodOverride]
self,
params: Array,
delta: Array,
) -> Array:
"""Add DiagonalBoltzmann delta to Boltzmann params (location/bias only).
Parameters
----------
params : Array
Natural parameters in Boltzmann space.
delta : Array
Natural parameters in DiagonalBoltzmann space to add.
Returns
-------
Array
Translated natural parameters in Boltzmann space.
"""
loc, prec = self._amb_man.split_location_precision(params)
return self._amb_man.join_location_precision(loc + delta, prec)
[docs]
@dataclass(frozen=True)
class LGM[
ObsRep: PositiveDefinite,
PostGaussian: GeneralizedGaussian[Any, Any],
PriorGaussian: GeneralizedGaussian[Any, Any],
](
DifferentiableConjugated[Normal[ObsRep], PostGaussian, PriorGaussian],
ABC,
):
"""A linear Gaussian model (LGM) implemented as a harmonium with Gaussian latent variables.
Linear Gaussian Models represent a joint distribution over observable variables $X$ and latent variables $Z$ where both are Gaussian and the relationship between them is linear. In generative terms, this can be viewed as:
$$x = Az + \\mu + \\epsilon$$
where:
- $z$ is drawn from a multivariate normal (typically a standard normal),
- $A$ is the loading matrix mapping latent to observable space,
- $\\mu$ is the observable bias term, and
- $\\epsilon \\sim \\mathcal{N}(0, \\Sigma)$ is Gaussian noise.
**Posterior vs Prior Structure**: The posterior latent distribution (conditioned on observables) uses the `PostGaussian` parameterization, which may employ a restricted covariance structure (e.g., diagonal) for computational efficiency during frequent inference. The prior latent distribution uses the `PriorGaussian` parameterization, whose shape is dictated by the conjugation parameters. When `PostGaussian` is more restricted than `PriorGaussian`, the prior is constructed by embedding the restricted posterior covariance structure into the fuller prior structure, ensuring compatibility with the required conjugation parameter computation.
As a harmonium, the joint distribution takes the form
$$p(x,z) \\propto \\exp(\\theta_X \\cdot s_X(x) + \\theta_Z \\cdot s_Z(z) + x \\cdot \\Theta^m_{XZ} \\cdot z),$$
where
- $s_X(x) = (x, \\text{tril}(x \\otimes x))$ is the sufficient statistic of the observable normal,
- $s_Z(z) = (z, \\text{tril}(z \\otimes z))$ is the sufficient statistic of the latent normal, and
- and $\\Theta^m_{XZ}$ are the first-order interaction terms between $X$ and $Z$.
The conjugation parameters are $\\rho = (\\rho^m, P^{\\sigma})$ where
- $\\rho^m = -\\frac{1}{2} \\Theta^m_{ZX} \\cdot {\\Theta_X^{\\sigma}}^{-1} \\cdot \\theta^m_X$
- $P^{\\sigma} = -\\frac{1}{4} \\Theta^m_{ZX} \\cdot {\\Theta_X^{\\sigma}}^{-1} \\cdot \\Theta^m_{XZ}$
"""
# Fields
obs_dim: int
"""Dimension of the observable variables."""
obs_rep: ObsRep
"""Covariance structure of the observable variables."""
### Methods ###
# Overrides
@property
@override
def obs_man(self) -> Normal[ObsRep]:
"""Override to construct directly from fields, avoiding circular dependency."""
return Normal(self.obs_dim, self.obs_rep)
@property
def int_obs_emb(self) -> GeneralizedGaussianLocationEmbedding[Normal[ObsRep]]:
return GeneralizedGaussianLocationEmbedding(self.obs_man)
@property
def int_pst_emb(self) -> LinearEmbedding[Euclidean, PostGaussian]:
"""Embedding of Euclidean location into posterior latent - general for all GeneralizedGaussians."""
return GeneralizedGaussianLocationEmbedding(self.pst_man)
@property
@override
def int_man(self) -> EmbeddedMap[PostGaussian, Normal[ObsRep]]:
return EmbeddedMap(
Rectangular(),
self.int_pst_emb,
self.int_obs_emb,
)
[docs]
@override
def conjugation_parameters(
self,
lkl_params: Array,
) -> Array:
"""Compute conjugation parameters for linear Gaussian model.
Parameters
----------
lkl_params : Array
Natural parameters for likelihood function.
Returns
-------
Array
Natural parameters for conjugation in PriorGaussian space.
"""
# Get parameters
obs_cov_man = self.obs_man.cov_man
obs_bias, int_mat = self.lkl_fun_man.split_coords(lkl_params)
obs_loc, obs_prec = self.obs_man.split_location_precision(obs_bias)
# Intermediate computations
obs_sigma = obs_cov_man.inverse(obs_prec)
obs_mean = obs_cov_man(obs_sigma, obs_loc)
# Conjugation parameters
im = self.int_man
int_mat_trn = im.transpose(int_mat)
rho_mean = im.trn_man.rep.matvec(
im.trn_man.matrix_shape, int_mat_trn, obs_mean
)
_, rho_shape = _change_of_basis(
im.matrix_shape,
im.rep,
int_mat,
obs_cov_man.rep,
obs_sigma,
)
rho_shape *= -1
# Join parameters into moment parameters
return self.prr_man.join_location_precision(rho_mean, rho_shape)
[docs]
@dataclass(frozen=True)
class NormalLGM[ObsRep: PositiveDefinite, PstRep: PositiveDefinite](
LGM[ObsRep, Normal[PstRep], FullNormal],
):
"""Differentiable Linear Gaussian Model with Normal latent variables.
Extends the abstract LGM with Normal-specific implementations for computing
observable distributions and converting to joint Normal form.
"""
lat_dim: int
"""Dimension of the latent variables."""
pst_rep: PstRep
# Overrides
@property
@override
def pst_man(self) -> Normal[PstRep]:
"""Override to construct directly from fields, avoiding circular dependency."""
return Normal(self.lat_dim, self.pst_rep)
@property
@override
def pst_prr_emb(self) -> NormalCovarianceEmbedding[PstRep, PositiveDefinite]:
"""Embedding of posterior Normal into prior Normal via covariance structure."""
prior_gau = full_normal(self.lat_dim)
return NormalCovarianceEmbedding(self.pst_man, prior_gau)
# Methods
[docs]
def observable_distribution(
self,
params: Array,
) -> tuple[FullNormal, Array]: # (Normal, Natural[Normal])
"""Returns the marginal normal distribution over observable variables.
Parameters
----------
params : Array
Natural parameters for the linear Gaussian model.
Returns
-------
tuple[Normal, Array]
The Normal manifold and its natural parameters.
"""
# Build transposed LGM with full covariance observable variables
transposed_lgm = NormalAnalyticLGM(
obs_dim=self.pst_man.data_dim, # Original latent becomes observable
obs_rep=PositiveDefinite(),
lat_dim=self.obs_dim, # Original observable becomes latent
)
# Construct parameters for transposed model
obs_params, int_params, lat_params = self.split_coords(params)
nor_man = transposed_lgm.prr_man
obs_params_emb = self.obs_man.embed_rep(nor_man, obs_params)
lat_params_emb = self.pst_man.embed_rep(transposed_lgm.obs_man, lat_params)
# Join parameters with interaction matrix transposed
transposed_params = transposed_lgm.join_coords(
lat_params_emb, # Original latent becomes observable
self.int_man.transpose(int_params),
obs_params_emb,
)
# Use harmonium prior to get marginal distribution
return nor_man, transposed_lgm.prior(transposed_params)
[docs]
def whiten_prior(self, means: Array) -> Array:
"""Reparameterize so latent prior is N(0,I) while preserving the observable marginal.
In mean coordinates:
- obs_means: unchanged (observable marginal E[s_X(x)] is preserved)
- lat_means: set to standard_normal() (mean coords of N(0,I))
- int_means: updated to WL where W \\Sigma_z = E[x \\otimes z] - E[x] \\otimes E[z] and L = chol(\\Sigma_z)
"""
obs_means, int_means, lat_means = self.split_coords(means)
obs_loc, _ = self.obs_man.split_mean_second_moment(obs_means)
lat_mean, lat_cov = self.prr_man.split_mean_covariance(lat_means)
# W \Sigma_z = E[x \otimes z] - E[x] \otimes E[z]
int_mat = self.int_man.to_matrix(int_means) # (obs_dim, lat_dim)
cross_cov = int_mat - jnp.outer(obs_loc, lat_mean) # W \Sigma_z
# WL = W \Sigma_z @ L^{-T}, L = chol(\Sigma_z)
chol = jnp.linalg.cholesky(self.prr_man.cov_man.to_matrix(lat_cov))
wl_mat = jax.scipy.linalg.solve_triangular(chol, cross_cov.T, lower=True).T
new_int_means = self.int_man.from_matrix(wl_mat)
new_lat_means = self.prr_man.standard_normal()
return self.join_coords(obs_means, new_int_means, new_lat_means)
[docs]
def to_normal(self, params: Array) -> Array:
"""Convert a linear model to a normal model.
Parameters
----------
params : Array
Natural parameters for the linear Gaussian model.
Returns
-------
Array
Natural parameters for the joint Normal distribution.
"""
lat_dim = self.prr_man.data_dim
new_man: NormalLGM[PositiveDefinite, PositiveDefinite] = NormalLGM(
obs_dim=self.obs_man.data_dim,
obs_rep=PositiveDefinite(),
lat_dim=lat_dim,
pst_rep=PositiveDefinite(),
)
obs_params, int_params, lat_params = self.split_coords(params)
emb_obs_params = self.obs_man.embed_rep(new_man.obs_man, obs_params)
emb_lat_params = self.pst_man.embed_rep(new_man.prr_man, lat_params)
obs_loc, obs_prs = new_man.obs_man.split_location_precision(emb_obs_params)
lat_loc, lat_prs = new_man.prr_man.split_location_precision(emb_lat_params)
nor_man = full_normal(self.data_dim)
nor_loc = jnp.concatenate([obs_loc, lat_loc])
obs_prs_array = new_man.obs_man.cov_man.to_matrix(obs_prs)
lat_prs_array = new_man.prr_man.cov_man.to_matrix(lat_prs)
int_array = -self.int_man.to_matrix(int_params)
joint_shape_array = jnp.block(
[[obs_prs_array, int_array], [int_array.T, lat_prs_array]]
)
return nor_man.join_location_precision(
nor_loc, nor_man.cov_man.from_matrix(joint_shape_array)
)
[docs]
@dataclass(frozen=True)
class BoltzmannLGM[ObsRep: PositiveDefinite](
SymmetricConjugated[Normal[ObsRep], Boltzmann],
LGM[ObsRep, Boltzmann, Boltzmann],
):
"""Differentiable Linear Gaussian Model with Boltzmann latent variables.
This model combines a Normal observable distribution with Boltzmann (binary)
latent variables. The latent states are discrete binary vectors, making this
suitable for discrete representation learning and binary latent factor models.
The observable distribution remains Gaussian (continuous), while the latent
distribution is a Boltzmann machine (discrete binary). This enables learning
discrete latent representations of continuous data.
"""
lat_dim: int
"""Number of binary latent units."""
# Overrides
@property
@override
def pst_man(self) -> Boltzmann:
"""Override to construct directly from fields, avoiding circular dependency."""
return Boltzmann(self.lat_dim)
@property
@override
def pst_prr_emb(self) -> LinearEmbedding[Boltzmann, Boltzmann]:
"""Embedding of posterior Boltzmann into prior Boltzmann.
For Boltzmann machines, both posterior and prior use the same manifold
structure (no covariance simplification like in Normal case), so we use
the identity embedding.
"""
return IdentityEmbedding(self.pst_man)
@property
@override
def lat_man(self) -> Boltzmann:
"""The latent manifold is a Boltzmann machine."""
return Boltzmann(self.lat_dim)
@dataclass(frozen=True)
class DifferentiableBoltzmannLGM[ObsRep: PositiveDefinite](
LGM[ObsRep, DiagonalBoltzmann, Boltzmann],
):
"""Differentiable Linear Gaussian Model with mean-field Boltzmann latent variables.
This model combines a Normal observable distribution with Boltzmann (binary)
latent variables, using a mean-field (independent units) approximation for
the posterior.
**Posterior vs Prior Structure**:
- Posterior: Uses `DiagonalBoltzmann` (independent binary units) for computational
efficiency. This mean-field approximation has O(n) log partition computation.
- Prior: Uses full `Boltzmann` with pairwise coupling. The conjugation parameters
naturally produce coupling terms even when the posterior is mean-field.
This asymmetric structure enables efficient inference while maintaining the
expressiveness of the full Boltzmann prior for modeling complex dependencies.
"""
lat_dim: int
"""Number of binary latent units."""
# Overrides
@property
@override
def pst_man(self) -> DiagonalBoltzmann:
"""Mean-field posterior: diagonal Boltzmann (independent units)."""
return DiagonalBoltzmann(self.lat_dim)
@property
@override
def pst_prr_emb(self) -> BoltzmannEmbedding:
"""Embedding from mean-field DiagonalBoltzmann to full Boltzmann."""
return BoltzmannEmbedding(self.pst_man, Boltzmann(self.lat_dim))
[docs]
@dataclass(frozen=True)
class NormalAnalyticLGM[ObsRep: PositiveDefinite](
AnalyticConjugated[Normal[ObsRep], FullNormal],
NormalLGM[ObsRep, PositiveDefinite],
):
"""Analytic Linear Gaussian Model that extends the differentiable LGM with full analytical tractability, adding conversions between mean and natural coordinates, and a closed-form implementation of EM."""
def __init__(self, obs_dim: int, obs_rep: ObsRep, lat_dim: int):
super().__init__(
obs_dim=obs_dim,
obs_rep=obs_rep,
lat_dim=lat_dim,
pst_rep=PositiveDefinite(),
)
@property
@override
def lat_man(self) -> FullNormal:
"""The latent manifold is a full Normal distribution."""
return full_normal(self.lat_dim)
@property
@override
def pst_prr_emb(self) -> NormalCovarianceEmbedding[PositiveDefinite, PositiveDefinite]:
"""Embedding of posterior Normal into prior Normal via covariance structure."""
prior_gau = full_normal(self.lat_dim)
return NormalCovarianceEmbedding(self.pst_man, prior_gau)
[docs]
@override
def to_natural_likelihood(
self,
means: Array, # Mean[Self]
) -> Array:
"""Convert mean parameters to natural likelihood parameters.
Parameters
----------
means : Array
Mean parameters for the analytic linear Gaussian model.
Returns
-------
Array
Natural parameters for likelihood function.
"""
# Get relevant manifolds
ocm = self.obs_man.cov_man
lcm = self.lat_man.cov_man
im = self.int_man
# Deconstruct parameters
obs_means, int_means, lat_means = self.split_coords(means)
obs_mean, obs_cov = self.obs_man.split_mean_covariance(obs_means)
lat_mean, lat_cov = self.lat_man.split_mean_covariance(lat_means)
int_cov = int_means - im.rep.outer_product(obs_mean, lat_mean)
# Construct precisions
lat_prs = lcm.inverse(lat_cov)
im_trn = im.trn_man
int_cov_t = im.transpose(int_cov)
# cob_man, cob = _change_of_basis(im_trn, int_cov_t, lat_cov_man, lat_prs)
cob_man, cob = _change_of_basis(
im_trn.matrix_shape,
im_trn.rep,
int_cov_t,
lcm.rep,
lat_prs,
)
shaped_cob = ocm.from_matrix(cob_man.to_matrix(cob))
obs_prs = ocm.inverse(obs_cov - shaped_cob)
sizes = (
ocm.matrix_shape[0],
ocm.matrix_shape[1],
im.matrix_shape[1],
lcm.matrix_shape[1],
)
_, _, int_params = _dual_composition(
sizes,
ocm.rep,
obs_prs,
im.rep,
int_cov,
lcm.rep,
lat_prs,
)
obs_loc0 = ocm(obs_prs, obs_mean)
# obs_loc1 = self._int_man_internal(int_params, lat_mean)
obs_loc1 = im.rep.matvec(im.matrix_shape, int_params, lat_mean)
obs_loc = obs_loc0 - obs_loc1
# Return natural parameters
obs_params = self.obs_man.join_location_precision(obs_loc, obs_prs)
return self.lkl_fun_man.join_coords(obs_params, int_params)
[docs]
@dataclass(frozen=True)
class FactorAnalysis(NormalAnalyticLGM[Diagonal]):
"""A factor analysis model with Gaussian latent variables."""
def __init__(self, obs_dim: int, lat_dim: int):
super().__init__(obs_dim, Diagonal(), lat_dim)
[docs]
@override
def expectation_maximization(
self,
params: Array,
xs: Array,
) -> Array:
"""Perform a single iteration of the EM algorithm.
Without further constraints the latent Normal of FA is not identifiable,
and so we hold it fixed at standard normal.
Parameters
----------
params : Array
Current natural parameters.
xs : Array
Observation data.
Returns
-------
Array
Updated natural parameters.
"""
# E-step: compute mean statistics; whiten to set prior to N(0,I) while
# preserving the observable marginal, then convert to natural coordinates.
q = self.mean_posterior_statistics(params, xs)
return self.to_natural(self.whiten_prior(q))
[docs]
def from_loadings(
self,
loadings: Array,
means: Array,
diags: Array,
) -> Array:
"""Convert standard factor analysis parameters to natural parameters.
Parameters
----------
loadings : Array
Loading matrix (obs_dim, lat_dim).
means : Array
Observation means.
diags : Array
Diagonal noise variances.
Returns
-------
Array
Natural parameters for the factor analysis model.
"""
# Initialize interaction matrix scaled by precision
om = self.obs_man
mu = means
cov = diags
obs_params = om.to_natural(om.join_mean_covariance(mu, cov))
obs_prs = om.split_location_precision(obs_params)[1]
dns_prs = om.cov_man.to_matrix(obs_prs)
int_mat = self.int_man.from_matrix(dns_prs @ loadings)
# Combine parameters
lkl_params = self.lkl_fun_man.join_coords(obs_params, int_mat)
z = self.lat_man.to_natural(self.lat_man.standard_normal())
return self.join_conjugated(lkl_params, z)
[docs]
@dataclass(frozen=True)
class PrincipalComponentAnalysis(NormalAnalyticLGM[Scale]):
"""A principal component analysis model with Gaussian latent variables."""
def __init__(self, obs_dim: int, lat_dim: int):
super().__init__(obs_dim, Scale(), lat_dim)
[docs]
@override
def expectation_maximization(
self,
params: Array,
xs: Array,
) -> Array:
"""Perform a single iteration of the EM algorithm.
Without further constraints the latent Normal of PCA is not identifiable,
and so we hold it fixed at standard normal.
Parameters
----------
params : Array
Current natural parameters.
xs : Array
Observation data.
Returns
-------
Array
Updated natural parameters.
"""
# E-step: compute mean statistics; whiten to set prior to N(0,I) while
# preserving the observable marginal, then convert to natural coordinates.
q = self.mean_posterior_statistics(params, xs)
return self.to_natural(self.whiten_prior(q))
### Helper Functions ###
def _dual_composition(
sizes: tuple[int, int, int, int],
h_rep: MatrixRep,
h_params: Array, # Parameters in some coordinate system
g_rep: MatrixRep,
g_params: Array, # Parameters in dual coordinates
f_rep: MatrixRep,
f_params: Array, # Parameters in original coordinate system
) -> tuple[
MatrixRep,
tuple[int, int], # Output shape
Array, # Parameters in original coordinate system
]:
"""Three-way matrix multiplication that respects coordinate duality.
Computes h @ g @ f where g is in dual coordinates.
"""
# First multiply g @ f
h_shape = (sizes[0], sizes[1])
g_shape = (sizes[1], sizes[2])
f_shape = (sizes[2], sizes[3])
rep_gf, shape_gf, params_gf = g_rep.matmat(
g_shape,
g_params,
f_rep,
f_shape,
f_params,
)
# Then multiply h @ (g @ f)
return h_rep.matmat(h_shape, h_params, rep_gf, shape_gf, params_gf)
def _change_of_basis(
f_size: tuple[int, int],
f_rep: MatrixRep,
f_params: Array, # Parameters in some coordinate system
g_rep: PositiveDefinite,
g_params: Array, # Parameters in dual coordinates
) -> tuple[
Covariance[Any],
Array, # Parameters in original coordinate system
]:
"""Linear change of basis transformation.
Computes f.T @ g @ f where g is in dual coordinates. The result is always
symmetric (positive semi-definite if g is positive definite).
"""
sizes = (f_size[1], f_size[0], f_size[0], f_size[1])
f_trans_params = f_rep.transpose(f_size, f_params)
fgf_rep, fgf_sizes, fgf_params = _dual_composition(
sizes,
f_rep,
f_trans_params,
g_rep,
g_params,
f_rep,
f_params,
)
# If fgf_rep is diagonal or stricter, preserve it; otherwise use PositiveDefinite
# since the result of f.T @ g @ f is always symmetric
if isinstance(fgf_rep, (Diagonal, Scale, Identity)):
cov_man = Covariance(fgf_sizes[0], fgf_rep)
else:
cov_man = Covariance(fgf_sizes[0], PositiveDefinite())
fgf_params = cov_man.from_matrix(
fgf_rep.to_matrix(cov_man.matrix_shape, fgf_params)
)
return cov_man, fgf_params