Harmoniums

Harmonium models: product exponential families over observable and latent variables coupled through an interaction matrix.

Natural parameters of a harmonium are the concatenation [obs_params, int_params, lat_params] of observable biases, interaction matrix, and latent biases. The conjugated subclasses add structure that decomposes the joint into a likelihood and a prior, enabling exact density evaluation and EM.

Class Hierarchy

Inheritance diagram of goal.geometry.exponential_family.harmonium

Base Class

class Harmonium[source]

Bases: Gibbs, Triple[Observable, LinearMap, Posterior], ABC, Generic

A product exponential family over observable \(x\) and latent \(z\) variables coupled through an interaction matrix.

Mathematically, the joint log-density is \(\log p(x,z) = \theta_X \cdot \mathbf s_X(x) + \theta_Z \cdot \mathbf s_Z(z) + \mathbf s_X(x) \cdot \Theta_{XZ} \cdot \mathbf s_Z(z) - \psi(\theta)\), where \(\theta_X\), \(\theta_Z\) are observable and latent biases, and \(\Theta_{XZ}\) is the interaction matrix.

abstract property int_man: LinearMap[Posterior, Observable]

Manifold of the interaction matrix \(\Theta_{XZ}\).

property obs_man: Observable

Manifold of observable biases.

property pst_man: Posterior

Manifold of posterior specific latent biases.

property lkl_fun_man: AffineMap[Posterior, Observable]

Manifold of likelihood distributions \(p(x \mid z)\).

property pst_fun_man: AffineMap[Observable, Posterior]

Manifold of conditional posterior distributions \(p(z \mid x)\).

likelihood_function(params: Array) Array[source]

Extract the likelihood affine map \(\eta \mapsto \theta_X + \Theta_{XZ} \cdot \eta\) from the given natural parameters.

posterior_function(params: Array) Array[source]

Extract the posterior affine map \(\eta \mapsto \theta_Z + \Theta_{XZ}^\top \cdot \eta\) from the given natural parameters.

likelihood_at(params: Array, z: Array) Array[source]

Evaluate the likelihood natural parameters \(p(x \mid z)\) at a given latent state \(z\).

posterior_at(params: Array, x: Array) Array[source]

Evaluate the posterior natural parameters \(p(z \mid x)\) at a given observation \(x\).

property fst_man: Observable

First component manifold.

property snd_man: LinearMap[Posterior, Observable]

Second component manifold.

property trd_man: Posterior

Third component manifold.

property data_dim: int

Total dimension of data points.

sufficient_statistic(x: Array) Array[source]

Compute sufficient statistics of a joint observation \([x, z]\).

log_base_measure(x: Array) Array[source]

Compute log base measure of a joint observation \([x, z]\).

initialize(key: Array, location: float = 0.0, shape: float = 0.1) Array[source]

Initialize harmonium natural parameters with biases from component initialization and interaction matrix scaled by shape / sqrt(int_dim).

initialize_from_sample(key: Array, sample: Array, location: float = 0.0, shape: float = 0.1) Array[source]

Initialize harmonium natural parameters, using sample data for observable biases.

gibbs_step(key: Array, params: Array, state: Array) Array[source]

One Gibbs sweep: sample \(x \sim p(x \mid z)\), then \(z \sim p(z \mid x)\).

contrastive_divergence_step(key: Array, params: Array, x: Array, k: int = 1) Array[source]

Compute CD-k gradient contribution for a single observation at the given natural parameters.

Positive phase: sample \(z_0\) from \(p(z \mid x)\) with \(x\) clamped. Negative phase: run \(k\) joint Gibbs steps from \((x, z_0)\). Returns negative minus positive sufficient statistics, approximating \(\mathbb{E}_{\text{model}}[\mathbf{s}] - \mathbb{E}_{\text{data}}[\mathbf{s}]\).

mean_contrastive_divergence_gradient(key: Array, params: Array, xs: Array, k: int = 1) Array[source]

Compute average CD-k gradient over a batch of observations at the given natural parameters.

Conjugation

class Conjugated[source]

Bases: Harmonium, Generative, ABC, Generic

