Coverage for eminus/utils.py: 97.87%

141 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 12:19 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Linear algebra calculation utilities.""" 

4 

5import functools 

6import math 

7import pathlib 

8import re 

9 

10import numpy as np 

11 

12import eminus 

13 

14from . import backend as xp 

15from .units import rad2deg 

16 

17 

18class BaseObject: 

19 """Base eminus class that implements some shared functionalities.""" 

20 

21 def view(self, *args, **kwargs): 

22 """Unified display function. 

23 

24 Args: 

25 args: Pass-through arguments. 

26 

27 Keyword Args: 

28 kwargs: Pass-through keyword arguments. 

29 

30 Returns: 

31 Viewable object. 

32 """ 

33 return eminus.extras.view(self, *args, **kwargs) 

34 

35 def write(self, filename, *args, **kwargs): 

36 """Unified file writer function. 

37 

38 Args: 

39 filename: Input file path/name. 

40 args: Pass-through arguments. 

41 

42 Keyword Args: 

43 kwargs: Pass-through keyword arguments. 

44 

45 Returns: 

46 None. 

47 """ 

48 # Save the object as a JSON file if no extension is given 

49 if not pathlib.Path(filename).suffix and "POSCAR" not in filename: 

50 filename += ".json" 

51 return eminus.io.write(self, filename, *args, **kwargs) 

52 

53 

54def dotprod(a, b): 

55 """Efficiently calculate the expression a * b. 

56 

57 Add an extra check to make sure the result is never zero since this function is used as a 

58 denominator in minimizers. 

59 

60 Args: 

61 a: Array of vectors. 

62 b: Array of vectors. 

63 

64 Returns: 

65 The expressions result. 

66 """ 

67 eps = 1e-15 # 2.22e-16 is the range of float64 machine precision 

68 # The dot product of complex vectors looks like the expression below, but this is slow 

69 # res = xp.real(xp.trace(a.conj().T @ b)) 

70 # We can calculate the trace faster by taking the sum of the Hadamard product 

71 res = xp.sum(a.conj() * b) 

72 if abs(res) < eps: 

73 return eps 

74 return xp.real(res) 

75 

76 

77def Ylm_real(l, m, G): # noqa: C901 

78 """Calculate real spherical harmonics from cartesian coordinates. 

79 

80 Reference: https://scipython.com/blog/visualizing-the-real-forms-of-the-spherical-harmonics 

81 

82 Args: 

83 l: Angular momentum number. 

84 m: Magnetic quantum number. 

85 G: Reciprocal lattice vector or array of lattice vectors. 

86 

87 Returns: 

88 Real spherical harmonics. 

