Skip to content

dq.tracemm

tracemm(x: ArrayLike, y: ArrayLike) -> Array

Return the trace of a matrix multiplication using a fast implementation.

The trace is computed as sum(x * y.T) where * is the element-wise product, instead of trace(x @ y) where @ is the matrix product. Indeed, we have:

\[ \tr{xy} = \sum_i (xy)_{ii} = \sum_{i,j} x_{ij} y_{ji} = \sum_{i,j} x_{ij} (y^\intercal)_{ij} = \sum_{i,j} (x * y^\intercal)_{ij} \]
Note

The resulting time complexity for \(n\times n\) matrices is \(\mathcal{O}(n^2)\) instead of \(\mathcal{O}(n^3)\) with the naive formula.

Parameters

  • x (array_like of shape (..., n, n))

    Array.

  • y (array_like of shape (..., n, n))

    Array.

Returns

(array of shape (...)) Trace of x @ y.

Examples

>>> x = jnp.ones((3, 3))
>>> y = jnp.ones((3, 3))
>>> dq.tracemm(x, y)
Array(9., dtype=float32)