Source code for goal.geometry.manifold.matrix

"""Storage-efficient matrix representations for use as strategies in linear maps.

A ``MatrixRep`` defines how to store a matrix as a flat parameter array and how to perform linear algebra (matvec, transpose, inverse, etc.) while respecting structural constraints. ``EmbeddedMap`` in ``linear.py`` plugs a rep into the manifold system; this module is purely about the matrix operations themselves.

The hierarchy from most to least general is::

    Rectangular > Square > Symmetric > PositiveDefinite > Diagonal > Scale > Identity

Each level exploits additional structure for cheaper storage and operations:

+----------------+---------------+---------------+---------------+
| Representation | Storage       | Matmul        | Inverse/Det   |
+================+===============+===============+===============+
| Rectangular    | $O(n^2)$      | $O(n^2)$      | $O(n^3)$      |
+----------------+---------------+---------------+---------------+
| Symmetric      | $O(n^2/2)$    | $O(n^2)$      | $O(n^3)$      |
+----------------+---------------+---------------+---------------+
| Pos. Definite  | $O(n^2/2)$    | $O(n^2)$      | $O(n^3)$      |
|                |               |               | (Cholesky)    |
+----------------+---------------+---------------+---------------+
| Diagonal       | $O(n)$        | $O(n)$        | $O(n)$        |
+----------------+---------------+---------------+---------------+
| Scale          | $O(1)$        | $O(n)$        | $O(1)$        |
+----------------+---------------+---------------+---------------+
| Identity       | $O(1)$        | $O(1)$        | $O(1)$        |
+----------------+---------------+---------------+---------------+

TODO: A ``Convolutional`` rep could fit naturally here. It would store a kernel and
implement ``matvec`` via convolution on a flat array (i.e. a compactly-stored Toeplitz
matrix), with ``shape = (output_len, input_len)`` preserving the existing contract.
Multi-channel and 2D structure would be handled in ``linear.py`` via ``BlockMap``
(one block per channel pair) or embeddings that reshape between flat and spatial layouts.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import override

import jax
import jax.numpy as jnp
from jax import Array

### Helper Functions ###


def _matmat(
    rep: MatrixRep,
    shape: tuple[int, int],
    params: Array,
    right_rep: MatrixRep,
    right_shape: tuple[int, int],
    right_params: Array,
) -> tuple[MatrixRep, tuple[int, int], Array]:
    """Compute matrix-matrix product, choosing the tightest representation for the result.

    Dispatches on both operands' representations: Identity and Scale short-circuit,
    Diagonal * Diagonal stays Diagonal, and mixed cases fall back to dense multiplication
    with the result stored as Square or Rectangular based on output shape.
    """
    out_shape = (shape[0], right_shape[1])
    out_rep: MatrixRep
    out_params: Array
    match rep:
        case Identity():
            # Identity * B = B, so result has same rep as right matrix
            out_rep = right_rep
            out_params = right_params
        case Scale():
            # Scale (aI) * B = aB, so just scale the right matrix parameters
            out_rep = right_rep
            out_params = params[0] * right_params
        case Diagonal():
            # Diagonal multiplication: D * X
            match right_rep:
                case Diagonal():
                    # Diagonal * Diagonal = Diagonal (element-wise product of diagonals)
                    out_rep = Diagonal()
                    out_params = params * right_params
                case _:
                    # Diagonal * (non-Diagonal): result is dense
                    # Convert right matrix to dense and multiply: diag(params) @ right_dense
                    # Broadcasting: params[:, None] * right_dense treats params as column vector
                    right_dense = right_rep.to_matrix(right_shape, right_params)
                    out_dense = params[:, None] * right_dense
                    # Determine output representation based on output shape
                    # If output is square, use Square; otherwise use Rectangular
                    out_rep = (
                        Square() if out_shape[0] == out_shape[1] else Rectangular()
                    )
                    out_params = out_rep.from_matrix(out_dense)
        case _:
            # General matrix (Rectangular/Square/Symmetric/PositiveDefinite) multiplication
            match right_rep:
                case Identity() | Scale() | Diagonal():
                    # A * S = (S * A^T)^T — S is symmetric so S^T = S
                    left_params_t = rep.transpose(shape, params)
                    left_shape_t = (shape[1], shape[0])
                    result_rep, result_shape, result_params = _matmat(
                        right_rep, right_shape, right_params, rep, left_shape_t, left_params_t
                    )
                    result_params_t = result_rep.transpose(result_shape, result_params)
                    result_shape_t = (result_shape[1], result_shape[0])
                    return result_rep, result_shape_t, result_params_t
                case _:
                    # Both general matrices: fall back to dense multiplication
                    # Determine output representation based on output shape:
                    # - If m == p, result is Square (can support additional operations)
                    # - Otherwise, result is Rectangular (general m x p matrix)
                    left_dense = rep.to_matrix(shape, params)
                    right_dense = right_rep.to_matrix(right_shape, right_params)
                    out_rep = (
                        Square() if out_shape[0] == out_shape[1] else Rectangular()
                    )
                    out_dense = left_dense @ right_dense
                    out_params = out_rep.from_matrix(out_dense)
    return out_rep, out_shape, out_params


def _diag_indices_in_triangular(n: int) -> Array:
    """Return indices of diagonal elements in upper triangular storage.

    For an $n \times n$ matrix stored in upper triangular format $(n(n+1)/2)$,
    returns the indices where diagonal elements are stored.
    """
    i_diag = jnp.arange(n)
    return i_diag * (2 * n - i_diag + 1) // 2


[docs] class MatrixRep(ABC): """Strategy interface for matrix storage and operations. Each subclass defines how to pack a matrix into a flat 1D parameter array and how to perform linear algebra (matvec, transpose, outer product, etc.) while preserving the structural constraint (symmetry, diagonal, etc.). Representations are stateless --- two instances of the same class are equal. The ``embed_params`` / ``project_params`` methods convert between representations by walking the linear inheritance chain (e.g. Diagonal -> PositiveDefinite -> Symmetric -> Square -> Rectangular). """ @override def __eq__(self, other: object) -> bool: """Compare matrix representations by type (not instance identity). Since MatrixRep instances are stateless, two instances are equal if they have the same class type. """ return type(self) is type(other) @override def __hash__(self) -> int: """Hash by class type to maintain hash/equality contract.""" return hash(type(self))
[docs] @classmethod @abstractmethod def matvec(cls, shape: tuple[int, int], params: Array, vector: Array) -> Array: """Matrix-vector multiplication."""
[docs] @classmethod def matmat( cls, shape: tuple[int, int], params: Array, right_rep: MatrixRep, right_shape: tuple[int, int], right_params: Array, ) -> tuple[MatrixRep, tuple[int, int], Array]: """Multiply matrices, returning optimal representation type and parameters.""" return _matmat(cls(), shape, params, right_rep, right_shape, right_params)
[docs] @classmethod @abstractmethod def transpose(cls, shape: tuple[int, int], params: Array) -> Array: """Transform parameters to represent the transposed matrix."""
[docs] @classmethod @abstractmethod def to_matrix(cls, shape: tuple[int, int], params: Array) -> Array: """Convert 1D parameters to dense matrix form."""
[docs] @classmethod @abstractmethod def num_params(cls, shape: tuple[int, int]) -> int: """Shape of 1D parameter array needed for matrix dimensions."""
[docs] @classmethod @abstractmethod def from_matrix(cls, matrix: Array) -> Array: """Convert dense matrix to 1D parameters."""
[docs] @classmethod @abstractmethod def outer_product(cls, v1: Array, v2: Array) -> Array: """Construct parameters from outer product $v_1 \\otimes v_2$."""
[docs] @classmethod @abstractmethod def map_diagonal( cls, shape: tuple[int, int], params: Array, f: Callable[[Array], Array] ) -> Array: """Apply function f to diagonal elements while preserving matrix structure."""
[docs] @classmethod def get_diagonal(cls, shape: tuple[int, int], params: Array) -> Array: """Extract diagonal elements from the matrix.""" return jnp.diag(cls.to_matrix(shape, params))
[docs] @classmethod def embed_params( cls, shape: tuple[int, int], params: Array, target_rep: MatrixRep ) -> Array: """Recursively embed params into more complex representation.""" if not issubclass(cls, type(target_rep)): raise TypeError(f"Cannot embed {cls} into {target_rep}") cur_rep = cls cur_params = params while cur_rep is not type(target_rep): cur_params = cur_rep.embed_in_super(shape, cur_params) cur_rep: type[MatrixRep] = cur_rep.__base__ # pyright: ignore[reportAssignmentType] return cur_params
[docs] @classmethod def project_params( cls, shape: tuple[int, int], params: Array, target_rep: MatrixRep ) -> Array: """Recursively project params to simpler representation.""" if not issubclass(type(target_rep), cls): raise TypeError(f"Cannot project {cls} to {target_rep}") # Build path of representations from target up to cls path: list[type[MatrixRep]] = [] cur_rep: type[MatrixRep] = type(target_rep) while cur_rep != cls: path.append(cur_rep) cur_rep = cur_rep.__base__ # pyright: ignore[reportAssignmentType] path = list(reversed(path)) # Project through path cur_params = params for sub_rep in path: cur_params = sub_rep.project_from_super(shape, cur_params) return cur_params
[docs] @classmethod @abstractmethod def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: """Embed parameters into immediate parent representation."""
[docs] @classmethod @abstractmethod def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: """Project parameters from immediate parent representation."""
[docs] class Rectangular(MatrixRep): """Full $m \\times n$ matrix, stored in row-major order. No structural constraints."""
[docs] @classmethod @override def matvec(cls, shape: tuple[int, int], params: Array, vector: Array) -> Array: matrix = cls.to_matrix(shape, params) return jnp.dot(matrix, vector)
[docs] @classmethod @override def transpose(cls, shape: tuple[int, int], params: Array) -> Array: matrix = cls.to_matrix(shape, params).T return matrix.reshape(-1)
[docs] @classmethod @override def to_matrix(cls, shape: tuple[int, int], params: Array) -> Array: return params.reshape(shape)
[docs] @classmethod @override def from_matrix(cls, matrix: Array) -> Array: return matrix.reshape(-1)
[docs] @classmethod @override def num_params(cls, shape: tuple[int, int]) -> int: n, m = shape return n * m
[docs] @classmethod @override def outer_product(cls, v1: Array, v2: Array) -> Array: """Create parameters from outer product.""" matrix = jnp.outer(v1, v2) return cls.from_matrix(matrix)
[docs] @classmethod @override def map_diagonal( cls, shape: tuple[int, int], params: Array, f: Callable[[Array], Array] ) -> Array: """Map function over diagonal elements of matrix.""" matrix = cls.to_matrix(shape, params) diag = jnp.diag(matrix) new_diag = f(diag) new_matrix = matrix.at[jnp.diag_indices(shape[0])].set(new_diag) return cls.from_matrix(new_matrix)
[docs] @classmethod @override def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: raise TypeError("Rectangular is most complex representation")
[docs] @classmethod @override def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: raise TypeError("No more complex rep to project from")
[docs] class Square(Rectangular): """Square $n \\times n$ matrix, adding inverse, determinant, and positive-definiteness checks."""
[docs] @classmethod def is_positive_definite(cls, shape: tuple[int, int], params: Array) -> Array: """Check if symmetric matrix is positive definite using eigenvalues.""" matrix = cls.to_matrix(shape, params) eigenvals = jnp.linalg.eigvalsh(matrix) return jnp.all(eigenvals > 0)
[docs] @classmethod def inverse(cls, shape: tuple[int, int], params: Array) -> Array: matrix = cls.to_matrix(shape, params) inv = jnp.linalg.inv(matrix) return cls.from_matrix(inv)
[docs] @classmethod def logdet(cls, shape: tuple[int, int], params: Array) -> Array: matrix = cls.to_matrix(shape, params) return jnp.linalg.slogdet(matrix)[1]
[docs] @classmethod @override def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: return params # Square to Rectangular is identity
[docs] @classmethod @override def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: n = shape[0] return params.reshape(n, n).reshape(-1)
[docs] class Symmetric(Square): """Symmetric matrix ($A = A^T$), stored as upper-triangular elements. Roughly half the storage of a full square matrix."""
[docs] @classmethod @override def transpose(cls, shape: tuple[int, int], params: Array) -> Array: """Symmetric matrices are self-transpose.""" return params
[docs] @classmethod @override def to_matrix(cls, shape: tuple[int, int], params: Array) -> Array: n = shape[0] matrix = jnp.zeros((n, n)) i_upper = jnp.triu_indices(n) matrix = matrix.at[i_upper].set(params) return matrix + jnp.triu(matrix, k=1).T
[docs] @classmethod @override def from_matrix(cls, matrix: Array) -> Array: n = matrix.shape[0] i_upper = jnp.triu_indices(n) return matrix[i_upper]
[docs] @classmethod @override def num_params(cls, shape: tuple[int, int]) -> int: n = shape[0] return (n * (n + 1)) // 2
[docs] @classmethod @override def get_diagonal(cls, shape: tuple[int, int], params: Array) -> Array: """Extract diagonal from packed upper-triangular storage in O(n).""" return params[_diag_indices_in_triangular(shape[0])]
[docs] @classmethod @override def map_diagonal( cls, shape: tuple[int, int], params: Array, f: Callable[[Array], Array] ) -> Array: """Apply f to diagonal elements in packed upper-triangular storage in O(n).""" diag_idx = _diag_indices_in_triangular(shape[0]) return params.at[diag_idx].set(f(params[diag_idx]))
[docs] @classmethod @override def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: """Convert upper triangular parameters to full square matrix parameters.""" # To Square means including each off-diagonal element twice matrix = cls.to_matrix(shape, params) return matrix.reshape(-1)
[docs] @classmethod @override def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: """Extract upper triangular parameters from full square matrix parameters.""" n = shape[0] matrix = params.reshape(n, n) return cls.from_matrix(matrix)
[docs] class PositiveDefinite(Symmetric): """Symmetric positive-definite matrix, using Cholesky decomposition for stable inverse and log-determinant. Mathematically, $A$ is positive definite iff $x^T A x > 0$ for all $x \\neq 0$, equivalently iff a unique Cholesky factorization $A = LL^T$ exists. """ @classmethod def _cholesky(cls, shape: tuple[int, int], params: Array) -> Array: matrix = cls.to_matrix(shape, params) return jnp.linalg.cholesky(matrix)
[docs] @classmethod def cholesky_matvec( cls, shape: tuple[int, int], params: Array, vector: Array ) -> Array: """Compute cholesky factorization and apply to vector or batch thereof.""" chol = cls._cholesky(shape, params) # Handle both single vectors and batches if vector.ndim == 1: return chol @ vector return (chol @ vector.T).T
[docs] @classmethod def cholesky_whiten( cls, shape: tuple[int, int], mean1: Array, params1: Array, mean2: Array, params2: Array, ) -> tuple[Array, Array]: """Whiten a distribution (mean1, params1) with respect to another (mean2, params2). Transforms the first Gaussian into a coordinate system in which the second Gaussian standard normal N(0, I). This is useful for computing KL divergences and other relative measures between Gaussians. Perform the transformation: - new_mean = L^(-1) @ (mean1 - mean2) - new_params = L^(-1) @ params1 @ L^(-T) Where L is the Cholesky factor of params2 such that L @ L.T = params2. """ # Get dense matrices matrix2 = cls.to_matrix(shape, params2) matrix1 = cls.to_matrix(shape, params1) # Compute Cholesky and transform mean chol = jnp.linalg.cholesky(matrix2) centered_mean = mean1 - mean2 whitened_mean = jax.scipy.linalg.solve_triangular( chol, centered_mean, lower=True ) # Transform covariance temp = jax.scipy.linalg.solve_triangular(chol, matrix1, lower=True) whitened_matrix = jax.scipy.linalg.solve_triangular(chol, temp.T, lower=True).T return whitened_mean, cls.from_matrix(whitened_matrix)
[docs] @classmethod @override def is_positive_definite(cls, shape: tuple[int, int], params: Array) -> Array: """Check positive definiteness via Cholesky decomposition.""" matrix = cls.to_matrix(shape, params) # Use linalg_ops.cholesky which returns NaN on failure chol = jax.lax.linalg.cholesky(matrix) return jnp.all(jnp.isfinite(chol))
[docs] @classmethod @override def inverse(cls, shape: tuple[int, int], params: Array) -> Array: """Inverse via Cholesky decomposition.""" chol = cls._cholesky(shape, params) n = shape[0] eye = jnp.eye(n) # Solve L L^T x = I inv_chol = jax.scipy.linalg.solve_triangular(chol, eye, lower=True) inv = inv_chol.T @ inv_chol return cls.from_matrix(inv)
[docs] @classmethod @override def logdet(cls, shape: tuple[int, int], params: Array) -> Array: """Log determinant via Cholesky.""" chol = cls._cholesky(shape, params) return 2.0 * jnp.sum(jnp.log(jnp.diag(chol)))
[docs] @classmethod @override def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: return params
[docs] @classmethod @override def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: return params
[docs] class Diagonal(PositiveDefinite): """Diagonal matrix $A = \\text{diag}(a_1, \\ldots, a_n)$, storing only the $n$ diagonal entries. All operations reduce to element-wise arithmetic, giving $O(n)$ storage and compute. """
[docs] @classmethod @override def is_positive_definite(cls, shape: tuple[int, int], params: Array) -> Array: """Check if all diagonal elements are positive.""" return jnp.all(params > 0)
[docs] @classmethod @override def matvec(cls, shape: tuple[int, int], params: Array, vector: Array) -> Array: return params * vector
[docs] @classmethod @override def transpose(cls, shape: tuple[int, int], params: Array) -> Array: return params
[docs] @classmethod @override def to_matrix(cls, shape: tuple[int, int], params: Array) -> Array: n = shape[0] matrix = jnp.zeros((n, n)) return matrix.at[jnp.diag_indices(n)].set(params)
[docs] @classmethod @override def from_matrix(cls, matrix: Array) -> Array: return jnp.diag(matrix)
[docs] @classmethod @override def num_params(cls, shape: tuple[int, int]) -> int: return shape[0]
[docs] @classmethod @override def inverse(cls, shape: tuple[int, int], params: Array) -> Array: return 1.0 / params
[docs] @classmethod @override def logdet(cls, shape: tuple[int, int], params: Array) -> Array: return jnp.sum(jnp.log(params))
@classmethod @override def _cholesky(cls, shape: tuple[int, int], params: Array) -> Array: return jnp.sqrt(params)
[docs] @classmethod @override def outer_product(cls, v1: Array, v2: Array) -> Array: """Create parameters from outer product, keeping only diagonal.""" return v1 * v2
[docs] @classmethod @override def cholesky_matvec( cls, shape: tuple[int, int], params: Array, vector: Array ) -> Array: return vector * jnp.sqrt(params)
[docs] @classmethod @override def cholesky_whiten( cls, shape: tuple[int, int], mean1: Array, params1: Array, mean2: Array, params2: Array, ) -> tuple[Array, Array]: # For diagonal matrices, Cholesky is just sqrt of diagonal elements # So whitening is just division by sqrt(params2) sqrt_params2 = jnp.sqrt(params2) whitened_mean = (mean1 - mean2) / sqrt_params2 whitened_params = params1 / params2 return whitened_mean, whitened_params
[docs] @classmethod @override def map_diagonal( cls, shape: tuple[int, int], params: Array, f: Callable[[Array], Array] ) -> Array: return f(params)
[docs] @classmethod @override def get_diagonal(cls, shape: tuple[int, int], params: Array) -> Array: return params
[docs] @classmethod @override def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: """Put diagonal elements into upper triangular format.""" n = shape[0] out_params = jnp.zeros(PositiveDefinite.num_params(shape)) diag_indices = _diag_indices_in_triangular(n) return out_params.at[diag_indices].set(params)
[docs] @classmethod @override def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: """Extract diagonal elements from upper triangular format.""" n = shape[0] # Use our num_params to verify we're getting correct number of diagonal elements diag_indices = _diag_indices_in_triangular(n) return params[diag_indices]
[docs] class Scale(Diagonal): """Scalar multiple of the identity, $A = \\alpha I$. Single parameter, $O(1)$ storage."""
[docs] @classmethod @override def is_positive_definite(cls, shape: tuple[int, int], params: Array) -> Array: """Check if scale factor is positive.""" return params[0] > 0
[docs] @classmethod @override def matvec(cls, shape: tuple[int, int], params: Array, vector: Array) -> Array: return params[0] * vector
[docs] @classmethod @override def to_matrix(cls, shape: tuple[int, int], params: Array) -> Array: n = shape[0] return params[0] * jnp.eye(n)
[docs] @classmethod @override def from_matrix(cls, matrix: Array) -> Array: return jnp.array([jnp.mean(jnp.diag(matrix))])
[docs] @classmethod @override def num_params(cls, shape: tuple[int, int]) -> int: return 1
[docs] @classmethod @override def logdet(cls, shape: tuple[int, int], params: Array) -> Array: n = shape[0] return n * jnp.log(params[0])
[docs] @classmethod @override def outer_product(cls, v1: Array, v2: Array) -> Array: """Average outer product to single scale parameter.""" return jnp.array([jnp.mean(v1 * v2)])
[docs] @classmethod @override def map_diagonal( cls, shape: tuple[int, int], params: Array, f: Callable[[Array], Array] ) -> Array: return jnp.array([f(params[0])])
[docs] @classmethod @override def get_diagonal(cls, shape: tuple[int, int], params: Array) -> Array: return jnp.full(shape[0], params[0])
[docs] @classmethod @override def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: """Expand scalar to diagonal vector.""" n = shape[0] return jnp.full(n, params[0])
[docs] @classmethod @override def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: """Average diagonal elements to scalar.""" return jnp.array([jnp.mean(params)])
[docs] class Identity(Scale): """The identity matrix $A = I$. Zero parameters --- fully determined by shape."""
[docs] @classmethod @override def matvec(cls, shape: tuple[int, int], params: Array, vector: Array) -> Array: return vector
[docs] @classmethod @override def is_positive_definite(cls, shape: tuple[int, int], params: Array) -> Array: """Identity is always positive definite.""" return jnp.array(True)
[docs] @classmethod @override def to_matrix(cls, shape: tuple[int, int], params: Array) -> Array: n = shape[0] return jnp.eye(n)
[docs] @classmethod @override def from_matrix(cls, matrix: Array) -> Array: _ = matrix return jnp.array([])
[docs] @classmethod @override def num_params(cls, shape: tuple[int, int]) -> int: return 0
[docs] @classmethod @override def inverse(cls, shape: tuple[int, int], params: Array) -> Array: return params
[docs] @classmethod @override def logdet(cls, shape: tuple[int, int], params: Array) -> Array: return jnp.array(0.0)
[docs] @classmethod @override def outer_product(cls, v1: Array, v2: Array) -> Array: """Identity ignores input vectors.""" return jnp.array([])
[docs] @classmethod @override def cholesky_matvec( cls, shape: tuple[int, int], params: Array, vector: Array ) -> Array: return vector
[docs] @classmethod @override def cholesky_whiten( cls, shape: tuple[int, int], mean1: Array, params1: Array, mean2: Array, params2: Array, ) -> tuple[Array, Array]: # Identity has no parameters; whitening is just centering the mean return mean1 - mean2, params1
[docs] @classmethod @override def get_diagonal(cls, shape: tuple[int, int], params: Array) -> Array: return jnp.ones(shape[0])
[docs] @classmethod @override def embed_in_super(cls, shape: tuple[int, int], params: Array) -> Array: """Empty params to unit scalar.""" return jnp.array([1.0])
[docs] @classmethod @override def project_from_super(cls, shape: tuple[int, int], params: Array) -> Array: """Scalar to empty params.""" return jnp.array([])