"""Linear and affine maps between manifolds.
A ``LinearMap`` is itself a ``Manifold`` whose points are the parameters of the map. The main concrete implementation is ``EmbeddedMap``, which combines a ``MatrixRep`` (from ``matrix.py``) with domain/codomain embeddings to handle the project-multiply-embed pipeline. ``BlockMap`` sums several such maps for heterogeneous block structure, and ``AffineMap`` adds a bias term.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import override
import jax.numpy as jnp
from jax import Array
from .base import Manifold
from .combinators import Pair
from .embedding import IdentityEmbedding, LinearComposedEmbedding, LinearEmbedding
from .matrix import MatrixRep, Square
### Linear Maps ###
[docs]
@dataclass(frozen=True)
class LinearMap[Domain: Manifold, Codomain: Manifold](Manifold, ABC):
"""A linear transformation between manifolds, which is itself a manifold (its points are the map's parameters).
The key pattern: a ``LinearMap`` knows its domain, codomain, and how to apply itself, transpose itself, and compute outer products. Concrete implementations choose how to store and execute the map.
Mathematically, a linear map $L: V \\to W$ satisfies $L(\\alpha x + \\beta y) = \\alpha L(x) + \\beta L(y)$.
"""
# Contract
@property
@abstractmethod
def dom_man(self) -> Domain:
"""The domain manifold."""
@property
@abstractmethod
def cod_man(self) -> Codomain:
"""The codomain manifold."""
@property
@abstractmethod
def trn_man(self) -> LinearMap[Codomain, Domain]:
"""Manifold of transposed linear maps."""
@abstractmethod
def __call__(self, f_coords: Array, v_coords: Array) -> Array:
"""Apply the map: takes map parameters and a domain point, returns a codomain point."""
[docs]
@abstractmethod
def transpose(self, f_coords: Array) -> Array:
"""Return parameters of the transposed map."""
[docs]
@abstractmethod
def outer_product(self, w_coords: Array, v_coords: Array) -> Array:
"""Outer product $w \\otimes v$, returned as map parameters."""
[docs]
@abstractmethod
def map_domain_embedding[NewDomain: Manifold](
self,
f: Callable[
[LinearEmbedding[Manifold, Domain]], LinearEmbedding[Manifold, NewDomain]
],
) -> LinearMap[NewDomain, Codomain]:
"""Transform the domain embedding(s) using the given function.
This enables operations like tensoring with a new factor by wrapping
the domain embedding in a more complex structure.
"""
[docs]
@abstractmethod
def map_codomain_embedding[NewCodomain: Manifold](
self,
f: Callable[
[LinearEmbedding[Manifold, Codomain]],
LinearEmbedding[Manifold, NewCodomain],
],
) -> LinearMap[Domain, NewCodomain]:
"""Transform the codomain embedding(s) using the given function.
This enables operations like tensoring with a new factor by wrapping
the codomain embedding in a more complex structure.
"""
# Methods
[docs]
def transpose_apply(self, f_coords: Array, w_coords: Array) -> Array:
"""Apply the transpose: takes map parameters and a codomain point, returns a domain point."""
f_trn_coords = self.transpose(f_coords)
return self.trn_man(f_trn_coords, w_coords)
[docs]
def prepend_embedding[NewDomain: Manifold](
self,
emb: LinearEmbedding[Domain, NewDomain],
) -> LinearMap[NewDomain, Codomain]:
"""Prepend an embedding to the domain."""
return self.map_domain_embedding(
lambda dom_emb: LinearComposedEmbedding(dom_emb, emb)
)
[docs]
def append_embedding[NewCodomain: Manifold](
self,
emb: LinearEmbedding[Codomain, NewCodomain],
) -> LinearMap[Domain, NewCodomain]:
"""Append an embedding to the codomain."""
return self.map_codomain_embedding(
lambda cod_emb: LinearComposedEmbedding(cod_emb, emb)
)
@dataclass(frozen=True)
class EmbeddedMap[Domain: Manifold, Codomain: Manifold](LinearMap[Domain, Codomain]):
"""A linear map backed by a ``MatrixRep`` and domain/codomain embeddings.
Application follows the pipeline: project input to internal domain, multiply by the matrix, embed result into external codomain. This lets the matrix operate on a lower-dimensional internal space while the map's external interface matches the full manifold dimensions.
"""
# Fields
rep: MatrixRep
"""The matrix representation strategy for this linear map."""
dom_emb: LinearEmbedding[Manifold, Domain]
"""Embedding from internal domain to external domain manifold."""
cod_emb: LinearEmbedding[Manifold, Codomain]
"""Embedding from internal codomain to external codomain manifold."""
# Overrides
@property
@override
def dom_man(self) -> Domain:
return self.dom_emb.amb_man
@property
@override
def cod_man(self) -> Codomain:
return self.cod_emb.amb_man
@property
@override
def dim(self) -> int:
return self.rep.num_params(self.matrix_shape)
@property
@override
def trn_man(self) -> EmbeddedMap[Codomain, Domain]:
"""Manifold of transposed linear maps."""
return EmbeddedMap(self.rep, self.cod_emb, self.dom_emb)
@override
def __call__(self, f_coords: Array, v_coords: Array) -> Array:
internal_v = self.dom_emb.project(v_coords)
internal_result = self.rep.matvec(self.matrix_shape, f_coords, internal_v)
return self.cod_emb.embed(internal_result)
@override
def transpose(self, f_coords: Array) -> Array:
return self.rep.transpose(self.matrix_shape, f_coords)
@override
def outer_product(self, w_coords: Array, v_coords: Array) -> Array:
internal_w = self.cod_emb.project(w_coords)
internal_v = self.dom_emb.project(v_coords)
return self.rep.outer_product(internal_w, internal_v)
@override
def map_domain_embedding[NewDomain: Manifold](
self,
f: Callable[
[LinearEmbedding[Manifold, Domain]], LinearEmbedding[Manifold, NewDomain]
],
) -> EmbeddedMap[NewDomain, Codomain]:
new_dom_emb = f(self.dom_emb)
return EmbeddedMap(self.rep, new_dom_emb, self.cod_emb)
@override
def map_codomain_embedding[NewCodomain: Manifold](
self,
f: Callable[
[LinearEmbedding[Manifold, Codomain]],
LinearEmbedding[Manifold, NewCodomain],
],
) -> EmbeddedMap[Domain, NewCodomain]:
new_cod_emb = f(self.cod_emb)
return EmbeddedMap(self.rep, self.dom_emb, new_cod_emb)
@override
def prepend_embedding[NewDomain: Manifold](
self,
emb: LinearEmbedding[Domain, NewDomain],
) -> EmbeddedMap[NewDomain, Codomain]:
"""Prepend an embedding to the domain."""
return self.map_domain_embedding(
lambda dom_emb: LinearComposedEmbedding(dom_emb, emb)
)
@override
def append_embedding[NewCodomain: Manifold](
self,
emb: LinearEmbedding[Codomain, NewCodomain],
) -> EmbeddedMap[Domain, NewCodomain]:
"""Append an embedding to the codomain."""
return self.map_codomain_embedding(
lambda cod_emb: LinearComposedEmbedding(cod_emb, emb)
)
# Methods
@property
def matrix_shape(self) -> tuple[int, int]:
"""Shape of the matrix operating on internal dimensions.
Returns the shape $(\\dim(internal\\_codomain), \\dim(internal\\_domain))$
of the underlying matrix in the internal coordinate system.
"""
return (self.cod_emb.sub_man.dim, self.dom_emb.sub_man.dim)
def from_matrix(self, matrix: Array) -> Array:
"""Pack a dense 2D matrix (in internal dimensions) into flat parameters."""
return self.rep.from_matrix(matrix)
def to_matrix(self, f_coords: Array) -> Array:
"""Unpack flat parameters into a dense 2D matrix (in internal dimensions)."""
return self.rep.to_matrix(self.matrix_shape, f_coords)
def get_diagonal(self, f_coords: Array) -> Array:
"""Extract diagonal elements from the matrix."""
return self.rep.get_diagonal(self.matrix_shape, f_coords)
def map_diagonal(
self, f_coords: Array, diagonal_f: Callable[[Array], Array]
) -> Array:
"""Apply a function to the diagonal elements, preserving matrix structure."""
return self.rep.map_diagonal(self.matrix_shape, f_coords, diagonal_f)
def embed_rep(
self, f_coords: Array, target_rep: MatrixRep
) -> tuple[EmbeddedMap[Domain, Codomain], Array]:
"""Embed into a more general representation (e.g. Diagonal -> Symmetric)."""
target_man = EmbeddedMap(target_rep, self.dom_emb, self.cod_emb)
coords = self.rep.embed_params(self.matrix_shape, f_coords, target_rep)
return target_man, coords
def project_rep(
self, f_coords: Array, target_rep: MatrixRep
) -> tuple[EmbeddedMap[Domain, Codomain], Array]:
"""Project to a more constrained representation (e.g. Symmetric -> Diagonal)."""
target_man = EmbeddedMap(target_rep, self.dom_emb, self.cod_emb)
coords = self.rep.project_params(self.matrix_shape, f_coords, target_rep)
return target_man, coords
@dataclass(frozen=True)
class BlockMap[Domain: Manifold, Codomain: Manifold](LinearMap[Domain, Codomain]):
"""Sum of independent linear maps (blocks), each potentially with a different ``MatrixRep``.
Used when different parts of a transformation have different structure --- e.g. a harmonium interaction matrix with one dense block and one diagonal block. Parameters are the concatenation of each block's parameters.
"""
# Fields
blocks: list[LinearMap[Domain, Codomain]]
"""The embedded linear maps that compose this block map."""
# Overrides
@property
@override
def dom_man(self) -> Domain:
return self.blocks[0].dom_man
@property
@override
def cod_man(self) -> Codomain:
return self.blocks[0].cod_man
@property
@override
def dim(self) -> int:
return sum(block.dim for block in self.blocks)
@property
@override
def trn_man(self) -> BlockMap[Codomain, Domain]:
"""Manifold of transposed linear maps."""
transposed_blocks = [block.trn_man for block in self.blocks]
return BlockMap(transposed_blocks)
@override
def __call__(self, f_coords: Array, v_coords: Array) -> Array:
result = self.cod_man.zeros()
for block, block_coords in zip(self.blocks, self.coord_blocks(f_coords)):
result = result + block(block_coords, v_coords)
return result
@override
def transpose(self, f_coords: Array) -> Array:
transposed_coords = [
block.transpose(block_coords)
for block, block_coords in zip(self.blocks, self.coord_blocks(f_coords))
]
return jnp.concatenate(transposed_coords)
@override
def outer_product(self, w_coords: Array, v_coords: Array) -> Array:
outer_params = [
block.outer_product(w_coords, v_coords) for block in self.blocks
]
return jnp.concatenate(outer_params)
@override
def map_domain_embedding[NewDomain: Manifold](
self,
f: Callable[
[LinearEmbedding[Manifold, Domain]], LinearEmbedding[Manifold, NewDomain]
],
) -> BlockMap[NewDomain, Codomain]:
return BlockMap([block.map_domain_embedding(f) for block in self.blocks])
@override
def map_codomain_embedding[NewCodomain: Manifold](
self,
f: Callable[
[LinearEmbedding[Manifold, Codomain]],
LinearEmbedding[Manifold, NewCodomain],
],
) -> BlockMap[Domain, NewCodomain]:
return BlockMap([block.map_codomain_embedding(f) for block in self.blocks])
# Methods
def coord_blocks(self, coords: Array) -> list[Array]:
"""Split flat parameters into per-block slices."""
sections = []
offset = 0
for block in self.blocks:
dim = block.dim
sections.append(coords[offset : offset + dim])
offset += dim
return sections
class AmbientMap[Domain: Manifold, Codomain: Manifold](EmbeddedMap[Domain, Codomain]):
"""Convenience wrapper: an ``EmbeddedMap`` with identity embeddings on both sides.
Use this when the map operates on the full domain and codomain without restriction.
"""
def __init__(self, rep: MatrixRep, dom_man: Domain, cod_man: Codomain):
super().__init__(rep, IdentityEmbedding(dom_man), IdentityEmbedding(cod_man))
[docs]
@dataclass(frozen=True)
class SquareMap[M: Manifold](AmbientMap[M, M]):
"""Square ``AmbientMap`` (domain = codomain), exposing inverse, log-determinant, and positive-definiteness checks."""
# Fields
rep: Square
def __init__(self, rep: MatrixRep, dom_man: M):
# Check that the representation is square
if not issubclass(type(rep), Square):
raise TypeError("SquareMap requires a square matrix representation.")
super().__init__(rep, dom_man, dom_man)
# Methods
[docs]
def inverse(self, f_coords: Array) -> Array:
"""Parameters of the inverse matrix."""
return self.rep.inverse(self.matrix_shape, f_coords)
[docs]
def logdet(self, f_coords: Array) -> Array:
"""Log determinant."""
return self.rep.logdet(self.matrix_shape, f_coords)
[docs]
def is_positive_definite(self, f_coords: Array) -> Array:
"""Check positive definiteness."""
return self.rep.is_positive_definite(self.matrix_shape, f_coords)
### Affine Maps ###
[docs]
@dataclass(frozen=True)
class AffineMap[
Domain: Manifold,
Codomain: Manifold,
](
Pair[Codomain, LinearMap[Domain, Codomain]],
):
"""A linear map plus a bias: $A(x) = L(x) + b$.
Stored as a ``Pair`` of the bias $b$ (on the codomain) and the linear map $L$. This is the natural parameter space for exponential family likelihoods (bias = observable natural parameters, linear part = interaction).
"""
# Fields
map_man: LinearMap[Domain, Codomain]
"""The linear transformation for this affine map."""
dom_man: Domain
"""The domain of the affine map."""
# Overrides
@property
@override
def fst_man(self) -> Codomain:
return self.map_man.cod_man
@property
@override
def snd_man(self) -> LinearMap[Domain, Codomain]:
return self.map_man
# Methods
def __call__(self, f_coords: Array, v_coords: Array) -> Array:
"""Apply the affine map: $L(v) + b$."""
bias, linear = self.split_coords(f_coords)
return bias + self.snd_man(linear, v_coords)