Coverage for eminus/backend.py: 56.67%

60 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-23 09:07 +0000

1# SPDX-FileCopyrightText: 2025 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Array backend handling. 

4 

5Includes helper and compatibility functions. 

6For more information see https://wangenau.gitlab.io/eminus/backend.html. 

7""" 

8 

9import pathlib 

10import sys 

11 

12import numpy as np 

13import scipy 

14 

15from . import config 

16 

17if "stubtest" not in pathlib.Path(sys.argv[0]).name: 

18 # Do not overwrite getattr when stubtest is running 

19 def __getattr__(name): 

20 """Access modules and functions of array backends by their name.""" 

21 if config.backend == "torch": 

22 from array_api_compat import torch as xp 

23 else: 

24 xp = np 

25 return getattr(xp, name) 

26 

27 

28# ### Helper functions ### 

29 

30 

31def is_array(value): 

32 """Check if the object is an NumPy array or Torch tensor. 

33 

34 Args: 

35 value: Input array. 

36 

37 Returns: 

38 If the value is an array supported by the available backends. 

39 """ 

40 if isinstance(value, np.ndarray): 

41 return True 

42 if config.backend == "torch": 

43 from array_api_compat import is_torch_array 

44 

45 return is_torch_array(value) 

46 return False 

47 

48 

49def to_np(arr): 

50 """Copy the array from the current device to a CPU NumPy array. 

51 

52 Args: 

53 arr: Input array. 

54 

55 Returns: 

56 Copied array on the CPU. 

57 """ 

58 if arr is None: 

59 return None 

60 try: 

61 from array_api_compat import is_torch_array 

62 

63 if is_torch_array(arr): 

64 return np.asarray(arr.resolve_conj().cpu()) 

65 # Wave functions are list of arrays, handle them as well 

66 if isinstance(arr, list) and len(arr) > 0 and is_torch_array(arr[0]): 

67 return [np.asarray(arr.resolve_conj().cpu()) for arr in arr] 

68 return np.asarray(arr) 

69 except ImportError: 

70 return np.asarray(arr) 

71 

72 

73# ### Compatibility functions ### 

74 

75 

76def delete(arr, obj, axis=None): 

77 """Return a new array with sub-arrays along an axis deleted. 

78 

79 Ref: https://gist.github.com/velikodniy/6efef837e67aee2e7152eb5900eb0258 

80 

81 Args: 

82 arr: Input array. 

83 obj: Indicate indices of sub-arrays to remove along the specified axis. 

84 

85 Keyword Args: 

86 axis: The axis along which to delete the subarray defined by obj. If `axis` is `None`, `obj` 

87 is applied to the flattened array. 

88 

89 Returns: 

90 A copy of `arr` with the elements specified by `obj` removed. If `axis` is `None`, `out` is 

91 a flattened array. 

92 """ 

93 if isinstance(arr, np.ndarray): 

94 return np.delete(arr, obj, axis) 

95 if axis is None: 

96 axis = 0 

97 arr = arr.ravel() 

98 if is_array(obj): 

99 obj = to_np(obj) 

100 else: 

101 obj = np.asarray(obj) 

102 skip = [i for i in range(arr.size(axis)) if i not in obj] 

103 indices = tuple(slice(None) if i != axis else skip for i in range(arr.ndim)) 

104 return arr[indices] 

105 

106 

107def fftn(x, *args, **kwargs): 

108 """Compute the N-D discrete Fourier Transform. 

109 

110 Use SciPy FFTs since they are faster, support parallelism, and are more accurate. 

111 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests. 

112 

113 Args: 

114 x: Input array, can be complex. 

115 args: Pass-through arguments. 

116 

117 Keyword Args: 

118 **kwargs: Pass-through keyword arguments. 

119 

120 Returns: 

121 Value of the fftn function at x. 

122 """ 

123 if isinstance(x, np.ndarray): 

124 return scipy.fft.fftn(x, *args, **kwargs, workers=config.threads) 

125 from array_api_compat import array_namespace 

126 

127 xp = array_namespace(x) 

128 return xp.fft.fftn(x, *args, **kwargs) 

129 

130 

131def ifftn(x, *args, **kwargs): 

132 """Compute the N-D inverse discrete Fourier Transform. 

133 

134 Use SciPy FFTs since they are faster, support parallelism, and are more accurate. 

135 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests. 

136 

137 Args: 

138 x: Input array, can be complex. 

139 args: Pass-through arguments. 

140 

141 Keyword Args: 

142 **kwargs: Pass-through keyword arguments. 

143 

144 Returns: 

145 Value of the ifftn function at x. 

146 """ 

147 if isinstance(x, np.ndarray): 

148 return scipy.fft.ifftn(x, *args, **kwargs, workers=config.threads) 

149 from array_api_compat import array_namespace 

150 

151 xp = array_namespace(x) 

152 return xp.fft.ifftn(x, *args, **kwargs) 

153 

154 

155def sqrtm(A, *args, **kwargs): 

156 """Matrix square root. 

157 

158 Args: 

159 A: Matrix whose square root to evaluate. 

160 args: Pass-through arguments. 

161 

162 Keyword Args: 

163 **kwargs: Pass-through keyword arguments. 

164 

165 Returns: 

166 Value of the sqrt function at A. 

167 """ 

168 if isinstance(A, np.ndarray): 

169 return np.asarray(scipy.linalg.sqrtm(A, *args, **kwargs), dtype=complex) 

170 from array_api_compat import array_namespace 

171 

172 xp = array_namespace(A) 

173 return xp.asarray( 

174 np.asarray(scipy.linalg.sqrtm(to_np(A), *args, **kwargs), dtype=complex), dtype=complex 

175 )