Source code for goal.models.harmonium.mixture

"""Mixture Models as Conjugate Harmoniums.

This module implements mixture models using a harmonium structure where

- The observable manifold is an exponential family, and
- the latent manifold is `Categorical` distribution over mixture components
"""

from __future__ import annotations

from dataclasses import dataclass, replace
from typing import override

import jax
import jax.numpy as jnp
from jax import Array

from ...geometry import (
    Analytic,
    AnalyticConjugated,
    Differentiable,
    DifferentiableConjugated,
    EmbeddedMap,
    IdentityEmbedding,
    LinearEmbedding,
    Manifold,
    Product,
    Rectangular,
    StatisticalMoments,
    SymmetricConjugated,
)
from ..base.categorical import (
    Categorical,
)


[docs] @dataclass(frozen=True) class Mixture[Observable: Differentiable]( SymmetricConjugated[Observable, Categorical], DifferentiableConjugated[Observable, Categorical, Categorical], ): """Mixture models with exponential family observations. A mixture model represents a weighted combination of component distributions: $$p(x) = \\sum_{k=1}^K \\pi_k p(x|\\theta_k)$$ where $\\pi_k$ are mixture weights summing to 1, and $p(x|\\theta_k)$ are component distributions from the same exponential family. In the harmonium framework, mixture models are implemented using: - A categorical latent variable $z$ representing component assignment - An observable distribution family for component distributions - An interaction matrix between $z$ and $x$ that contain component specific parameters The interaction matrix structure is fixed: Rectangular matrix representation with identity domain embedding. Only the observable embedding varies. """ # Fields n_categories: int """Number of mixture components.""" obs_emb: LinearEmbedding[Manifold, Observable] """Observable embedding - determines which observable parameters are mixed.""" # Template Methods @property def cmp_man(self) -> Product[Observable]: """Manifold for all components of mixture.""" return Product(self.obs_man, self.n_categories)
[docs] def join_mean_mixture( self, components: Array, weights: Array, ) -> Array: """Create a mixture model in mean coordinates from components and weights. In mean coordinates, projections are the correct operation because expectations of sufficient statistics are preserved: $\\mathbb{E}[s_{\\text{sub}}(x)] = \\text{project}(\\mathbb{E}[s_{\\text{full}}(x)])$. This method projects component parameters to the interaction submanifold, storing only the submanifold portion while averaging out statistics outside it. **Technical note:** When using restricted submanifolds (non-identity embeddings), component statistics outside the interaction submanifold are irreversibly lost (aggregated into observable biases). This is mathematically correct for mean coordinates but makes the operation non-invertible without additional information. Args: components: Flat 1D array of component parameters (shape ``[n_categories * obs_dim]``) weights: Categorical weights """ probs = self.pst_man.to_probs(weights) comps_2d = self.cmp_man.to_2d(components) weighted_comps = comps_2d * probs[:, None] obs_means = jnp.sum(weighted_comps, axis=0) # Project components (excluding first) to interaction subspace projected_comps = jax.vmap(self.int_man.cod_emb.project)(weighted_comps[1:]) # [n_categories-1, sub_obs_dim] # Transpose and convert to int_man storage format int_means = projected_comps.T.ravel() return self.join_coords(obs_means, int_means, weights)
# Overrides @property @override def lat_man(self) -> Categorical: return Categorical(self.n_categories) @property @override def int_man(self) -> EmbeddedMap[Categorical, Observable]: """Construct interaction matrix from observable embedding. Structure is fixed: Rectangular matrix with identity domain embedding. """ return EmbeddedMap( Rectangular(), IdentityEmbedding(self.lat_man), self.obs_emb, ) # Methods
[docs] def split_natural_mixture( self, natural_params: Array, ) -> tuple[Array, Array]: """Split a mixture model in natural coordinates into components and prior. In natural coordinates, embeddings are the correct operation because they preserve the log-density form of exponential families. The interaction parameters in $\\text{IntObservable}$ space are embedded into $\\text{Observable}$ space by translating the observable bias: $\\theta_k = \\text{embed}(\\theta_X, \\delta_k)$ ensures that $\\log p(x|z=k) = \\theta_k \\cdot s(x) - \\psi(\\theta_k) + \\text{const}$. This operation is invertible for general $\\text{IntObservable} \\subsetneq \\text{Observable}$ because zero-padding in natural space preserves the exponential family structure. Returns: Tuple of (components, prior) where components is a flat 1D array of shape ``[n_categories * obs_dim]`` representing parameters on ``cmp_man``. """ lkl_params, prr_params = self.split_conjugated(natural_params) obs_bias, int_mat = self.lkl_fun_man.split_coords(lkl_params) # Convert to 2D matrix and transpose to get columns as rows int_cols = self.int_man.to_matrix(int_mat).T # [n_categories-1, sub_obs_dim] # Translate each column from subspace to full observable space def translate_col(col: Array) -> Array: return self.int_man.cod_emb.translate(obs_bias, col) translated = jax.vmap(translate_col)(int_cols) # Stack with first component and flatten to match cmp_man convention components = jnp.vstack([obs_bias[None, :], translated]).ravel() return components, prr_params
[docs] def observable_mean_covariance( self, natural_params: Array, ) -> tuple[Array, Array]: """Compute mean and covariance of the observable variables.""" if not isinstance(self.obs_man, StatisticalMoments): raise TypeError( f"Observable manifold {type(self.obs_man)} does not support statistical moment computation" ) comp_params, cat_params = self.split_natural_mixture(natural_params) weights = self.lat_man.to_probs(self.lat_man.to_mean(cat_params)) # Use statistical_mean/covariance for component statistics cmp_means = self.cmp_man.map(self.obs_man.statistical_mean, comp_params) cmp_covs = self.cmp_man.map(self.obs_man.statistical_covariance, comp_params) # Compute mixture mean mean = jnp.einsum("k,ki->i", weights, cmp_means) # Compute mixture covariance using law of total variance # First term: weighted sum of component covariances cov = jnp.einsum("k,kij->ij", weights, cmp_covs) # Second term: weighted sum of outer products of component means mean_outer = jnp.einsum("ki,kj->kij", cmp_means, cmp_means) cov += jnp.einsum("k,kij->ij", weights, mean_outer) # Third term: subtract outer product of mixture mean cov -= jnp.outer(mean, mean) return mean, cov
# Overrides
[docs] @override def conjugation_parameters( self, lkl_params: Array, ) -> Array: """Compute conjugation parameters for categorical mixture. In particular, $$\\rho_k = \\psi(\\theta_X + \\theta_{XZ,k}) - \\psi(\\theta_X)$$ """ # Compute base term from observable bias obs_bias, int_mat = self.lkl_fun_man.split_coords(lkl_params) rho_0 = self.obs_man.log_partition_function(obs_bias) # Convert to 2D matrix and transpose to get columns as rows int_comps = self.int_man.to_matrix(int_mat).T # [n_categories-1, sub_obs_dim] def compute_rho(comp_params: Array) -> Array: adjusted_obs = self.int_man.cod_emb.translate(obs_bias, comp_params) return self.obs_man.log_partition_function(adjusted_obs) - rho_0 return jax.vmap(compute_rho)(int_comps) # [n_categories-1]
[docs] class CompleteMixture[Observable: Differentiable]( Mixture[Observable], ): """Mixture models where the complete observable manifold is mixed.""" # Constructor def __init__(self, obs_man: Observable, n_categories: int): # Use identity observable embedding for complete mixture obs_emb = IdentityEmbedding(obs_man) super().__init__(n_categories, obs_emb)
[docs] def split_mean_mixture( self, means: Array, ) -> tuple[Array, Array]: """Split a mixture model in mean coordinates into components and weights. **Constraint:** This method is restricted to $\\text{Observable} = \\text{IntObservable}$. Since `join_mean_mixture` projects component statistics to the interaction submanifold, information outside that submanifold is irreversibly lost. Only when the full Observable space is used for interactions can the decomposition perfectly invert the join operation. Returns: Tuple of (components, weights) where components is a flat 1D array of shape ``[n_categories * obs_dim]`` representing parameters on ``cmp_man``. """ obs_means, int_means, cat_means = self.split_coords(means) probs = self.lat_man.to_probs(cat_means) # shape: (n_categories,) # Convert to 2D matrix and transpose to get columns as rows [n_categories-1, obs_dim] int_dense = self.int_man.to_matrix(int_means) # [obs_dim, n_categories-1] int_cols = int_dense.T # [n_categories-1, obs_dim] # Compute first component # Sum interaction columns - shape: (obs_dim,) sum_interactions = jnp.sum(int_cols, axis=0) first_comp = (obs_means - sum_interactions) / probs[0] # Scale remaining components by their probabilities # shape: (n_categories-1, obs_dim) other_comps = int_cols / probs[1:, None] # Combine components and flatten to match cmp_man convention components = jnp.vstack([first_comp[None, :], other_comps]).ravel() return components, cat_means
[docs] def join_natural_mixture( self, components: Array, prior: Array, ) -> Array: """Create a mixture model in natural coordinates from components and prior. **Constraint:** This method is restricted to $\\text{Observable} = \\text{IntObservable}$. Component differences $\\theta_k - \\theta_0$ are computed in Observable space and stored directly. For general submanifolds, these differences would need to be projected to $\\text{IntObservable}$ space, but projection in natural coordinates does not preserve the exponential family structure (would corrupt the log-density). Algorithm: Uses the first component as an anchor, storing differences in the interaction matrix: $\\Theta_{XZ,k} = \\theta_k - \\theta_0$. Args: components: Flat 1D array of component parameters (shape ``[n_categories * obs_dim]``) prior: Prior parameters for categorical distribution """ # Get anchor (first component) - shape: (obs_dim,) obs_bias = self.cmp_man.get_replicate(components, 0) def to_interaction(comp: Array) -> Array: return comp - obs_bias # Get remaining components (flat 1D) remaining_start = self.obs_man.dim components_rest = components[remaining_start:] cmp_man_minus = replace(self.cmp_man, n_reps=self.n_categories - 1) # projected_comps shape: (n_categories-1, obs_dim) projected_comps = cmp_man_minus.map(to_interaction, components_rest) # Transpose to [obs_dim, n_categories-1] and convert to int_man storage int_mat = self.int_man.rep.from_matrix(projected_comps.T) lkl_params = self.lkl_fun_man.join_coords(obs_bias, int_mat) return self.join_conjugated(lkl_params, prior)
[docs] class AnalyticMixture[Observable: Analytic]( CompleteMixture[Observable], AnalyticConjugated[Observable, Categorical], ): """Mixture model with analytically tractable components. Amongst other things, this enables closed-form implementation of expectation-maximization.""" # Overrides
[docs] @override def to_natural_likelihood( self, means: Array, ) -> Array: """Map mean harmonium parameters to natural likelihood parameters. **Constraint:** This method requires $\\text{Observable} = \\text{IntObservable}$. This is necessary because converting from mean to natural coordinates requires decomposing the mixture into individual component parameters. In the general case where interactions are restricted to a submanifold, information outside that submanifold is irreversibly lost during `join_mean_mixture`, making the decomposition impossible. Algorithm: 1. Recover component means using `split_mean_mixture` (requires full Observable recovery) 2. Convert each component mean parameter to natural parameters 3. Compute differences relative to first component (interaction matrix) """ # Get component means (flat 1D) comp_means, _ = self.split_mean_mixture(means) # Convert each component to natural parameters (flat 1D) def to_natural(mean: Array) -> Array: return self.obs_man.to_natural(mean) nat_comps = self.cmp_man.map(to_natural, comp_means, flatten=True) # Get anchor (first component) - shape: (obs_dim,) obs_bias = self.cmp_man.get_replicate(nat_comps, 0) def to_interaction(nat: Array) -> Array: return nat - obs_bias # Convert remaining components to interactions remaining_start = self.obs_man.dim nat_comps_rest = nat_comps[remaining_start:] cmp_man1 = replace(self.cmp_man, n_reps=self.n_categories - 1) # int_cols shape: (n_categories-1, obs_dim) int_cols = cmp_man1.map(to_interaction, nat_comps_rest) # Transpose to [obs_dim, n_categories-1] and convert to int_man storage int_mat = self.int_man.rep.from_matrix(int_cols.T) return self.lkl_fun_man.join_coords(obs_bias, int_mat)