Coverage for eminus/utils.py: 96.62%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-18 08:43 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Linear algebra calculation utilities.""" 

4 

5import functools 

6import pathlib 

7import re 

8 

9import numpy as np 

10from scipy.linalg import norm 

11 

12import eminus 

13 

14from . import config 

15from .units import rad2deg 

16 

17 

18class BaseObject: 

19 """Base eminus class that implements some shared functionalities.""" 

20 

21 def view(self, *args, **kwargs): 

22 """Unified display function. 

23 

24 Args: 

25 args: Pass-through arguments. 

26 

27 Keyword Args: 

28 kwargs: Pass-through keyword arguments. 

29 

30 Returns: 

31 Viewable object. 

32 """ 

33 return eminus.extras.view(self, *args, **kwargs) 

34 

35 def write(self, filename, *args, **kwargs): 

36 """Unified file writer function. 

37 

38 Args: 

39 filename: Input file path/name. 

40 args: Pass-through arguments. 

41 

42 Keyword Args: 

43 kwargs: Pass-through keyword arguments. 

44 

45 Returns: 

46 None. 

47 """ 

48 # Save the object as a JSON file if no extension is given 

49 if not pathlib.Path(filename).suffix and "POSCAR" not in filename: 

50 filename += ".json" 

51 return eminus.io.write(self, filename, *args, **kwargs) 

52 

53 

54def dotprod(a, b): 

55 """Efficiently calculate the expression a * b. 

56 

57 Add an extra check to make sure the result is never zero since this function is used as a 

58 denominator in minimizers. 

59 

60 Args: 

61 a: Array of vectors. 

62 b: Array of vectors. 

63 

64 Returns: 

65 The expressions result. 

66 """ 

67 eps = 1e-15 # 2.22e-16 is the range of float64 machine precision 

68 # The dot product of complex vectors looks like the expression below, but this is slow 

69 # res = np.real(np.trace(a.conj().T @ b)) 

70 # We can calculate the trace faster by taking the sum of the Hadamard product 

71 res = np.sum(a.conj() * b) 

72 if abs(res) < eps: 

73 return eps 

74 return np.real(res) 

75 

76 

77def Ylm_real(l, m, G): # noqa: C901, PLR0911 

78 """Calculate real spherical harmonics from cartesian coordinates. 

79 

80 Reference: https://scipython.com/blog/visualizing-the-real-forms-of-the-spherical-harmonics 

81 

82 Args: 

83 l: Angular momentum number. 

84 m: Magnetic quantum number. 

85 G: Reciprocal lattice vector or array of lattice vectors. 

86 

87 Returns: 

88 Real spherical harmonics. 

89 """ 

90 eps = 1e-9 

91 # Account for single vectors 

92 G = np.atleast_2d(G) 

93 

94 # No need to calculate more for l=0 

95 if l == 0: 

96 return 0.5 * np.sqrt(1 / np.pi) * np.ones(len(G)) 

97 

98 # cos(theta)=Gz/|G| 

99 Gm = norm(G, axis=1) 

100 with np.errstate(divide="ignore", invalid="ignore"): 

101 cos_theta = G[:, 2] / Gm 

102 # Account for small magnitudes, if norm(G) < eps: cos_theta=0 

103 cos_theta[Gm < eps] = 0 

104 

105 # Vectorized version of sin(theta)=sqrt(max(0, 1-cos_theta^2)) 

106 sin_theta = np.sqrt(np.amax((np.zeros_like(cos_theta), 1 - cos_theta**2), axis=0)) 

107 

108 # phi=arctan(Gy/Gx) 

109 phi = np.arctan2(G[:, 1], G[:, 0]) 

110 # If Gx=0: phi=pi/2*sign(Gy) 

111 phi_idx = np.abs(G[:, 0]) < eps 

112 phi[phi_idx] = np.pi / 2 * np.sign(G[phi_idx, 1]) 

113 

114 if l == 1: 

115 if m == -1: # py 

116 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.sin(phi) 

117 if m == 0: # pz 

118 return 0.5 * np.sqrt(3 / np.pi) * cos_theta 

119 if m == 1: # px 

120 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.cos(phi) 

121 elif l == 2: 

122 if m == -2: # dxy 

123 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.sin(2 * phi) 

124 if m == -1: # dyz 

125 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.sin(phi) 

126 if m == 0: # dz2 

127 return 0.25 * np.sqrt(5 / np.pi) * (3 * cos_theta**2 - 1) 

128 if m == 1: # dxz 

129 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.cos(phi) 

130 if m == 2: # dx2-y2 

131 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.cos(2 * phi) 

132 elif l == 3: 

133 if m == -3: 

134 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.sin(3 * phi) 

135 if m == -2: 

136 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.sin(2 * phi) 

137 if m == -1: 

138 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.sin(phi) 

139 if m == 0: 

140 return 0.25 * np.sqrt(7 / np.pi) * (5 * cos_theta**3 - 3 * cos_theta) 

141 if m == 1: 

142 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.cos(phi) 

143 if m == 2: 

144 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.cos(2 * phi) 

145 if m == 3: 

146 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.cos(3 * phi) 

147 

148 msg = f"No definition found for Ylm({l}, {m})." 

149 raise ValueError(msg) 

150 

151 

152def handle_spin(func): 

153 """Handle spin calculating the function for each channel separately. 

