Coverage for eminus/backend.py: 55.17%
58 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
1# SPDX-FileCopyrightText: 2025 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Array backend handling.
5Includes helper and compatibility functions.
6For more information see https://wangenau.gitlab.io/eminus/backend.html.
7"""
9import pathlib
10import sys
12import numpy as np
13import scipy
15from . import config
17if "stubtest" not in pathlib.Path(sys.argv[0]).name:
18 # Do not overwrite getattr when stubtest is running
19 def __getattr__(name):
20 """Access modules and functions of array backends by their name."""
21 if config.backend == "torch":
22 from array_api_compat import torch as xp
23 else:
24 xp = np
25 return getattr(xp, name)
28# ### Helper functions ###
31def is_array(value):
32 """Check if the object is an NumPy array or Torch tensor.
34 Args:
35 value: Input array.
37 Returns:
38 If the value is an array supported by the available backends.
39 """
40 if isinstance(value, np.ndarray):
41 return True
42 if config.backend == "torch":
43 from array_api_compat import is_torch_array
45 return is_torch_array(value)
46 return False
49def to_np(arr):
50 """Copy the array from the current device to a CPU NumPy array.
52 Args:
53 arr: Input array.
55 Returns:
56 Copied array on the CPU.
57 """
58 try:
59 from array_api_compat import is_torch_array
61 if is_torch_array(arr):
62 return np.asarray(arr.resolve_conj().cpu())
63 # Wave functions are list of arrays, handle them as well
64 if isinstance(arr, list) and len(arr) > 0 and is_torch_array(arr[0]):
65 return [np.asarray(arr.resolve_conj().cpu()) for arr in arr]
66 return np.asarray(arr)
67 except ImportError:
68 return np.asarray(arr)
71# ### Compatibility functions ###
74def delete(arr, obj, axis=None):
75 """Return a new array with sub-arrays along an axis deleted.
77 Ref: https://gist.github.com/velikodniy/6efef837e67aee2e7152eb5900eb0258
79 Args:
80 arr: Input array.
81 obj: Indicate indices of sub-arrays to remove along the specified axis.
83 Keyword Args:
84 axis: The axis along which to delete the subarray defined by obj. If `axis` is `None`, `obj`
85 is applied to the flattened array.
87 Returns:
88 A copy of `arr` with the elements specified by `obj` removed. If `axis` is `None`, `out` is
89 a flattened array.
90 """
91 if isinstance(arr, np.ndarray):
92 return np.delete(arr, obj, axis)
93 if axis is None:
94 axis = 0
95 arr = arr.ravel()
96 if is_array(obj):
97 obj = to_np(obj)
98 else:
99 obj = np.asarray(obj)
100 skip = [i for i in range(arr.size(axis)) if i not in obj]
101 indices = tuple(slice(None) if i != axis else skip for i in range(arr.ndim))
102 return arr[indices]
105def fftn(x, *args, **kwargs):
106 """Compute the N-D discrete Fourier Transform.
108 Use SciPy FFTs since they are faster, support parallelism, and are more accurate.
109 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests.
111 Args:
112 x: Input array, can be complex.
113 args: Pass-through arguments.
115 Keyword Args:
116 **kwargs: Pass-through keyword arguments.
118 Returns:
119 Value of the fftn function at x.
120 """
121 if isinstance(x, np.ndarray):
122 return scipy.fft.fftn(x, *args, **kwargs, workers=config.threads)
123 from array_api_compat import array_namespace
125 xp = array_namespace(x)
126 return xp.fft.fftn(x, *args, **kwargs)
129def ifftn(x, *args, **kwargs):
130 """Compute the N-D inverse discrete Fourier Transform.
132 Use SciPy FFTs since they are faster, support parallelism, and are more accurate.
133 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests.
135 Args:
136 x: Input array, can be complex.
137 args: Pass-through arguments.
139 Keyword Args:
140 **kwargs: Pass-through keyword arguments.
142 Returns:
143 Value of the ifftn function at x.
144 """
145 if isinstance(x, np.ndarray):
146 return scipy.fft.ifftn(x, *args, **kwargs, workers=config.threads)
147 from array_api_compat import array_namespace
149 xp = array_namespace(x)
150 return xp.fft.ifftn(x, *args, **kwargs)
153def sqrtm(A, *args, **kwargs):
154 """Matrix square root.
156 Args:
157 A: Matrix whose square root to evaluate.
158 args: Pass-through arguments.
160 Keyword Args:
161 **kwargs: Pass-through keyword arguments.
163 Returns:
164 Value of the sqrt function at A.
165 """
166 if isinstance(A, np.ndarray):
167 return np.asarray(scipy.linalg.sqrtm(A, *args, **kwargs), dtype=complex)
168 from array_api_compat import array_namespace
170 xp = array_namespace(A)
171 return xp.asarray(
172 np.asarray(scipy.linalg.sqrtm(to_np(A), *args, **kwargs), dtype=complex), dtype=complex
173 )