from __future__ import annotations
import math
from collections.abc import Callable
import numpy as np
import numpy.typing as npt
import pylops
from pylops.utils.backend import deps
from ..backends import get_arange, get_conj, get_meshgrid, get_ravel, get_zeros
if deps.jax_enabled:
import jax
import jax.numpy as jnp
def _unit_lerp(idx, n):
"""Map integer indices in {0, ..., n-1} to the interval [0, 1].
For n == 1, returns 0 (single grid point at the origin). Works on both
scalars and arrays. Equivalent to np.linspace(0, 1, n)[idx] but without
materializing the full array.
"""
if n <= 1:
return idx * 0 # preserves dtype and backend array type
return idx / (n - 1)
def _insert_mode_coord(other_coords, mode_val, mode, num_idxs):
"""Build the full coordinate list by inserting mode_val at position mode."""
coords = [None] * num_idxs
j = 0
for d in range(num_idxs):
if d == mode:
coords[d] = mode_val
else:
coords[d] = other_coords[j]
j += 1
return coords
[docs]
class MatricizedTensorOperator(pylops.LinearOperator):
"""Linear operator for the mode-m unfolding of a tensor defined by a callable.
Parameters
----------
tensor_callable : Callable
Vectorized callable accepting k array arguments (coordinates in [0, 1])
and returning the tensor values at those coordinates. Must be
JAX-traceable if backend == 'jax'.
tensor_shape : tuple[int, ...]
Shape of the tensor (n_0, n_1, ..., n_{k-1}).
mode : int
Mode along which to unfold.
dtype : npt.DTypeLike
Data type of the operator.
backend : str
Array backend: 'numpy', 'cupy', or 'jax'.
"""
[docs]
def __init__(
self,
tensor_callable: Callable,
tensor_shape: tuple[int, ...],
mode: int,
dtype: npt.DTypeLike,
backend: str,
):
num_idxs = len(tensor_shape)
num_rows = tensor_shape[mode]
other_dims = [i for i in range(num_idxs) if i != mode]
num_cols = math.prod(tensor_shape[i] for i in other_dims)
self._callable = tensor_callable
self._tensor_shape = tuple(tensor_shape)
self._mode = mode
self._num_idxs = num_idxs
self._num_rows = num_rows
self._other_dims = other_dims
self._backend = backend
arange = get_arange(backend)
meshgrid = get_meshgrid(backend)
ravel = get_ravel(backend)
other_1d = [ # precompute the 1D coordinate arrays for the non-mode dimensions
_unit_lerp(arange(tensor_shape[i], dtype=dtype), tensor_shape[i]) for i in other_dims
]
mesh = meshgrid(*other_1d, indexing="ij")
self._other_coords = tuple(ravel(m) for m in mesh)
del mesh, other_1d
self._mode_vals = _unit_lerp(arange(num_rows, dtype=dtype), num_rows)
super().__init__(dtype=np.dtype(dtype), shape=(num_rows, num_cols))
# build jitted JAX kernels once to avoid retracing
if backend == "jax" and deps.jax_enabled:
self._jax_matmat_fn = self._build_jax_matmat()
self._jax_rmatmat_fn = self._build_jax_rmatmat()
def _eval_row(self, i_m: int):
"""Evaluate all tensor entries with mode index fixed to i_m."""
mode_val = self._mode_vals[i_m]
coords = _insert_mode_coord(self._other_coords, mode_val, self._mode, self._num_idxs)
return self._callable(*coords)
def _matvec(self, x):
if self._backend == "jax":
return self._jax_matmat_fn(x[:, None]).ravel()
zeros = get_zeros(self._backend)
y = zeros(self._num_rows, dtype=self.dtype)
for i in range(self._num_rows):
y[i] = self._eval_row(i) @ x
return y
def _rmatvec(self, y):
if self._backend == "jax":
return self._jax_rmatmat_fn(y[:, None]).ravel()
conj = get_conj(self._backend)
z = get_zeros(self._backend)(self.shape[1], dtype=self.dtype)
for i in range(self._num_rows):
z += y[i] * conj(self._eval_row(i))
return z
def _matmat(self, X):
"""Compute A @ X where X is (num_cols, l)."""
if self._backend == "jax":
return self._jax_matmat_fn(X)
X_cols = X.shape[1]
Y = get_zeros(self._backend)((self._num_rows, X_cols), dtype=self.dtype)
for i in range(self._num_rows):
row = self._eval_row(i)
Y[i, :] = row @ X
return Y
def _rmatmat(self, Y):
"""Compute A^H @ Y where Y is (num_rows, l)."""
if self._backend == "jax":
return self._jax_rmatmat_fn(Y)
conj = get_conj(self._backend)
Y_cols = Y.shape[1]
num_cols = self.shape[1]
Z = get_zeros(self._backend)((num_cols, Y_cols), dtype=self.dtype)
# rank-1 accumulation avoids an O(num_cols * l) broadcast per row
for i in range(self._num_rows):
crow = conj(self._eval_row(i))
for j in range(Y_cols):
Z[:, j] += Y[i, j] * crow
return Z
def _build_jax_matmat(self):
"""Return a jitted function computing A @ X via lax.fori_loop."""
fn = self._callable
other_coords = self._other_coords
mode_vals = self._mode_vals
mode = self._mode
num_idxs = self._num_idxs
num_rows = self._num_rows
dtype = self.dtype
@jax.jit
def matmat(X):
X_cols = X.shape[1]
def body(i, Y):
coords = _insert_mode_coord(other_coords, mode_vals[i], mode, num_idxs)
row = fn(*coords)
return Y.at[i].set(row @ X)
return jax.lax.fori_loop(0, num_rows, body, jnp.zeros((num_rows, X_cols), dtype=dtype))
return matmat
def _build_jax_rmatmat(self):
"""Return a jitted function computing A^H @ Y via lax.fori_loop."""
fn = self._callable
other_coords = self._other_coords
mode_vals = self._mode_vals
mode = self._mode
num_idxs = self._num_idxs
num_rows = self._num_rows
num_cols_val = self.shape[1]
dtype = self.dtype
@jax.jit
def rmatmat(Y):
Y_cols = Y.shape[1]
def body(i, Z):
coords = _insert_mode_coord(other_coords, mode_vals[i], mode, num_idxs)
row = fn(*coords)
return Z + jnp.conj(row)[:, None] * Y[i, :]
return jax.lax.fori_loop(
0, num_rows, body, jnp.zeros((num_cols_val, Y_cols), dtype=dtype)
)
return rmatmat