154 

155 This can only be applied if the only spin-dependent indexing is the wave function W. 

156 

157 Implementing the explicit handling of spin adds an extra layer of complexity where one has to 

158 loop over the spin states in many places. We can hide this complexity using this decorator while 

159 still supporting many use cases, e.g., the operators previously act on arrays containing wave 

160 functions of all states and of one state only. This decorator maintains this functionality and 

161 adds the option to act on arrays containing wave functions of all spins and all states as well. 

162 

163 Args: 

164 func: Function that acts on spin-states. 

165 

166 Returns: 

167 Decorator. 

168 """ 

169 

170 @functools.wraps(func) 

171 def decorator(obj, W, *args, **kwargs): 

172 if W.ndim == 3: 

173 return np.asarray([func(obj, Wspin, *args, **kwargs) for Wspin in W]) 

174 return func(obj, W, *args, **kwargs) 

175 

176 return decorator 

177 

178 

179def handle_k(func=None, *, mode="gracefully"): 

180 """Handle k-points calculating the function for each channel with different modes. 

181 

182 This uses the same principle as described in :func:`~eminus.utils.handle_spin`. 

183 

184 Keyword Args: 

185 func: Function that acts on k-points. 

186 mode: How to handle the k-point dependency. 

187 

188 Returns: 

189 Decorator. 

190 """ 

191 if func is None: 

192 return functools.partial(handle_k, mode=mode) 

193 

194 @functools.wraps(func) 

195 def decorator(obj, W, *args, **kwargs): 

196 if isinstance(W, list) or (isinstance(W, np.ndarray) and W.ndim == 4): 

197 # No explicit k-point indexing is needed 

198 if mode == "gracefully": 

199 return [func(obj, Wk, *args, **kwargs) for Wk in W] 

200 # Explicit k-point indexing is needed 

201 if mode == "index": 

202 return [func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)] 

203 # Explicit k-point indexing is needed and the result has to be summed up 

204 if mode == "reduce": 

205 # The Python sum allows summing single values and NumPy arrays elementwise 

206 return sum(func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)) 

207 # No k-point dependency has been implemented, so skip it 

208 if mode == "skip": 

209 obj._atoms.kpts._assert_gamma_only() 

210 ret = func(obj, W[0], *args, **kwargs) 

211 if isinstance(ret, np.ndarray) and ret.ndim == 3: 

212 return [ret] 

213 return ret 

214 return func(obj, W, *args, **kwargs) 

215 

216 return decorator 

217 

218 

219def handle_torch(func, *args, **kwargs): 

220 """Use a function optimized with Torch if available. 

221 

222 Args: 

223 func: Function with a Torch alternative. 

224 args: Pass-through arguments. 

225 

226 Keyword Args: 

227 kwargs: Pass-through keyword arguments. 

228 

229 Returns: 

230 Decorator. 

231 """ 

232 

233 @functools.wraps(func) 

234 def decorator(*args, **kwargs): 

235 if config.use_torch: 

236 func_torch = getattr(eminus.extras.torch, func.__name__) 

237 return func_torch(*args, **kwargs) 

238 return func(*args, **kwargs) 

239 

240 return decorator 

241 

242 

243def pseudo_uniform(size, seed=1234): 

244 """Lehmer random number generator, following MINSTD. 

245 

246 Reference: Commun. ACM. 12, 85. 

247 

248 Args: 

