dq.set_matmul_precision
set_matmul_precision(matmul_precision: Literal['low', 'high', 'highest'])
Configure the default precision for matrix multiplications on GPUs and TPUs.
Some devices allow trading off accuracy for speed when performing matrix multiplications (matmul). Three options are available:
'low'reduces matmul precision tobfloat16(fastest but least accurate),'high'reduces matmul precision tobfloat16_3xortensorfloat32if available (faster but less accurate),'highest'keeps matmul precision tofloat32orfloat64as applicable (slowest but most accurate, default setting).
Equivalent JAX syntax
This function is equivalent to setting jax_default_matmul_precision in
jax.config. See JAX documentation on matmul precision
and JAX documentation on the different available options.
Parameters
-
matmul_precision
(string 'low', 'high', or 'highest')
–
Default precision for matrix multiplications on GPUs and TPUs.