Coverage for eminus/utils.py: 95.36%
151 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 10:16 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 10:16 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Linear algebra calculation utilities."""
5import functools
6import pathlib
7import re
9import numpy as np
10from scipy.linalg import norm
12import eminus
14from . import config
15from .units import rad2deg
18class BaseObject:
19 """Base eminus class that implements some shared functionalities."""
21 def view(self, *args, **kwargs):
22 """Unified display function.
24 Args:
25 args: Pass-through arguments.
27 Keyword Args:
28 kwargs: Pass-through keyword arguments.
30 Returns:
31 Viewable object.
32 """
33 return eminus.extras.view(self, *args, **kwargs)
35 def write(self, filename, *args, **kwargs):
36 """Unified file writer function.
38 Args:
39 filename: Input file path/name.
40 args: Pass-through arguments.
42 Keyword Args:
43 kwargs: Pass-through keyword arguments.
45 Returns:
46 None.
47 """
48 # Save the object as a JSON file if no extension is given
49 if not pathlib.Path(filename).suffix and "POSCAR" not in filename:
50 filename += ".json"
51 return eminus.io.write(self, filename, *args, **kwargs)
54def dotprod(a, b):
55 """Efficiently calculate the expression a * b.
57 Add an extra check to make sure the result is never zero since this function is used as a
58 denominator in minimizers.
60 Args:
61 a: Array of vectors.
62 b: Array of vectors.
64 Returns:
65 The expressions result.
66 """
67 eps = 1e-15 # 2.22e-16 is the range of float64 machine precision
68 # The dot product of complex vectors looks like the expression below, but this is slow
69 # res = np.real(np.trace(a.conj().T @ b))
70 # We can calculate the trace faster by taking the sum of the Hadamard product
71 res = np.sum(a.conj() * b)
72 if abs(res) < eps:
73 return eps
74 return np.real(res)
77def Ylm_real(l, m, G): # noqa: C901, PLR0911
78 """Calculate real spherical harmonics from cartesian coordinates.
80 Reference: https://scipython.com/blog/visualizing-the-real-forms-of-the-spherical-harmonics
82 Args:
83 l: Angular momentum number.
84 m: Magnetic quantum number.
85 G: Reciprocal lattice vector or array of lattice vectors.
87 Returns:
88 Real spherical harmonics.
89 """
90 eps = 1e-9
91 # Account for single vectors
92 G = np.atleast_2d(G)
94 # No need to calculate more for l=0
95 if l == 0:
96 return 0.5 * np.sqrt(1 / np.pi) * np.ones(len(G))
98 # cos(theta)=Gz/|G|
99 Gm = norm(G, axis=1)
100 with np.errstate(divide="ignore", invalid="ignore"):
101 cos_theta = G[:, 2] / Gm
102 # Account for small magnitudes, if norm(G) < eps: cos_theta=0
103 cos_theta[Gm < eps] = 0
105 # Vectorized version of sin(theta)=sqrt(max(0, 1-cos_theta^2))
106 sin_theta = np.sqrt(np.amax((np.zeros_like(cos_theta), 1 - cos_theta**2), axis=0))
108 # phi=arctan(Gy/Gx)
109 phi = np.arctan2(G[:, 1], G[:, 0])
110 # If Gx=0: phi=pi/2*sign(Gy)
111 phi_idx = np.abs(G[:, 0]) < eps
112 phi[phi_idx] = np.pi / 2 * np.sign(G[phi_idx, 1])
114 if l == 1:
115 if m == -1: # py
116 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.sin(phi)
117 if m == 0: # pz
118 return 0.5 * np.sqrt(3 / np.pi) * cos_theta
119 if m == 1: # px
120 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.cos(phi)
121 elif l == 2:
122 if m == -2: # dxy
123 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.sin(2 * phi)
124 if m == -1: # dyz
125 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.sin(phi)
126 if m == 0: # dz2
127 return 0.25 * np.sqrt(5 / np.pi) * (3 * cos_theta**2 - 1)
128 if m == 1: # dxz
129 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.cos(phi)
130 if m == 2: # dx2-y2
131 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.cos(2 * phi)
132 elif l == 3:
133 if m == -3:
134 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.sin(3 * phi)
135 if m == -2:
136 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.sin(2 * phi)
137 if m == -1:
138 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.sin(phi)
139 if m == 0:
140 return 0.25 * np.sqrt(7 / np.pi) * (5 * cos_theta**3 - 3 * cos_theta)
141 if m == 1:
142 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.cos(phi)
143 if m == 2:
144 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.cos(2 * phi)
145 if m == 3:
146 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.cos(3 * phi)
148 msg = f"No definition found for Ylm({l}, {m})."
149 raise ValueError(msg)
152def handle_spin(func):
153 """Handle spin calculating the function for each channel separately.
155 This can only be applied if the only spin-dependent indexing is the wave function W.
157 Implementing the explicit handling of spin adds an extra layer of complexity where one has to
158 loop over the spin states in many places. We can hide this complexity using this decorator while
159 still supporting many use cases, e.g., the operators previously act on arrays containing wave
160 functions of all states and of one state only. This decorator maintains this functionality and
161 adds the option to act on arrays containing wave functions of all spins and all states as well.
163 Args:
164 func: Function that acts on spin-states.
166 Returns:
167 Decorator.
168 """
170 @functools.wraps(func)
171 def decorator(obj, W, *args, **kwargs):
172 if W.ndim == 3:
173 return np.asarray([func(obj, Wspin, *args, **kwargs) for Wspin in W])
174 return func(obj, W, *args, **kwargs)
176 return decorator
179def handle_k(func=None, *, mode="gracefully"):
180 """Handle k-points calculating the function for each channel with different modes.
182 This uses the same principle as described in :func:`~eminus.utils.handle_spin`.
184 Keyword Args:
185 func: Function that acts on k-points.
186 mode: How to handle the k-point dependency.
188 Returns:
189 Decorator.
190 """
191 if func is None:
192 return functools.partial(handle_k, mode=mode)
194 @functools.wraps(func)
195 def decorator(obj, W, *args, **kwargs):
196 if isinstance(W, list) or (isinstance(W, np.ndarray) and W.ndim == 4):
197 # No explicit k-point indexing is needed
198 if mode == "gracefully":
199 return [func(obj, Wk, *args, **kwargs) for Wk in W]
200 # Explicit k-point indexing is needed
201 if mode == "index":
202 return [func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)]
203 # Explicit k-point indexing is needed and the result has to be summed up
204 if mode == "reduce":
205 # The Python sum allows summing single values and NumPy arrays elementwise
206 return sum(func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W))
207 # No k-point dependency has been implemented, so skip it
208 if mode == "skip":
209 obj._atoms.kpts._assert_gamma_only()
210 ret = func(obj, W[0], *args, **kwargs)
211 if isinstance(ret, np.ndarray) and ret.ndim == 3:
212 return [ret]
213 return ret
214 return func(obj, W, *args, **kwargs)
216 return decorator
219def handle_backend(func, *args, **kwargs):
220 """Use a function optimized with a different backend if available.
222 Args:
223 func: Function with an alternative implementation.
224 args: Pass-through arguments.
226 Keyword Args:
227 kwargs: Pass-through keyword arguments.
229 Returns:
230 Decorator.
231 """
233 @functools.wraps(func)
234 def decorator(*args, **kwargs):
235 if config.backend == "jax":
236 func_jax = getattr(eminus.extras.jax, func.__name__)
237 return func_jax(*args, **kwargs)
238 if config.backend == "torch":
239 func_torch = getattr(eminus.extras.torch, func.__name__)
240 return func_torch(*args, **kwargs)
241 return func(*args, **kwargs)
243 return decorator
246def pseudo_uniform(size, seed=1234):
247 """Lehmer random number generator, following MINSTD.
249 Reference: Commun. ACM. 12, 85.
251 Args:
252 size: Dimension of the array to create.
254 Keyword Args:
255 seed: Seed to initialize the random number generator.
257 Returns:
258 Array with (pseudo) random numbers.
259 """
260 W = np.zeros(size, dtype=complex)
261 mult = 48271
262 mod = (2**31) - 1
263 x = (seed * mult + 1) % mod
264 for i in range(size[0]):
265 for j in range(size[1]):
266 for k in range(size[2]):
267 x = (x * mult + 1) % mod
268 W[i, j, k] = x / mod
269 return W
272def add_maybe_none(a, b):
273 """Add a and b together, when one or both can potentially be None.
275 Args:
276 a: Array or None.
277 b: Array or None.
279 Returns:
280 Sum of a and b.
281 """
282 if a is b is None:
283 return None
284 if a is None:
285 return b
286 if b is None:
287 return a
288 return a + b
291def molecule2list(molecule):
292 """Expand a chemical formula to a list of chemical symbols.
294 No charges or parentheses are allowed, only chemical symbols followed by their amount.
296 Args:
297 molecule: Simplified chemical formula (case sensitive).
299 Returns:
300 Atoms of the molecule expanded to a list.
301 """
302 # Insert a whitespace before every capital letter, these can appear once or none at all
303 # Or insert before digits, these can appear at least once
304 tmp_list = re.sub(r"([A-Z?]|\d+)", r" \1", molecule).split()
305 atom_list = []
306 for ia in tmp_list:
307 if ia.isdigit():
308 # If ia is an integer append the previous atom ia-1 times
309 atom_list += [atom_list[-1]] * (int(ia) - 1)
310 else:
311 # If ia is a string add it to the results list
312 atom_list += [ia]
313 return atom_list
316def atom2charge(atom, path=None):
317 """Get the valence charges for a list of chemical symbols from GTH files.
319 Args:
320 atom: Atom symbols.
321 path: Directory of GTH files.
323 Returns:
324 Valence charges per atom.
325 """
326 # Import here to prevent circular imports
327 from .io import read_gth
329 if path is not None:
330 if path.lower() in {"pade", "pbe"}:
331 psp_path = path.lower()
332 else:
333 psp_path = path
334 else:
335 psp_path = "pbe"
336 return [read_gth(ia, psp_path=psp_path)["Zion"] for ia in atom]
339def vector_angle(a, b):
340 """Calculate the angle between two vectors.
342 Args:
343 a: Vector.
344 b: Vector.
346 Returns:
347 Angle between a and b in Degree.
348 """
349 # Normalize vectors first
350 a_norm = a / norm(a)
351 b_norm = b / norm(b)
352 angle = np.arccos(a_norm @ b_norm)
353 return rad2deg(angle)
356def get_lattice(lattice_vectors):
357 """Generate a cell for given lattice vectors.
359 Args:
360 lattice_vectors: Lattice vectors.
362 Returns:
363 Lattice vertices.
364 """
365 # Vertices of a cube
366 vertices = np.array(
367 [
368 [0, 0, 0],
369 [0, 0, 1],
370 [0, 1, 0],
371 [0, 1, 1],
372 [1, 0, 0],
373 [1, 0, 1],
374 [1, 1, 0],
375 [1, 1, 1],
376 ]
377 )
378 # Connected vertices of a cube with the above ordering
379 edges = np.array(
380 [
381 [0, 1],
382 [0, 2],
383 [0, 4],
384 [1, 3],
385 [1, 5],
386 [2, 3],
387 [2, 6],
388 [3, 7],
389 [4, 5],
390 [4, 6],
391 [5, 7],
392 [6, 7],
393 ]
394 )
395 # Scale vertices with the lattice vectors
396 # Select pairs of vertices to plot them later
397 # The resulting return value is similar to the get_brillouin_zone function
398 return [(vertices @ lattice_vectors)[e, :] for e in edges]