89 """ 

90 eps = 1e-9 

91 # Account for single vectors 

92 G = xp.atleast_2d(G) 

93 

94 # No need to calculate more for l=0 

95 if l == 0: 

96 return 0.5 * math.sqrt(1 / math.pi) * xp.ones(len(G)) 

97 

98 # cos(theta)=Gz/|G| 

99 Gm = xp.linalg.norm(G, axis=1) 

100 with np.errstate(divide="ignore", invalid="ignore"): 

101 cos_theta = G[:, 2] / Gm 

102 # Account for small magnitudes, if norm(G) < eps: cos_theta=0 

103 cos_theta[Gm < eps] = 0 

104 

105 # Vectorized version of sin(theta)=sqrt(max(0, 1-cos_theta^2)) 

106 sin_theta = xp.sqrt(xp.max(xp.stack([xp.zeros_like(cos_theta), 1 - cos_theta**2]), axis=0)) 

107 

108 # phi=arctan(Gy/Gx) 

109 phi = xp.arctan2(G[:, 1], G[:, 0]) 

110 # If Gx=0: phi=pi/2*sign(Gy) 

111 phi_idx = xp.abs(G[:, 0]) < eps 

112 phi[phi_idx] = math.pi / 2 * xp.sign(G[phi_idx, 1]) 

113 

114 if l == 1: 

115 if m == -1: # py 

116 return 0.5 * math.sqrt(3 / math.pi) * sin_theta * xp.sin(phi) 

117 if m == 0: # pz 

118 return 0.5 * math.sqrt(3 / math.pi) * cos_theta 

119 if m == 1: # px 

120 return 0.5 * math.sqrt(3 / math.pi) * sin_theta * xp.cos(phi) 

121 elif l == 2: 

122 if m == -2: # dxy 

123 return math.sqrt(15 / 16 / math.pi) * sin_theta**2 * xp.sin(2 * phi) 

124 if m == -1: # dyz 

125 return math.sqrt(15 / 4 / math.pi) * cos_theta * sin_theta * xp.sin(phi) 

126 if m == 0: # dz2 

127 return 0.25 * math.sqrt(5 / math.pi) * (3 * cos_theta**2 - 1) 

128 if m == 1: # dxz 

129 return math.sqrt(15 / 4 / math.pi) * cos_theta * sin_theta * xp.cos(phi) 

130 if m == 2: # dx2-y2 

131 return math.sqrt(15 / 16 / math.pi) * sin_theta**2 * xp.cos(2 * phi) 

132 elif l == 3: 

133 if m == -3: 

134 return 0.25 * math.sqrt(35 / 2 / math.pi) * sin_theta**3 * xp.sin(3 * phi) 

135 if m == -2: 

136 return 0.25 * math.sqrt(105 / math.pi) * sin_theta**2 * cos_theta * xp.sin(2 * phi) 

137 if m == -1: 

138 return ( 

139 0.25 

140 * math.sqrt(21 / 2 / math.pi) 

141 * sin_theta 

142 * (5 * cos_theta**2 - 1) 

143 * xp.sin(phi) 

144 ) 

145 if m == 0: 

146 return 0.25 * math.sqrt(7 / math.pi) * (5 * cos_theta**3 - 3 * cos_theta) 

147 if m == 1: 

148 return ( 

149 0.25 

150 * math.sqrt(21 / 2 / math.pi) 

151 * sin_theta 

152 * (5 * cos_theta**2 - 1) 

153 * xp.cos(phi) 

154 ) 

155 if m == 2: 

156 return 0.25 * math.sqrt(105 / math.pi) * sin_theta**2 * cos_theta * xp.cos(2 * phi) 

157 if m == 3: 

158 return 0.25 * math.sqrt(35 / 2 / math.pi) * sin_theta**3 * xp.cos(3 * phi) 

159 

160 msg = f"No definition found for Ylm({l}, {m})." 

161 raise ValueError(msg) 

162 

163 

164def handle_spin(func): 

165 """Handle spin calculating the function for each channel separately. 

166 

167 This can only be applied if the only spin-dependent indexing is the wave function W. 

168 

169 Implementing the explicit handling of spin adds an extra layer of complexity where one has to 

170 loop over the spin states in many places. We can hide this complexity using this decorator while 

171 still supporting many use cases, e.g., the operators previously act on arrays containing wave 

172 functions of all states and of one state only. This decorator maintains this functionality and 

173 adds the option to act on arrays containing wave functions of all spins and all states as well. 

174 

175 Args: 

176 func: Function that acts on spin-states. 

177 

178 Returns: 

179 Decorator. 

180 """ 

181 

182 @functools.wraps(func) 

183 def decorator(obj, W, *args, **kwargs): 

184 if W.ndim == 3: 

185 return xp.stack([func(obj, Wspin, *args, **kwargs) for Wspin in W]) 

186 return func(obj, W, *args, **kwargs) 

187 

188 return decorator 

189 

190 

191def handle_k(func=None, *, mode="gracefully"): 

192 """Handle k-points calculating the function for each channel with different modes. 

193 

194 This uses the same principle as described in :func:`~eminus.utils.handle_spin`. 

195 

196 Keyword Args: 

197 func: Function that acts on k-points. 

198 mode: How to handle the k-point dependency. 

199 

200 Returns: 

201 Decorator. 

202 """ 

203 if func is None: 

204 return functools.partial(handle_k, mode=mode) 

205 

206 @functools.wraps(func) 

207 def decorator(obj, W, *args, **kwargs): 

208 if isinstance(W, list) or (xp.is_array(W) and W.ndim == 4): 

209 # No explicit k-point indexing is needed 

210 if mode == "gracefully": 

211 return [func(obj, Wk, *args, **kwargs) for Wk in W] 

212 # Explicit k-point indexing is needed 

213 if mode == "index": 

214 return [func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)] 

215 # Explicit k-point indexing is needed and the result has to be summed up 

216 if mode == "reduce": 

217 # The Python sum allows summing single values and NumPy arrays elementwise 

218 return sum(func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)) 

219 # No k-point dependency has been implemented, so skip it 

220 if mode == "skip": 

221 obj._atoms.kpts._assert_gamma_only() 

222 ret = func(obj, W[0], *args, **kwargs) 

223 if xp.is_array(ret) and ret.ndim == 3: 

224 return [ret] 

225 return ret 

226 return func(obj, W, *args, **kwargs) 

227 

228 return decorator 

229 

230 

231def pseudo_uniform(size, seed=1234): 

232 """Lehmer random number generator, following MINSTD. 

233 