249 size: Dimension of the array to create. 

250 

251 Keyword Args: 

252 seed: Seed to initialize the random number generator. 

253 

254 Returns: 

255 Array with (pseudo) random numbers. 

256 """ 

257 W = np.zeros(size, dtype=complex) 

258 mult = 48271 

259 mod = (2**31) - 1 

260 x = (seed * mult + 1) % mod 

261 for i in range(size[0]): 

262 for j in range(size[1]): 

263 for k in range(size[2]): 

264 x = (x * mult + 1) % mod 

265 W[i, j, k] = x / mod 

266 return W 

267 

268 

269def add_maybe_none(a, b): 

270 """Add a and b together, when one or both can potentially be None. 

271 

272 Args: 

273 a: Array or None. 

274 b: Array or None. 

275 

276 Returns: 

277 Sum of a and b. 

278 """ 

279 if a is b is None: 

280 return None 

281 if a is None: 

282 return b 

283 if b is None: 

284 return a 

285 return a + b 

286 

287 

288def molecule2list(molecule): 

289 """Expand a chemical formula to a list of chemical symbols. 

290 

291 No charges or parentheses are allowed, only chemical symbols followed by their amount. 

292 

293 Args: 

294 molecule: Simplified chemical formula (case sensitive). 

295 

296 Returns: 

297 Atoms of the molecule expanded to a list. 

298 """ 

299 # Insert a whitespace before every capital letter, these can appear once or none at all 

300 # Or insert before digits, these can appear at least once 

301 tmp_list = re.sub(r"([A-Z?]|\d+)", r" \1", molecule).split() 

302 atom_list = [] 

303 for ia in tmp_list: 

304 if ia.isdigit(): 

305 # If ia is an integer append the previous atom ia-1 times 

306 atom_list += [atom_list[-1]] * (int(ia) - 1) 

307 else: 

308 # If ia is a string add it to the results list 

309 atom_list += [ia] 

310 return atom_list 

311 

312 

313def atom2charge(atom, path=None): 

314 """Get the valence charges for a list of chemical symbols from GTH files. 

315 

316 Args: 

317 atom: Atom symbols. 

318 path: Directory of GTH files. 

319 

320 Returns: 

321 Valence charges per atom. 

322 """ 

323 # Import here to prevent circular imports 

324 from .io import read_gth 

325 

326 if path is not None: 

327 if path.lower() in {"pade", "pbe"}: 

328 psp_path = path.lower() 

329 else: 

330 psp_path = path 

331 else: 

332 psp_path = "pbe" 

333 return [read_gth(ia, psp_path=psp_path)["Zion"] for ia in atom] 

334 

335 

336def vector_angle(a, b): 

337 """Calculate the angle between two vectors. 

338 

339 Args: 

340 a: Vector. 

341 b: Vector. 

342 

343 Returns: 

344 Angle between a and b in Degree. 

345 """ 

346 # Normalize vectors first 

347 a_norm = a / norm(a) 

348 b_norm = b / norm(b) 

349 angle = np.arccos(a_norm @ b_norm) 

350 return rad2deg(angle) 

351 

352 

353def get_lattice(lattice_vectors): 

354 """Generate a cell for given lattice vectors. 

355 

356 Args: 

357 lattice_vectors: Lattice vectors. 

358 

359 Returns: 

360 Lattice vertices. 

361 """ 

362 # Vertices of a cube 

363 vertices = np.array( 

364 [ 

365 [0, 0, 0], 

366 [0, 0, 1], 

367 [0, 1, 0], 

368 [0, 1, 1], 

369 [1, 0, 0], 

370 [1, 0, 1], 

371 [1, 1, 0], 

372 [1, 1, 1], 

373 ] 

374 ) 

375 # Connected vertices of a cube with the above ordering 

376 edges = np.array( 

377 [ 

378 [0, 1], 

379 [0, 2], 

380 [0, 4], 

381 [1, 3], 

382 [1, 5], 

383 [2, 3], 

384 [2, 6], 

385 [3, 7], 

386 [4, 5], 

387 [4, 6], 

388 [5, 7], 

389 [6, 7], 

390 ] 

391 ) 

392 # Scale vertices with the lattice vectors 

393 # Select pairs of vertices to plot them later 

394 # The resulting return value is similar to the get_brillouin_zone function 

395 return [(vertices @ lattice_vectors)[e, :] for e in edges]