Coverage for eminus/backend.py: 55.17%

58 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 12:19 +0000

1# SPDX-FileCopyrightText: 2025 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Array backend handling. 

4 

5Includes helper and compatibility functions. 

6For more information see https://wangenau.gitlab.io/eminus/backend.html. 

7""" 

8 

9import pathlib 

10import sys 

11 

12import numpy as np 

13import scipy 

14 

15from . import config 

16 

17if "stubtest" not in pathlib.Path(sys.argv[0]).name: 

18 # Do not overwrite getattr when stubtest is running 

19 def __getattr__(name): 

20 """Access modules and functions of array backends by their name.""" 

21 if config.backend == "torch": 

22 from array_api_compat import torch as xp 

23 else: 

24 xp = np 

25 return getattr(xp, name) 

26 

27 

28# ### Helper functions ### 

29 

30 

31def is_array(value): 

32 """Check if the object is an NumPy array or Torch tensor. 

33 

34 Args: 

35 value: Input array. 

36 

37 Returns: 

38 If the value is an array supported by the available backends. 

39 """ 

40 if isinstance(value, np.ndarray): 

41 return True 

42 if config.backend == "torch": 

43 from array_api_compat import is_torch_array 

44 

45 return is_torch_array(value) 

46 return False 

47 

48 

49def to_np(arr): 

50 """Copy the array from the current device to a CPU NumPy array. 

51 

52 Args: 

53 arr: Input array. 

54 

55 Returns: 

56 Copied array on the CPU. 

57 """ 

58 try: 

59 from array_api_compat import is_torch_array 

60 

61 if is_torch_array(arr): 

62 return np.asarray(arr.resolve_conj().cpu()) 

63 # Wave functions are list of arrays, handle them as well 

64 if isinstance(arr, list) and len(arr) > 0 and is_torch_array(arr[0]): 

65 return [np.asarray(arr.resolve_conj().cpu()) for arr in arr] 

66 return np.asarray(arr) 

67 except ImportError: 

68 return np.asarray(arr) 

69 

70 

71# ### Compatibility functions ### 

72 

73 

74def delete(arr, obj, axis=None): 

75 """Return a new array with sub-arrays along an axis deleted. 

76 

77 Ref: https://gist.github.com/velikodniy/6efef837e67aee2e7152eb5900eb0258 

78 

79 Args: 

80 arr: Input array. 

81 obj: Indicate indices of sub-arrays to remove along the specified axis. 

82 

83 Keyword Args: 

84 axis: The axis along which to delete the subarray defined by obj. If `axis` is `None`, `obj` 

85 is applied to the flattened array. 

86 

87 Returns: 

88 A copy of `arr` with the elements specified by `obj` removed. If `axis` is `None`, `out` is 

89 a flattened array. 

90 """ 

91 if isinstance(arr, np.ndarray): 

92 return np.delete(arr, obj, axis) 

93 if axis is None: 

94 axis = 0 

95 arr = arr.ravel() 

96 if is_array(obj): 

97 obj = to_np(obj) 

98 else: 

99 obj = np.asarray(obj) 

100 skip = [i for i in range(arr.size(axis)) if i not in obj] 

101 indices = tuple(slice(None) if i != axis else skip for i in range(arr.ndim)) 

102 return arr[indices] 

103 

104 

105def fftn(x, *args, **kwargs): 

106 """Compute the N-D discrete Fourier Transform. 

107 

108 Use SciPy FFTs since they are faster, support parallelism, and are more accurate. 

109 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests. 

110 

111 Args: 

112 x: Input array, can be complex. 

113 args: Pass-through arguments. 

114 

115 Keyword Args: 

116 **kwargs: Pass-through keyword arguments. 

117 

118 Returns: 

119 Value of the fftn function at x. 

120 """ 

121 if isinstance(x, np.ndarray): 

122 return scipy.fft.fftn(x, *args, **kwargs, workers=config.threads) 

123 from array_api_compat import array_namespace 

124 

125 xp = array_namespace(x) 

126 return xp.fft.fftn(x, *args, **kwargs) 

127 

128 

129def ifftn(x, *args, **kwargs): 

130 """Compute the N-D inverse discrete Fourier Transform. 

131 

132 Use SciPy FFTs since they are faster, support parallelism, and are more accurate. 

133 They will upcast some complex arrays to complex256. Using NumPy FFTs will fail some tests. 

134 

135 Args: 

136 x: Input array, can be complex. 

137 args: Pass-through arguments. 

138 

139 Keyword Args: 

140 **kwargs: Pass-through keyword arguments. 

141 

142 Returns: 

143 Value of the ifftn function at x. 

144 """ 

145 if isinstance(x, np.ndarray): 

146 return scipy.fft.ifftn(x, *args, **kwargs, workers=config.threads) 

147 from array_api_compat import array_namespace 

148 

149 xp = array_namespace(x) 

150 return xp.fft.ifftn(x, *args, **kwargs) 

151 

152 

153def sqrtm(A, *args, **kwargs): 

154 """Matrix square root. 

155 

156 Args: 

157 A: Matrix whose square root to evaluate. 

158 args: Pass-through arguments. 

159 

160 Keyword Args: 

161 **kwargs: Pass-through keyword arguments. 

162 

163 Returns: 

164 Value of the sqrt function at A. 

165 """ 

166 if isinstance(A, np.ndarray): 

167 return np.asarray(scipy.linalg.sqrtm(A, *args, **kwargs), dtype=complex) 

168 from array_api_compat import array_namespace 

169 

170 xp = array_namespace(A) 

171 return xp.asarray( 

172 np.asarray(scipy.linalg.sqrtm(to_np(A), *args, **kwargs), dtype=complex), dtype=complex 

173 )