Coverage for eminus/utils.py: 95.36%

151 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 10:16 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Linear algebra calculation utilities.""" 

4 

5import functools 

6import pathlib 

7import re 

8 

9import numpy as np 

10from scipy.linalg import norm 

11 

12import eminus 

13 

14from . import config 

15from .units import rad2deg 

16 

17 

18class BaseObject: 

19 """Base eminus class that implements some shared functionalities.""" 

20 

21 def view(self, *args, **kwargs): 

22 """Unified display function. 

23 

24 Args: 

25 args: Pass-through arguments. 

26 

27 Keyword Args: 

28 kwargs: Pass-through keyword arguments. 

29 

30 Returns: 

31 Viewable object. 

32 """ 

33 return eminus.extras.view(self, *args, **kwargs) 

34 

35 def write(self, filename, *args, **kwargs): 

36 """Unified file writer function. 

37 

38 Args: 

39 filename: Input file path/name. 

40 args: Pass-through arguments. 

41 

42 Keyword Args: 

43 kwargs: Pass-through keyword arguments. 

44 

45 Returns: 

46 None. 

47 """ 

48 # Save the object as a JSON file if no extension is given 

49 if not pathlib.Path(filename).suffix and "POSCAR" not in filename: 

50 filename += ".json" 

51 return eminus.io.write(self, filename, *args, **kwargs) 

52 

53 

54def dotprod(a, b): 

55 """Efficiently calculate the expression a * b. 

56 

57 Add an extra check to make sure the result is never zero since this function is used as a 

58 denominator in minimizers. 

59 

60 Args: 

61 a: Array of vectors. 

62 b: Array of vectors. 

63 

64 Returns: 

65 The expressions result. 

66 """ 

67 eps = 1e-15 # 2.22e-16 is the range of float64 machine precision 

68 # The dot product of complex vectors looks like the expression below, but this is slow 

69 # res = np.real(np.trace(a.conj().T @ b)) 

70 # We can calculate the trace faster by taking the sum of the Hadamard product 

71 res = np.sum(a.conj() * b) 

72 if abs(res) < eps: 

73 return eps 

74 return np.real(res) 

75 

76 

77def Ylm_real(l, m, G): # noqa: C901, PLR0911 

78 """Calculate real spherical harmonics from cartesian coordinates. 

79 

80 Reference: https://scipython.com/blog/visualizing-the-real-forms-of-the-spherical-harmonics 

81 

82 Args: 

83 l: Angular momentum number. 

84 m: Magnetic quantum number. 

85 G: Reciprocal lattice vector or array of lattice vectors. 

86 

87 Returns: 

88 Real spherical harmonics. 

89 """ 

90 eps = 1e-9 

91 # Account for single vectors 

92 G = np.atleast_2d(G) 

93 

94 # No need to calculate more for l=0 

95 if l == 0: 

96 return 0.5 * np.sqrt(1 / np.pi) * np.ones(len(G)) 

97 

98 # cos(theta)=Gz/|G| 

99 Gm = norm(G, axis=1) 

100 with np.errstate(divide="ignore", invalid="ignore"): 

101 cos_theta = G[:, 2] / Gm 

102 # Account for small magnitudes, if norm(G) < eps: cos_theta=0 

103 cos_theta[Gm < eps] = 0 

104 

105 # Vectorized version of sin(theta)=sqrt(max(0, 1-cos_theta^2)) 

106 sin_theta = np.sqrt(np.amax((np.zeros_like(cos_theta), 1 - cos_theta**2), axis=0)) 

107 

108 # phi=arctan(Gy/Gx) 

109 phi = np.arctan2(G[:, 1], G[:, 0]) 

110 # If Gx=0: phi=pi/2*sign(Gy) 

111 phi_idx = np.abs(G[:, 0]) < eps 

112 phi[phi_idx] = np.pi / 2 * np.sign(G[phi_idx, 1]) 

113 

114 if l == 1: 

115 if m == -1: # py 

116 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.sin(phi) 

117 if m == 0: # pz 

118 return 0.5 * np.sqrt(3 / np.pi) * cos_theta 

119 if m == 1: # px 

120 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.cos(phi) 

121 elif l == 2: 

122 if m == -2: # dxy 

123 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.sin(2 * phi) 

124 if m == -1: # dyz 

125 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.sin(phi) 

126 if m == 0: # dz2 

127 return 0.25 * np.sqrt(5 / np.pi) * (3 * cos_theta**2 - 1) 

128 if m == 1: # dxz 

129 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.cos(phi) 

130 if m == 2: # dx2-y2 

131 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.cos(2 * phi) 

132 elif l == 3: 

133 if m == -3: 

134 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.sin(3 * phi) 

135 if m == -2: 

136 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.sin(2 * phi) 

137 if m == -1: 

138 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.sin(phi) 

139 if m == 0: 

140 return 0.25 * np.sqrt(7 / np.pi) * (5 * cos_theta**3 - 3 * cos_theta) 

141 if m == 1: 

142 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.cos(phi) 

143 if m == 2: 

144 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.cos(2 * phi) 

145 if m == 3: 

146 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.cos(3 * phi) 

147 

148 msg = f"No definition found for Ylm({l}, {m})." 

149 raise ValueError(msg) 

150 

151 

152def handle_spin(func): 

153 """Handle spin calculating the function for each channel separately. 

