Coverage for eminus/backend.py: 56.67%
60 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-23 09:07 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-23 09:07 +0000
1# SPDX-FileCopyrightText: 2025 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Array backend handling.
5Includes helper and compatibility functions.
6For more information see https://wangenau.gitlab.io/eminus/backend.html.
7"""
9import pathlib
10import sys
12import numpy as np
13import scipy
15from . import config
17if "stubtest" not in pathlib.Path(sys.argv[0]).name:
18 # Do not overwrite getattr when stubtest is running
19 def __getattr__(name):
20 """Access modules and functions of array backends by their name."""
21 if config.backend == "torch":
22 from array_api_compat import torch as xp
23 else:
24 xp = np
25 return getattr(xp, name)
28# ### Helper functions ###
31def is_array(value):
32 """Check if the object is an NumPy array or Torch tensor.
34 Args:
35 value: Input array.
37 Returns:
38 If the value is an array supported by the available backends.
39 """
40 if isinstance(value, np.ndarray):
41 return True
42 if config.backend == "torch":
43 from array_api_compat import is_torch_array
45 return is_torch_array(value)
46 return False
49def to_np(arr):
50 """Copy the array from the current device to a CPU NumPy array.
52 Args:
53 arr: Input array.
55 Returns:
56 Copied array on the CPU.
57 """
58 if arr is None:
59 return None
60 try:
61 from array_api_compat import is_torch_array
63 if is_torch_array(arr):
64 return np.asarray(arr.resolve_conj().cpu())
65 # Wave functions are list of arrays, handle them as well
66 if isinstance(arr, list) and len(arr) > 0 and is_torch_array(arr[0]):
67 return [np.asarray(arr.resolve_conj().cpu()) for arr in arr]
68 return np.asarray(arr)
69 except ImportError:
70 return np.asarray(arr)
73# ### Compatibility functions ###
76def delete(arr, obj, axis=None):
77 """Return a new array with sub-arrays along an axis deleted.
79 Ref: https://gist.github.com/velikodniy/6efef837e67aee2e7152eb5900eb0258
81 Args:
82 arr: Input array.
83 obj: Indicate indices of sub-arrays to remove along the specified axis.
85 Keyword Args:
86 axis: The axis along which to delete the subarray defined by obj. If `axis` is `None`, `obj`
87 is applied to the flattened array.
89 Returns:
90 A copy of `arr` with the elements specified by `obj` removed. If `axis` is `None`, `out` is
91 a flattened array.
92 """
93 if isinstance(arr, np.ndarray):
94 return np.delete(arr, obj, axis)
95 if axis is None:
96 axis = 0
97 arr = arr.ravel()
98 if is_array(obj):
99 obj = to_np(obj)
100 else:
101 obj = np.asarray(obj)
102 skip = [i for i in range(arr.size(axis)) if i not in obj]
103 indices = tuple(slice(None) if i != axis else skip for i in range(arr.ndim))
104 return arr[indices]
107def fftn(x, *args, **kwargs):
108 """Compute the N-D discrete Fourier Transform.
110 Use SciPy FFTs since they are faster, support parallelism, and are more accurate.
111 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests.
113 Args:
114 x: Input array, can be complex.
115 args: Pass-through arguments.
117 Keyword Args:
118 **kwargs: Pass-through keyword arguments.
120 Returns:
121 Value of the fftn function at x.
122 """
123 if isinstance(x, np.ndarray):
124 return scipy.fft.fftn(x, *args, **kwargs, workers=config.threads)
125 from array_api_compat import array_namespace
127 xp = array_namespace(x)
128 return xp.fft.fftn(x, *args, **kwargs)
131def ifftn(x, *args, **kwargs):
132 """Compute the N-D inverse discrete Fourier Transform.
134 Use SciPy FFTs since they are faster, support parallelism, and are more accurate.
135 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests.
137 Args:
138 x: Input array, can be complex.
139 args: Pass-through arguments.
141 Keyword Args:
142 **kwargs: Pass-through keyword arguments.
144 Returns:
145 Value of the ifftn function at x.
146 """
147 if isinstance(x, np.ndarray):
148 return scipy.fft.ifftn(x, *args, **kwargs, workers=config.threads)
149 from array_api_compat import array_namespace
151 xp = array_namespace(x)
152 return xp.fft.ifftn(x, *args, **kwargs)
155def sqrtm(A, *args, **kwargs):
156 """Matrix square root.
158 Args:
159 A: Matrix whose square root to evaluate.
160 args: Pass-through arguments.
162 Keyword Args:
163 **kwargs: Pass-through keyword arguments.
165 Returns:
166 Value of the sqrt function at A.
167 """
168 if isinstance(A, np.ndarray):
169 return np.asarray(scipy.linalg.sqrtm(A, *args, **kwargs), dtype=complex)
170 from array_api_compat import array_namespace
172 xp = array_namespace(A)
173 return xp.asarray(
174 np.asarray(scipy.linalg.sqrtm(to_np(A), *args, **kwargs), dtype=complex), dtype=complex
175 )