A harmonium whose prior \(p(z)\) belongs to the same exponential family as the posterior \(p(z \mid x)\), enabling exact computation of the prior via conjugation parameters \(\rho\).

abstract property pst_prr_emb: LinearEmbedding[Posterior, Prior]

Embedding of the posterior latent submanifold into the prior latent manifold.

abstractmethod conjugation_parameters(lkl_params: Array) Array[source]

Compute conjugation parameters \(\rho\) from the given likelihood natural parameters.

property prr_man: Prior

Manifold of the prior distribution.

prior(params: Array) Array[source]

Compute prior natural parameters \(p(z)\) from the given harmonium natural parameters.

split_conjugated(params: Array) tuple[Array, Array][source]

Split harmonium natural parameters into likelihood and prior natural parameters.

extract_likelihood_input(prr_sample: Array) Array[source]

Extract the variables needed to condition the likelihood from a prior sample.

Returns the full sample by default. Hierarchical models override to extract only the immediate child latent (e.g., \(y\) from a joint \(yz\) sample).

observable_sample(key: Array, params: Array, n: int = 1) Array[source]

Sample from the observable marginal \(p(x)\) by discarding latent components.

sample(key: Array, params: Array, n: int = 1) Array[source]

Sample from the joint by first sampling \(z \sim p(z)\), then \(x \sim p(x \mid z)\).

class DifferentiableConjugated[source]

Bases: Conjugated, Differentiable, ABC, Generic

A conjugated harmonium with an analytical log-partition function, enabling exact density evaluation and gradient-based optimization.

log_partition_function(params: Array) Array[source]

Compute \(\psi(\theta) = \psi_Z(\theta_Z + \rho) + \psi_X(\theta_X)\) at the given natural parameters.

log_observable_density(params: Array, x: Array) Array[source]

Compute log marginal density \(\log p(x)\) at the given natural parameters by integrating out the latent variable analytically.

observable_density(params: Array, x: Array) Array[source]

Compute marginal density \(p(x)\) at the given natural parameters.

average_log_observable_density(params: Array, xs: Array, batch_size: int = 2048) Array[source]

Compute average \(\log p(x)\) over a batch of observations at the given natural parameters.

posterior_statistics(params: Array, x: Array) Array[source]

Compute expected sufficient statistics \((\mathbf s_X(x),\, \mathbf s_X(x) \otimes \mathbb E[\mathbf s_Z \mid x],\, \mathbb E[\mathbf s_Z \mid x])\) in mean coordinates for a single observation, at the given natural parameters.

mean_posterior_statistics(params: Array, xs: Array, batch_size: int = 256) Array[source]

Compute average expected sufficient statistics over a batch of observations at the given natural parameters.

class SymmetricConjugated[source]

Bases: Conjugated, ABC, Generic

A conjugated harmonium where the posterior and prior share the same latent manifold (pst_man == prr_man).

abstract property lat_man: Latent

Manifold of latent biases.

property pst_prr_emb: LinearEmbedding[Latent, Latent]

Embedding of the posterior latent submanifold into the prior latent manifold.

property pst_man: Latent

Manifold of posterior specific latent biases.

property prr_man: Latent

Manifold of the prior distribution.

join_conjugated(lkl_params: Array, prior_params: Array) Array[source]

Join likelihood and prior natural parameters into harmonium natural parameters.

class AnalyticConjugated[source]

Bases: SymmetricConjugated, DifferentiableConjugated, Analytic, ABC, Generic

A symmetric conjugated harmonium with analytically tractable negative entropy, enabling closed-form KL divergence and expectation-maximization.

abstractmethod to_natural_likelihood(means: Array) Array[source]

Convert harmonium mean parameters to likelihood natural parameters.

to_natural(means: Array) Array[source]

Convert harmonium mean parameters to natural parameters.

negative_entropy(means: Array) Array[source]

Compute negative entropy \(\phi(\eta) = \eta \cdot \theta - \psi(\theta)\) at the given mean parameters.

expectation_maximization(params: Array, xs: Array) Array[source]

Perform one EM iteration: E-step computes expected sufficient statistics, M-step converts to natural parameters.