154 

155 This can only be applied if the only spin-dependent indexing is the wave function W. 

156 

157 Implementing the explicit handling of spin adds an extra layer of complexity where one has to 

158 loop over the spin states in many places. We can hide this complexity using this decorator while 

159 still supporting many use cases, e.g., the operators previously act on arrays containing wave 

160 functions of all states and of one state only. This decorator maintains this functionality and 

161 adds the option to act on arrays containing wave functions of all spins and all states as well. 

162 

163 Args: 

164 func: Function that acts on spin-states. 

165 

166 Returns: 

167 Decorator. 

168 """ 

169 

170 @functools.wraps(func) 

171 def decorator(obj, W, *args, **kwargs): 

172 if W.ndim == 3: 

173 return np.asarray([func(obj, Wspin, *args, **kwargs) for Wspin in W]) 

174 return func(obj, W, *args, **kwargs) 

175 

176 return decorator 

177 

178 

179def handle_k(func=None, *, mode="gracefully"): 

180 """Handle k-points calculating the function for each channel with different modes. 

181 

182 This uses the same principle as described in :func:`~eminus.utils.handle_spin`. 

183 

184 Keyword Args: 

185 func: Function that acts on k-points. 

186 mode: How to handle the k-point dependency. 

187 

188 Returns: 

189 Decorator. 

190 """ 

191 if func is None: 

192 return functools.partial(handle_k, mode=mode) 

193 

194 @functools.wraps(func) 

195 def decorator(obj, W, *args, **kwargs): 

196 if isinstance(W, list) or (isinstance(W, np.ndarray) and W.ndim == 4): 

197 # No explicit k-point indexing is needed 

198 if mode == "gracefully": 

199 return [func(obj, Wk, *args, **kwargs) for Wk in W] 

200 # Explicit k-point indexing is needed 

201 if mode == "index": 

202 return [func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)] 

203 # Explicit k-point indexing is needed and the result has to be summed up 

204 if mode == "reduce": 

205 # The Python sum allows summing single values and NumPy arrays elementwise 

206 return sum(func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)) 

207 # No k-point dependency has been implemented, so skip it 

208 if mode == "skip": 

209 obj._atoms.kpts._assert_gamma_only() 

210 ret = func(obj, W[0], *args, **kwargs) 

211 if isinstance(ret, np.ndarray) and ret.ndim == 3: 

212 return [ret] 

213 return ret 

214 return func(obj, W, *args, **kwargs) 

215 

216 return decorator 

217 

218 

219def handle_backend(func, *args, **kwargs): 

220 """Use a function optimized with a different backend if available. 

221 

222 Args: 

223 func: Function with an alternative implementation. 

224 args: Pass-through arguments. 

225 

226 Keyword Args: 

227 kwargs: Pass-through keyword arguments. 

228 

229 Returns: 

230 Decorator. 

231 """ 

232 

233 @functools.wraps(func) 

234 def decorator(*args, **kwargs): 

235 if config.backend == "jax": 

236 func_jax = getattr(eminus.extras.jax, func.__name__) 

237 return func_jax(*args, **kwargs) 

238 if config.backend == "torch": 

239 func_torch = getattr(eminus.extras.torch, func.__name__) 

240 return func_torch(*args, **kwargs) 

241 return func(*args, **kwargs) 

242 

243 return decorator 

244 

245 

246def pseudo_uniform(size, seed=1234): 

247 """Lehmer random number generator, following MINSTD. 

248 

249 Reference: Commun. ACM. 12, 85. 

