Source code for piel.utils.numerical
import jax.numpy as jnp
from piel.types import ArrayTypes
__all__ = [
"round_complex_array",
]
[docs]
def round_complex_array(
array: ArrayTypes,
to_absolute: bool = False,
):
"""
Rounds the elements of a complex JAX numpy array to the nearest integer.
Parameters:
- array: A complex JAX numpy array.
- absolute: A boolean that determines whether the complex numbers are rounded to the nearest integers in their absolute value.
Returns:
- A JAX numpy array with the complex elements rounded to the nearest integers.
"""
real_part = jnp.around(array.real) # Round the real parts to the nearest integer
imaginary_part = jnp.around(
array.imag
) # Round the imaginary parts to the nearest integer
value = real_part + 1j * imaginary_part # Recombine the real and imaginary parts
if to_absolute:
value = jnp.abs(value) # Take the absolute value of the complex number
return value