Source code for goal.models.graphical.hmog

"""Hierarchical Mixture of Gaussians (HMoG) models.

This module provides concrete implementations of hierarchical Gaussian models that combine
linear Gaussian dimensionality reduction with Gaussian mixture clustering, enabling joint
learning of latent factor representations and cluster assignments.

**Model structure**: HMoG models have two levels:

- **Lower harmonium**: Maps observations :math:`X \\in \\mathbb{R}^p` to first-level latent factors
  :math:`Y \\in \\mathbb{R}^d` using a linear Gaussian relationship (factor analysis/PCA)
- **Upper harmonium**: Models a mixture of Gaussians over the latent space :math:`Y`

The joint distribution factors as:

.. math::

    p(X, Y, Z) = p(Z) \\cdot p(Y | Z) \\cdot p(X | Y)

where :math:`Z \\in \\{1,\\ldots,K\\}` are discrete cluster assignments.

**Variants**: Three implementations with different analytical properties:

- **DifferentiableHMoG**: Gradient-based optimization, uses restricted posterior covariance
  for efficiency (e.g., diagonal)
- **SymmetricHMoG**: Symmetric posterior/prior structure, additional functionality like
  `join_conjugated`, but slower due to full covariance matrix operations
- **AnalyticHMoG**: Fully analytic, enables closed-form EM and bidirectional parameter conversion

Factory functions (`differentiable_hmog`, `symmetric_hmog`, `analytic_hmog`) provide
convenient construction for common configurations.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import override

import jax
import jax.numpy as jnp
from jax import Array

from ...geometry import (
    AnalyticHierarchical,
    DifferentiableHierarchical,
    PositiveDefinite,
    SymmetricHierarchical,
)
from ..base.gaussian.normal import FullNormal, Normal, full_normal
from ..harmonium.lgm import (
    NormalAnalyticLGM,
    NormalCovarianceEmbedding,
    NormalLGM,
)
from ..harmonium.mixture import AnalyticMixture, Mixture

### HMoG Classes ###


# TODO: Somehow figure out how to share more code between these classes while preserving the type system.
[docs] @dataclass(frozen=True) class DifferentiableHMoG[ObsRep: PositiveDefinite, PstRep: PositiveDefinite]( DifferentiableHierarchical[ NormalLGM[ObsRep, PstRep], AnalyticMixture[Normal[PstRep]], Mixture[FullNormal], ] ): """Differentiable Hierarchical Mixture of Gaussians. This model combines: 1. A linear Gaussian model (factor analysis) mapping observations to latents 2. A Gaussian mixture model over the latent space Supports gradient-based optimization via log-likelihood descent. Uses full covariance Gaussians in the latent space. **Posterior vs Prior Structure**: The posterior latent mixture (`pst_upr_hrm`) uses an AnalyticMixture with a restricted covariance structure for computational efficiency. The prior latent mixture (`prr_upr_hrm`) embeds the restricted structure into full covariance for conjugation parameter computation. """
[docs] def whiten_prior(self, means: Array) -> Array: """Reparameterize the latent Y-space to have zero mean and identity covariance. Preserves p(x) by updating both: - The lower LGM interaction (loading matrix + observable bias adjustment) - Each GMM component (via the existing Normal.whiten relative to GMM marginal) """ obs_means, lwr_int_means, lat_means = self.split_coords(means) # GMM marginal statistics (obs_means_gmm = E[s_Y(y)] w.r.t. joint) obs_means_gmm, _, cat_means = self.pst_upr_hrm.split_coords(lat_means) lat_mean_y, lat_cov_y = self.pst_upr_hrm.obs_man.split_mean_covariance( obs_means_gmm ) chol = jnp.linalg.cholesky( self.pst_upr_hrm.obs_man.cov_man.to_matrix(lat_cov_y) ) # Whiten each GMM component using existing Normal.whiten comp_means, _ = self.pst_upr_hrm.split_mean_mixture(lat_means) new_comp_means = self.pst_upr_hrm.cmp_man.map( lambda c: self.pst_upr_hrm.obs_man.relative_whiten(c, obs_means_gmm), comp_means, flatten=True, ) new_lat_means = self.pst_upr_hrm.join_mean_mixture(new_comp_means, cat_means) # Update lower LGM cross-statistics (same transform as LGM whitening) obs_loc, _ = self.obs_man.split_mean_second_moment(obs_means) lwr_int_mat = self.lwr_hrm.int_man.to_matrix(lwr_int_means) cross_cov = lwr_int_mat - jnp.outer(obs_loc, lat_mean_y) # W Cov(Y) new_lwr_int_mat = jax.scipy.linalg.solve_triangular( chol, cross_cov.T, lower=True ).T new_lwr_int_means = self.lwr_hrm.int_man.from_matrix(new_lwr_int_mat) return self.join_coords(obs_means, new_lwr_int_means, new_lat_means)
# HMoG-specific methods
[docs] def posterior_categorical(self, params: Array, x: Array) -> Array: """Compute posterior categorical distribution p(Z|x) in natural coordinates. Returns the natural parameters of the categorical distribution over mixture components in the latent space given an observation. Args: params: Model parameters (natural coordinates) x: Observable data point Returns: Array of shape (n_components-1,) with categorical natural parameters """ # Compute posterior latent mixture posterior = self.posterior_at(params, x) # Extract categorical marginal from the mixture return self.pst_upr_hrm.prior(posterior)
[docs] def posterior_soft_assignments(self, params: Array, x: Array) -> Array: """Compute posterior assignment probabilities p(Z|x). Returns the posterior probability distribution over mixture components in the latent space given an observation. Args: params: Model parameters (natural coordinates) x: Observable data point Returns: Array of shape (n_components,) giving p(z_k|x) for each component k """ cat_natural = self.posterior_categorical(params, x) cat_means = self.pst_upr_hrm.lat_man.to_mean(cat_natural) return self.pst_upr_hrm.lat_man.to_probs(cat_means)
[docs] def posterior_hard_assignment(self, params: Array, x: Array) -> Array: """Compute hard posterior assignments p(Z|x). Returns the index of the most probable mixture component in the latent space given an observation. Args: params: Model parameters (natural coordinates) x: Observable data point Returns: Integer index of the most probable component """ soft_assignments = self.posterior_soft_assignments(params, x) return jnp.argmax(soft_assignments)
[docs] @dataclass(frozen=True) class SymmetricHMoG[ObsRep: PositiveDefinite]( SymmetricHierarchical[NormalAnalyticLGM[ObsRep], Mixture[FullNormal]] ): """Symmetric Hierarchical Mixture of Gaussians. This model supports gradient-based optimization with additional functionality (e.g., join_conjugated) not available in DifferentiableHMoG. The symmetric structure means the posterior and conjugated latent spaces are the same, enabling bidirectional parameter transformations. Trade-off: Matrix inversions happen in the space of full covariance matrices over the latent space, which can be slower than DifferentiableHMoG. """ # HMoG-specific methods
[docs] def posterior_categorical(self, params: Array, x: Array) -> Array: """Compute posterior categorical distribution p(Z|x) in natural coordinates. Returns the natural parameters of the categorical distribution over mixture components in the latent space given an observation. Args: params: Model parameters (natural coordinates) x: Observable data point Returns: Array of shape (n_components-1,) with categorical natural parameters """ # Compute posterior latent mixture posterior = self.posterior_at(params, x) # Extract categorical marginal from the mixture return self.upr_hrm.prior(posterior)
[docs] def posterior_assignments(self, params: Array, x: Array) -> Array: """Compute posterior assignment probabilities p(Z|x). Returns the posterior probability distribution over mixture components in the latent space given an observation. Args: params: Model parameters (natural coordinates) x: Observable data point Returns: Array of shape (n_components,) giving p(z_k|x) for each component k """ cat_natural = self.posterior_categorical(params, x) return self.upr_hrm.lat_man.to_probs(cat_natural)
[docs] @dataclass(frozen=True) class AnalyticHMoG[ObsRep: PositiveDefinite]( AnalyticHierarchical[NormalAnalyticLGM[ObsRep], AnalyticMixture[FullNormal]] ): """Analytic Hierarchical Mixture of Gaussians. This model enables: - Closed-form EM algorithm for learning (from AnalyticConjugated) - Bidirectional parameter conversion (mean <-> natural) - Full analytical tractability Requires full covariance Gaussians in the latent space. """
[docs] @override def expectation_maximization(self, params: Array, xs: Array) -> Array: """Perform a single iteration of EM with latent-prior whitening. HMoG has the same latent-space non-identifiability as FA/PCA. After the E-step, whiten the latent prior in mean coordinates before mapping back to natural coordinates. """ q = self.mean_posterior_statistics(params, xs) return self.to_natural(self.whiten_prior(q))
[docs] def whiten_prior(self, means: Array) -> Array: """Reparameterize the latent Y-space to have zero mean and identity covariance. Preserves p(x) by updating both: - The lower LGM interaction (loading matrix + observable bias adjustment) - Each GMM component (via the existing Normal.whiten relative to GMM marginal) """ obs_means, lwr_int_means, lat_means = self.split_coords(means) # GMM marginal statistics (obs_means_gmm = E[s_Y(y)] w.r.t. joint) obs_means_gmm, _, cat_means = self.upr_hrm.split_coords(lat_means) lat_mean_y, lat_cov_y = self.upr_hrm.obs_man.split_mean_covariance( obs_means_gmm ) chol = jnp.linalg.cholesky(self.upr_hrm.obs_man.cov_man.to_matrix(lat_cov_y)) # Whiten each GMM component using existing Normal.whiten comp_means, _ = self.upr_hrm.split_mean_mixture(lat_means) new_comp_means = self.upr_hrm.cmp_man.map( lambda c: self.upr_hrm.obs_man.relative_whiten(c, obs_means_gmm), comp_means, flatten=True, ) new_lat_means = self.upr_hrm.join_mean_mixture(new_comp_means, cat_means) # Update lower LGM cross-statistics (same transform as LGM whitening) obs_loc, _ = self.obs_man.split_mean_second_moment(obs_means) lwr_int_mat = self.lwr_hrm.int_man.to_matrix(lwr_int_means) cross_cov = lwr_int_mat - jnp.outer(obs_loc, lat_mean_y) # W Cov(Y) new_lwr_int_mat = jax.scipy.linalg.solve_triangular( chol, cross_cov.T, lower=True ).T new_lwr_int_means = self.lwr_hrm.int_man.from_matrix(new_lwr_int_mat) return self.join_coords(obs_means, new_lwr_int_means, new_lat_means)
# HMoG-specific methods
[docs] def posterior_categorical(self, params: Array, x: Array) -> Array: """Compute posterior categorical distribution p(Z|x) in natural coordinates. Returns the natural parameters of the categorical distribution over mixture components in the latent space given an observation. Args: params: Model parameters (natural coordinates) x: Observable data point Returns: Array of shape (n_components-1,) with categorical natural parameters """ # Compute posterior latent mixture posterior = self.posterior_at(params, x) # Extract categorical marginal from the mixture return self.upr_hrm.prior(posterior)
[docs] def posterior_assignments(self, params: Array, x: Array) -> Array: """Compute posterior assignment probabilities p(Z|x). Returns the posterior probability distribution over mixture components in the latent space given an observation. Args: params: Model parameters (natural coordinates) x: Observable data point Returns: Array of shape (n_components,) giving p(z_k|x) for each component k """ cat_natural = self.posterior_categorical(params, x) return self.upr_hrm.lat_man.to_probs(cat_natural)
## Factory Functions ##
[docs] def differentiable_hmog[ObsRep: PositiveDefinite, PstRep: PositiveDefinite]( obs_dim: int, obs_rep: ObsRep, lat_dim: int, pst_rep: PstRep, n_components: int, ) -> DifferentiableHMoG[ObsRep, PstRep]: """Create a differentiable hierarchical mixture of Gaussians model. This function constructs a hierarchical model combining: 1. A bottom layer with a linear Gaussian model reducing observables to first-level latents 2. A top layer with a Gaussian mixture model for modelling the latent distribution This model supports optimization via log-likelihood gradient descent. Uses full covariance Gaussians in the latent space. """ pst_y_man = Normal(lat_dim, pst_rep) prr_y_man = full_normal(lat_dim) lwr_hrm = NormalLGM(obs_dim, obs_rep, lat_dim, pst_rep) mix_sub = NormalCovarianceEmbedding(pst_y_man, prr_y_man) pst_upr_hrm = AnalyticMixture(pst_y_man, n_components) # Use the embedding directly for the mixture prr_upr_hrm = Mixture(n_components, mix_sub) return DifferentiableHMoG( lwr_hrm, pst_upr_hrm, prr_upr_hrm, )
[docs] def symmetric_hmog[ObsRep: PositiveDefinite]( obs_dim: int, obs_rep: ObsRep, lat_dim: int, lat_rep: PositiveDefinite, n_components: int, ) -> SymmetricHMoG[ObsRep]: """Create a symmetric hierarchical mixture of Gaussians model. Supports optimization via log-likelihood gradient descent with additional functionality (e.g., `join_conjugated`) not available in `DifferentiableHMoG`. The symmetric structure means posterior and prior use the same latent parameterization. Trade-off: Matrix inversions happen in the space of full covariance matrices over the latent space, which can be slower than `DifferentiableHMoG`. """ mid_lat_man = full_normal(lat_dim) sub_lat_man = Normal(lat_dim, lat_rep) mix_sub = NormalCovarianceEmbedding(sub_lat_man, mid_lat_man) lwr_hrm = NormalAnalyticLGM(obs_dim, obs_rep, lat_dim) # Use the embedding directly for the mixture upr_hrm = Mixture(n_components, mix_sub) return SymmetricHMoG( lwr_hrm, upr_hrm, )
[docs] def analytic_hmog[ObsRep: PositiveDefinite]( obs_dim: int, obs_rep: ObsRep, lat_dim: int, n_components: int, ) -> AnalyticHMoG[ObsRep]: """Create an analytic hierarchical mixture of Gaussians model. Enables closed-form expectation-maximization for learning and bidirectional parameter conversion between natural and mean coordinates. Requires full covariance Gaussians in the latent space for complete analytical tractability. """ lat_man = full_normal(lat_dim) lwr_hrm = NormalAnalyticLGM(obs_dim, obs_rep, lat_dim) upr_hrm = AnalyticMixture(lat_man, n_components) return AnalyticHMoG( lwr_hrm, upr_hrm, )