Backends
TensorRSVD supports three array backends: NumPy (default), JAX, and
CuPy. The backend is selected by passing the backend keyword argument to
any top-level or core function.
If your tensor is defined through JAX, the callable (the tensor function) must be JAX-traceable. This means it must be compatible with JAX’s JIT compilation (i.e., it must be a pure function without side effects) and it must use JAX’s array operations only.
All linear-algebra primitives (QR, SVD, random normal sampling) and array creation routines are resolved at runtime through the helper functions below, which return the correct callable for the requested backend.
Linear algebra
- tensorrsvd.backends.get_qr(backend)[source]
Return the QR decomposition function for the given backend.
- tensorrsvd.backends.get_svd(backend)[source]
Return a full_matrices=False SVD function for the given backend.
Random sampling
- tensorrsvd.backends.get_normal(backend, seed=0)[source]
Return a standard-normal sampling function for the given backend.
Array creation
- tensorrsvd.backends.get_meshgrid(backend)[source]
Return the meshgrid function for the given backend.
- tensorrsvd.backends.get_empty(backend)[source]
Return the empty-array constructor for the given backend.
- tensorrsvd.backends.get_zeros(backend)[source]
Return the zeros-array constructor for the given backend.
Element-wise operations
- tensorrsvd.backends.get_conj(backend)[source]
Return the complex-conjugate function for the given backend.
Dtype utilities