Skip to content

dq.set_matmul_precision

set_matmul_precision(matmul_precision: Literal['low', 'high', 'highest'])

Configure the default precision for matrix multiplications on GPUs and TPUs.

Some devices allow trading off accuracy for speed when performing matrix multiplications (matmul). Three options are available:

  • 'low' reduces matmul precision to bfloat16 (fastest but least accurate),
  • 'high' reduces matmul precision to bfloat16_3x or tensorfloat32 if available (faster but less accurate),
  • 'highest' keeps matmul precision to float32 or float64 as applicable (slowest but most accurate, default setting).
Equivalent JAX syntax

This function is equivalent to setting jax_default_matmul_precision in jax.config. See JAX documentation on matmul precision and JAX documentation on the different available options.

Parameters

  • matmul_precision (string 'low', 'high', or 'highest') –

    Default precision for matrix multiplications on GPUs and TPUs.