234 Reference: Commun. ACM. 12, 85. 

235 

236 Args: 

237 size: Dimension of the array to create. 

238 

239 Keyword Args: 

240 seed: Seed to initialize the random number generator. 

241 

242 Returns: 

243 Array with (pseudo) random numbers. 

244 """ 

245 W = xp.empty(size, dtype=complex) 

246 mult = 48271 

247 mod = (2**31) - 1 

248 x = (seed * mult + 1) % mod 

249 for i in range(size[0]): 

250 for j in range(size[1]): 

251 for k in range(size[2]): 

252 x = (x * mult + 1) % mod 

253 W[i, j, k] = x / mod 

254 return W 

255 

256 

257def add_maybe_none(a, b): 

258 """Add a and b together, when one or both can potentially be None. 

259 

260 Args: 

261 a: Array or None. 

262 b: Array or None. 

263 

264 Returns: 

265 Sum of a and b. 

266 """ 

267 if a is b is None: 

268 return None 

269 if a is None: 

270 return b 

271 if b is None: 

272 return a 

273 return a + b 

274 

275 

276def molecule2list(molecule): 

277 """Expand a chemical formula to a list of chemical symbols. 

278 

279 No charges or parentheses are allowed, only chemical symbols followed by their amount. 

280 

281 Args: 

282 molecule: Simplified chemical formula (case sensitive). 

283 

284 Returns: 

285 Atoms of the molecule expanded to a list. 

286 """ 

287 # Insert a whitespace before every capital letter, these can appear once or none at all 

288 # Or insert before digits, these can appear at least once 

289 tmp_list = re.sub(r"([A-Z?]|\d+)", r" \1", molecule).split() 

290 atom_list = [] 

291 for ia in tmp_list: 

292 if ia.isdigit(): 

293 # If ia is an integer append the previous atom ia-1 times 

294 atom_list += [atom_list[-1]] * (int(ia) - 1) 

295 else: 

296 # If ia is a string add it to the results list 

297 atom_list += [ia] 

298 return atom_list 

299 

300 

301def atom2charge(atom, path=None): 

302 """Get the valence charges for a list of chemical symbols from GTH files. 

303 

304 Args: 

305 atom: Atom symbols. 

306 path: Directory of GTH files. 

307 

308 Returns: 

309 Valence charges per atom. 

310 """ 

311 # Import here to prevent circular imports 

312 from .io import read_gth 

313 

314 if path is not None: 

315 if path.lower() in {"pade", "pbe"}: 

316 psp_path = path.lower() 

317 else: 

318 psp_path = path 

319 else: 

320 psp_path = "pbe" 

321 return [read_gth(ia, psp_path=psp_path)["Zion"] for ia in atom] 

322 

323 

324def vector_angle(a, b): 

325 """Calculate the angle between two vectors. 

326 

327 Args: 

328 a: Vector. 

329 b: Vector. 

330 

331 Returns: 

332 Angle between a and b in Degree. 

333 """ 

334 # Normalize vectors first 

335 a, b = xp.asarray(a, dtype=float), xp.asarray(b, dtype=float) 

336 a_norm = a / xp.linalg.norm(a) 

337 b_norm = b / xp.linalg.norm(b) 

338 angle = xp.arccos(a_norm @ b_norm) 

339 return rad2deg(angle) 

340 

341 

342def get_lattice(lattice_vectors): 

343 """Generate a cell for given lattice vectors. 

344 

345 Args: 

346 lattice_vectors: Lattice vectors. 

347 

348 Returns: 

349 Lattice vertices. 

350 """ 

351 # Vertices of a cube 

352 vertices = xp.asarray( 

353 [ 

354 [0, 0, 0], 

355 [0, 0, 1], 

356 [0, 1, 0], 

357 [0, 1, 1], 

358 [1, 0, 0], 

359 [1, 0, 1], 

360 [1, 1, 0], 

361 [1, 1, 1], 

362 ], 

363 dtype=float, 

364 ) 

365 # Connected vertices of a cube with the above ordering 

366 edges = xp.asarray( 

367 [ 

368 [0, 1], 

369 [0, 2], 

370 [0, 4], 

371 [1, 3], 

372 [1, 5], 

373 [2, 3], 

374 [2, 6], 

375 [3, 7], 

376 [4, 5], 

377 [4, 6], 

378 [5, 7], 

379 [6, 7], 

380 ] 

381 ) 

382 # Scale vertices with the lattice vectors 

383 # Select pairs of vertices to plot them later 

384 # The resulting return value is similar to the get_brillouin_zone function 

385 return [(vertices @ lattice_vectors)[e, :] for e in edges]