250 

251 Args: 

252 size: Dimension of the array to create. 

253 

254 Keyword Args: 

255 seed: Seed to initialize the random number generator. 

256 

257 Returns: 

258 Array with (pseudo) random numbers. 

259 """ 

260 W = np.zeros(size, dtype=complex) 

261 mult = 48271 

262 mod = (2**31) - 1 

263 x = (seed * mult + 1) % mod 

264 for i in range(size[0]): 

265 for j in range(size[1]): 

266 for k in range(size[2]): 

267 x = (x * mult + 1) % mod 

268 W[i, j, k] = x / mod 

269 return W 

270 

271 

272def add_maybe_none(a, b): 

273 """Add a and b together, when one or both can potentially be None. 

274 

275 Args: 

276 a: Array or None. 

277 b: Array or None. 

278 

279 Returns: 

280 Sum of a and b. 

281 """ 

282 if a is b is None: 

283 return None 

284 if a is None: 

285 return b 

286 if b is None: 

287 return a 

288 return a + b 

289 

290 

291def molecule2list(molecule): 

292 """Expand a chemical formula to a list of chemical symbols. 

293 

294 No charges or parentheses are allowed, only chemical symbols followed by their amount. 

295 

296 Args: 

297 molecule: Simplified chemical formula (case sensitive). 

298 

299 Returns: 

300 Atoms of the molecule expanded to a list. 

301 """ 

302 # Insert a whitespace before every capital letter, these can appear once or none at all 

303 # Or insert before digits, these can appear at least once 

304 tmp_list = re.sub(r"([A-Z?]|\d+)", r" \1", molecule).split() 

305 atom_list = [] 

306 for ia in tmp_list: 

307 if ia.isdigit(): 

308 # If ia is an integer append the previous atom ia-1 times 

309 atom_list += [atom_list[-1]] * (int(ia) - 1) 

310 else: 

311 # If ia is a string add it to the results list 

312 atom_list += [ia] 

313 return atom_list 

314 

315 

316def atom2charge(atom, path=None): 

317 """Get the valence charges for a list of chemical symbols from GTH files. 

318 

319 Args: 

320 atom: Atom symbols. 

321 path: Directory of GTH files. 

322 

323 Returns: 

324 Valence charges per atom. 

325 """ 

326 # Import here to prevent circular imports 

327 from .io import read_gth 

328 

329 if path is not None: 

330 if path.lower() in {"pade", "pbe"}: 

331 psp_path = path.lower() 

332 else: 

333 psp_path = path 

334 else: 

335 psp_path = "pbe" 

336 return [read_gth(ia, psp_path=psp_path)["Zion"] for ia in atom] 

337 

338 

339def vector_angle(a, b): 

340 """Calculate the angle between two vectors. 

341 

342 Args: 

343 a: Vector. 

344 b: Vector. 

345 

346 Returns: 

347 Angle between a and b in Degree. 

348 """ 

349 # Normalize vectors first 

350 a_norm = a / norm(a) 

351 b_norm = b / norm(b) 

352 angle = np.arccos(a_norm @ b_norm) 

353 return rad2deg(angle) 

354 

355 

356def get_lattice(lattice_vectors): 

357 """Generate a cell for given lattice vectors. 

358 

359 Args: 

360 lattice_vectors: Lattice vectors. 

361 

362 Returns: 

363 Lattice vertices. 

364 """ 

365 # Vertices of a cube 

366 vertices = np.array( 

367 [ 

368 [0, 0, 0], 

369 [0, 0, 1], 

370 [0, 1, 0], 

371 [0, 1, 1], 

372 [1, 0, 0], 

373 [1, 0, 1], 

374 [1, 1, 0], 

375 [1, 1, 1], 

376 ] 

377 ) 

378 # Connected vertices of a cube with the above ordering 

379 edges = np.array( 

380 [ 

381 [0, 1], 

382 [0, 2], 

383 [0, 4], 

384 [1, 3], 

385 [1, 5], 

386 [2, 3], 

387 [2, 6], 

388 [3, 7], 

389 [4, 5], 

390 [4, 6], 

391 [5, 7], 

392 [6, 7], 

393 ] 

394 ) 

395 # Scale vertices with the lattice vectors 

396 # Select pairs of vertices to plot them later 

397 # The resulting return value is similar to the get_brillouin_zone function 

398 return [(vertices @ lattice_vectors)[e, :] for e in edges]