Source code for goal.models.harmonium.poisson

"""Poisson-related models: population codes, COM-Poisson, and variational wrappers.

This module provides:
- VariationalPoissonVonMises: Variational wrapper for Poisson-VonMises harmonium
- VonMisesPopulationCode: Population code with Poisson observables and Von Mises stimulus
- CoMPoissonPopulation: Population of COM-Poisson units with shared dispersion
- Poisson mixture type aliases and factory functions
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import override

import jax
import jax.numpy as jnp
from jax import Array

from ...geometry import (
    AnalyticProduct,
    DifferentiableProduct,
    EmbeddedMap,
    IdentityEmbedding,
    LocationShape,
    Product,
    Rectangular,
    TupleEmbedding,
)
from ...geometry.exponential_family.harmonium import (
    Harmonium,
    SymmetricConjugated,
)
from ...geometry.exponential_family.variational import (
    VariationalConjugated,
)
from ..base.poisson import CoMPoisson, CoMShape, Poisson, Poissons
from ..base.von_mises import VonMises, VonMisesProduct
from .mixture import AnalyticMixture, Mixture

### Poisson-VonMises Harmonium ###


@dataclass(frozen=True)
class PoissonVonMisesHarmonium(Harmonium[Poissons, VonMisesProduct]):
    """Harmonium with Poisson observables and VonMises latents."""

    _int_man: EmbeddedMap[VonMisesProduct, Poissons]

    @property
    @override
    def int_man(self) -> EmbeddedMap[VonMisesProduct, Poissons]:
        return self._int_man

    @property
    def n_neurons(self) -> int:
        """Number of Poisson observable neurons."""
        return self.obs_man.n_reps

    @property
    def n_latent(self) -> int:
        """Number of VonMises latent dimensions."""
        return self.pst_man.n_reps

    def observable_rates(self, params: Array, z: Array) -> Array:
        """Compute expected firing rates given latent state z."""
        lkl_params = self.likelihood_at(params, z)
        return self.obs_man.to_mean(lkl_params)

    def get_weight_matrix(self, params: Array) -> Array:
        """Extract interaction matrix as 2D array (n_neurons, 2*n_latent)."""
        _, int_params, _ = self.split_coords(params)
        return int_params.reshape(self.n_neurons, 2 * self.n_latent)


def poisson_vonmises_harmonium(
    n_neurons: int, n_latent: int
) -> PoissonVonMisesHarmonium:
    """Create a Poisson-VonMises Harmonium."""
    obs = Poissons(n_neurons)
    lat = VonMisesProduct(n_latent)
    int_man = EmbeddedMap(Rectangular(), IdentityEmbedding(lat), IdentityEmbedding(obs))
    return PoissonVonMisesHarmonium(int_man)


### Variational Poisson-VonMises ###


@dataclass(frozen=True)
class VariationalPoissonVonMises(
    VariationalConjugated[Poissons, VonMisesProduct, VonMisesProduct]
):
    """Variational conjugated model with Poisson observables and VonMises latents.

    Uses score-function gradient correction since VonMises sampling
    (rejection sampling) is not reparameterizable.

    Attributes:
        hrm: The underlying PoissonVonMisesHarmonium
    """

    hrm: PoissonVonMisesHarmonium

    @property
    @override
    def rho_emb(self) -> IdentityEmbedding[VonMisesProduct]:
        """Identity embedding: rho directly subtracts from posterior params."""
        return IdentityEmbedding(self.hrm.pst_man)

    @property
    def n_neurons(self) -> int:
        """Number of Poisson observable neurons."""
        return self.hrm.n_neurons

    @property
    def n_latent(self) -> int:
        """Number of VonMises latent dimensions."""
        return self.hrm.n_latent

    def get_weight_matrix(self, params: Array) -> Array:
        """Extract interaction matrix as 2D array (n_neurons, 2*n_latent)."""
        _, hrm_params = self.split_coords(params)
        return self.hrm.get_weight_matrix(hrm_params)

    def reconstruct(self, params: Array, x: Array) -> Array:
        """Reconstruct observable via mean-field approximation."""
        q_params = self.approximate_posterior_at(params, x)
        z_mean_stats = self.pst_man.to_mean(q_params)
        _, hrm_params = self.split_coords(params)
        lkl_fun = self.hrm.likelihood_function(hrm_params)
        lkl_natural = self.hrm.lkl_fun_man(lkl_fun, z_mean_stats)
        return self.obs_man.to_mean(lkl_natural)

    def reconstruction_error(self, params: Array, xs: Array) -> Array:
        """Compute mean squared reconstruction error over a batch."""
        recons = jax.vmap(self.reconstruct, in_axes=(None, 0))(params, xs)
        return jnp.mean((xs - recons) ** 2)

    def normalized_reconstruction_error(
        self, params: Array, xs: Array, scale: float = 1.0
    ) -> Array:
        """Compute MSE on normalized data."""
        recons = jax.vmap(self.reconstruct, in_axes=(None, 0))(params, xs)
        xs_norm = xs / scale
        recons_norm = recons / scale
        return jnp.mean((xs_norm - recons_norm) ** 2)


def variational_poisson_vonmises(
    n_neurons: int, n_latent: int
) -> VariationalPoissonVonMises:
    """Create a variational Poisson-VonMises model."""
    hrm = poisson_vonmises_harmonium(n_neurons, n_latent)
    return VariationalPoissonVonMises(hrm=hrm)


### Population Codes ###

type PoissonPopulation = AnalyticProduct[Poisson]


@dataclass(frozen=True)
class VonMisesPopulationCode(SymmetricConjugated[PoissonPopulation, VonMises]):
    """Population code with Poisson observables and Von Mises stimulus.

    Models neural responses where each neuron has a tuning curve over a
    circular stimulus (e.g., orientation).

    Attributes:
        n_neurons: Number of neurons in the population
        n_grid_points: Number of grid points for regression (default 100)
    """

    n_neurons: int
    n_grid_points: int = 100

    @property
    @override
    def lat_man(self) -> VonMises:
        return VonMises()

    @property
    @override
    def obs_man(self) -> PoissonPopulation:
        return AnalyticProduct(Poisson(), self.n_neurons)

    @property
    @override
    def int_man(self) -> EmbeddedMap[VonMises, PoissonPopulation]:
        return EmbeddedMap(
            Rectangular(),
            IdentityEmbedding(self.lat_man),
            IdentityEmbedding(self.obs_man),
        )

    @override
    def conjugation_parameters(self, lkl_params: Array) -> Array:
        """Compute approximate conjugation parameters via regression."""
        obs_params, int_params = self.lkl_fun_man.split_coords(lkl_params)
        int_matrix = int_params.reshape(self.n_neurons, 2)
        grid = jnp.linspace(0, 2 * jnp.pi, self.n_grid_points, endpoint=False)

        def log_partition_at(z: Array) -> Array:
            suff = self.lat_man.sufficient_statistic(z)
            log_rates = obs_params + int_matrix @ suff
            return jnp.sum(jnp.exp(log_rates))

        log_partitions = jax.vmap(log_partition_at)(grid)
        suff_stats = jax.vmap(self.lat_man.sufficient_statistic)(grid)
        design = jnp.column_stack([jnp.ones(self.n_grid_points), suff_stats])
        coeffs, _, _, _ = jnp.linalg.lstsq(design, log_partitions, rcond=None)
        return coeffs[1:]

    def regression_diagnostics(self, params: Array) -> tuple[Array, Array, Array]:
        """Get regression fit diagnostics for visualization."""
        obs_params, int_params, _ = self.split_coords(params)
        int_matrix = int_params.reshape(self.n_neurons, 2)
        grid = jnp.linspace(0, 2 * jnp.pi, self.n_grid_points, endpoint=False)

        def log_partition_at(z: Array) -> Array:
            suff = self.lat_man.sufficient_statistic(z)
            log_rates = obs_params + int_matrix @ suff
            return jnp.sum(jnp.exp(log_rates))

        actual = jax.vmap(log_partition_at)(grid)
        suff_stats = jax.vmap(self.lat_man.sufficient_statistic)(grid)
        design = jnp.column_stack([jnp.ones(self.n_grid_points), suff_stats])
        coeffs, _, _, _ = jnp.linalg.lstsq(design, actual, rcond=None)
        fitted = design @ coeffs
        return grid, actual, fitted

    def join_tuning_parameters(
        self, gains: Array, preferred: Array, baselines: Array, prior_params: Array
    ) -> Array:
        """Construct population code from tuning curve specification."""
        obs_params = baselines
        int_col_1 = gains * jnp.cos(preferred)
        int_col_2 = gains * jnp.sin(preferred)
        int_params = jnp.stack([int_col_1, int_col_2], axis=1).ravel()
        lkl_params = self.lkl_fun_man.join_coords(obs_params, int_params)
        rho = self.conjugation_parameters(lkl_params)
        lat_params = prior_params - rho
        return self.join_coords(obs_params, int_params, lat_params)

    def split_tuning_parameters(
        self, params: Array
    ) -> tuple[Array, Array, Array, Array]:
        """Extract tuning curve parameters from natural parameters."""
        obs_params, int_params, lat_params = self.split_coords(params)
        baselines = obs_params
        int_matrix = int_params.reshape(self.n_neurons, 2)
        gains = jnp.sqrt(int_matrix[:, 0] ** 2 + int_matrix[:, 1] ** 2)
        preferred = jnp.arctan2(int_matrix[:, 1], int_matrix[:, 0])
        lkl_params = self.lkl_fun_man.join_coords(obs_params, int_params)
        rho = self.conjugation_parameters(lkl_params)
        prior_params = lat_params + rho
        return gains, preferred, baselines, prior_params

    def tuning_curve(self, params: Array, z: Array) -> Array:
        """Evaluate log firing rates at a stimulus."""
        obs_params, int_params, _ = self.split_coords(params)
        int_matrix = int_params.reshape(self.n_neurons, 2)
        suff = self.lat_man.sufficient_statistic(z)
        return obs_params + int_matrix @ suff

    def firing_rates(self, params: Array, z: Array) -> Array:
        """Evaluate firing rates at a stimulus."""
        log_rates = self.tuning_curve(params, z)
        return self.obs_man.to_mean(log_rates)

    def sample_stimulus(self, key: Array, params: Array) -> Array:
        """Sample stimulus z from prior."""
        prior_params = self.prior(params)
        return self.lat_man.sample(key, prior_params, n=1)[0]

    def sample_spikes(self, key: Array, params: Array, z: Array) -> Array:
        """Sample spike counts given a stimulus."""
        log_rates = self.tuning_curve(params, z)
        rates = self.obs_man.to_mean(log_rates)
        return jax.random.poisson(key, rates)

    def sample_joint(
        self, key: Array, params: Array, n: int = 1
    ) -> tuple[Array, Array]:
        """Sample (spike_counts, stimulus) pairs from the joint distribution."""
        key_z, key_x = jax.random.split(key)
        prior_params = self.prior(params)
        zs = self.lat_man.sample(key_z, prior_params, n=n).squeeze(-1)
        keys_x = jax.random.split(key_x, n)
        xs = jax.vmap(lambda k, z: self.sample_spikes(k, params, z))(keys_x, zs)
        return xs, zs

    def posterior_mean(self, params: Array, x: Array) -> Array:
        """Compute posterior mean stimulus given observations."""
        post_params = self.posterior_at(params, x)
        mu, _ = self.lat_man.split_mean_concentration(post_params)
        return mu

    def posterior_concentration(self, params: Array, x: Array) -> Array:
        """Compute posterior concentration given observations."""
        post_params = self.posterior_at(params, x)
        _, kappa = self.lat_man.split_mean_concentration(post_params)
        return kappa

    def log_observable_density(self, params: Array, x: Array) -> Array:
        """Compute log density of the observable distribution p(x)."""
        obs_params, _, _ = self.split_coords(params)
        chi = self.obs_man.log_partition_function(obs_params)
        obs_stats = self.obs_man.sufficient_statistic(x)
        prr = self.prior(params)
        log_density = jnp.dot(obs_params, obs_stats)
        log_density += self.lat_man.log_partition_function(self.posterior_at(params, x))
        log_density -= self.lat_man.log_partition_function(prr) + chi
        return log_density + self.obs_man.log_base_measure(x)

    def observable_density(self, params: Array, x: Array) -> Array:
        """Compute density of the observable distribution p(x)."""
        return jnp.exp(self.log_observable_density(params, x))


def von_mises_population_code(
    n_neurons: int,
    gains: Array,
    preferred: Array,
    baselines: Array,
    prior_mean: float = 0.0,
    prior_concentration: float = 0.0,
) -> tuple[VonMisesPopulationCode, Array]:
    """Create a Von Mises population code with specified tuning curves."""
    model = VonMisesPopulationCode(n_neurons)
    prior_params = model.lat_man.join_mean_concentration(
        prior_mean, prior_concentration
    )
    params = model.join_tuning_parameters(gains, preferred, baselines, prior_params)
    return model, params


### Poisson Mixture Models ###

type PopulationShape = Product[CoMShape]
type PoissonMixture = AnalyticMixture[PoissonPopulation]
type CoMPoissonMixture = Mixture[CoMPoissonPopulation]


[docs] @dataclass(frozen=True) class CoMPoissonPopulation( DifferentiableProduct[CoMPoisson], LocationShape[PoissonPopulation, PopulationShape] ): """A population of independent Conway-Maxwell-Poisson (COM-Poisson) units.""" def __init__(self, n_reps: int): super().__init__(CoMPoisson(), n_reps) @property @override def fst_man(self) -> PoissonPopulation: return AnalyticProduct(Poisson(), n_reps=self.n_reps) @property @override def snd_man(self) -> PopulationShape: return Product(CoMShape(), n_reps=self.n_reps)
[docs] @override def join_coords(self, fst_coords: Array, snd_coords: Array) -> Array: params_matrix = jnp.stack([fst_coords, snd_coords]) return params_matrix.T.reshape(-1)
[docs] @override def split_coords(self, coords: Array) -> tuple[Array, Array]: matrix = coords.reshape(self.n_reps, 2).T return matrix[0], matrix[1]
[docs] @dataclass(frozen=True) class PopulationLocationEmbedding( TupleEmbedding[PoissonPopulation, CoMPoissonPopulation] ): """Embedding that projects COM-Poisson to Poisson via location parameters.""" n_neurons: int @property @override def tup_idx(self) -> int: return 0 @property @override def amb_man(self) -> CoMPoissonPopulation: return CoMPoissonPopulation(n_reps=self.n_neurons) @property @override def sub_man(self) -> PoissonPopulation: return AnalyticProduct(Poisson(), n_reps=self.n_neurons)
[docs] def poisson_mixture(n_neurons: int, n_components: int) -> PoissonMixture: """Create a mixture of independent Poisson populations.""" pop_man = AnalyticProduct(Poisson(), n_reps=n_neurons) return AnalyticMixture(pop_man, n_components)
[docs] def com_poisson_mixture(n_neurons: int, n_components: int) -> CoMPoissonMixture: """Create a COM-Poisson mixture with shared dispersion parameters.""" obs_emb = PopulationLocationEmbedding(n_neurons) return Mixture(n_components, obs_emb)