Coverage for eminus/utils.py: 97.87%
141 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Linear algebra calculation utilities."""
5import functools
6import math
7import pathlib
8import re
10import numpy as np
12import eminus
14from . import backend as xp
15from .units import rad2deg
18class BaseObject:
19 """Base eminus class that implements some shared functionalities."""
21 def view(self, *args, **kwargs):
22 """Unified display function.
24 Args:
25 args: Pass-through arguments.
27 Keyword Args:
28 kwargs: Pass-through keyword arguments.
30 Returns:
31 Viewable object.
32 """
33 return eminus.extras.view(self, *args, **kwargs)
35 def write(self, filename, *args, **kwargs):
36 """Unified file writer function.
38 Args:
39 filename: Input file path/name.
40 args: Pass-through arguments.
42 Keyword Args:
43 kwargs: Pass-through keyword arguments.
45 Returns:
46 None.
47 """
48 # Save the object as a JSON file if no extension is given
49 if not pathlib.Path(filename).suffix and "POSCAR" not in filename:
50 filename += ".json"
51 return eminus.io.write(self, filename, *args, **kwargs)
54def dotprod(a, b):
55 """Efficiently calculate the expression a * b.
57 Add an extra check to make sure the result is never zero since this function is used as a
58 denominator in minimizers.
60 Args:
61 a: Array of vectors.
62 b: Array of vectors.
64 Returns:
65 The expressions result.
66 """
67 eps = 1e-15 # 2.22e-16 is the range of float64 machine precision
68 # The dot product of complex vectors looks like the expression below, but this is slow
69 # res = xp.real(xp.trace(a.conj().T @ b))
70 # We can calculate the trace faster by taking the sum of the Hadamard product
71 res = xp.sum(a.conj() * b)
72 if abs(res) < eps:
73 return eps
74 return xp.real(res)
77def Ylm_real(l, m, G): # noqa: C901
78 """Calculate real spherical harmonics from cartesian coordinates.
80 Reference: https://scipython.com/blog/visualizing-the-real-forms-of-the-spherical-harmonics
82 Args:
83 l: Angular momentum number.
84 m: Magnetic quantum number.
85 G: Reciprocal lattice vector or array of lattice vectors.
87 Returns:
88 Real spherical harmonics.
89 """
90 eps = 1e-9
91 # Account for single vectors
92 G = xp.atleast_2d(G)
94 # No need to calculate more for l=0
95 if l == 0:
96 return 0.5 * math.sqrt(1 / math.pi) * xp.ones(len(G))
98 # cos(theta)=Gz/|G|
99 Gm = xp.linalg.norm(G, axis=1)
100 with np.errstate(divide="ignore", invalid="ignore"):
101 cos_theta = G[:, 2] / Gm
102 # Account for small magnitudes, if norm(G) < eps: cos_theta=0
103 cos_theta[Gm < eps] = 0
105 # Vectorized version of sin(theta)=sqrt(max(0, 1-cos_theta^2))
106 sin_theta = xp.sqrt(xp.max(xp.stack([xp.zeros_like(cos_theta), 1 - cos_theta**2]), axis=0))
108 # phi=arctan(Gy/Gx)
109 phi = xp.arctan2(G[:, 1], G[:, 0])
110 # If Gx=0: phi=pi/2*sign(Gy)
111 phi_idx = xp.abs(G[:, 0]) < eps
112 phi[phi_idx] = math.pi / 2 * xp.sign(G[phi_idx, 1])
114 if l == 1:
115 if m == -1: # py
116 return 0.5 * math.sqrt(3 / math.pi) * sin_theta * xp.sin(phi)
117 if m == 0: # pz
118 return 0.5 * math.sqrt(3 / math.pi) * cos_theta
119 if m == 1: # px
120 return 0.5 * math.sqrt(3 / math.pi) * sin_theta * xp.cos(phi)
121 elif l == 2:
122 if m == -2: # dxy
123 return math.sqrt(15 / 16 / math.pi) * sin_theta**2 * xp.sin(2 * phi)
124 if m == -1: # dyz
125 return math.sqrt(15 / 4 / math.pi) * cos_theta * sin_theta * xp.sin(phi)
126 if m == 0: # dz2
127 return 0.25 * math.sqrt(5 / math.pi) * (3 * cos_theta**2 - 1)
128 if m == 1: # dxz
129 return math.sqrt(15 / 4 / math.pi) * cos_theta * sin_theta * xp.cos(phi)
130 if m == 2: # dx2-y2
131 return math.sqrt(15 / 16 / math.pi) * sin_theta**2 * xp.cos(2 * phi)
132 elif l == 3:
133 if m == -3:
134 return 0.25 * math.sqrt(35 / 2 / math.pi) * sin_theta**3 * xp.sin(3 * phi)
135 if m == -2:
136 return 0.25 * math.sqrt(105 / math.pi) * sin_theta**2 * cos_theta * xp.sin(2 * phi)
137 if m == -1:
138 return (
139 0.25
140 * math.sqrt(21 / 2 / math.pi)
141 * sin_theta
142 * (5 * cos_theta**2 - 1)
143 * xp.sin(phi)
144 )
145 if m == 0:
146 return 0.25 * math.sqrt(7 / math.pi) * (5 * cos_theta**3 - 3 * cos_theta)
147 if m == 1:
148 return (
149 0.25
150 * math.sqrt(21 / 2 / math.pi)
151 * sin_theta
152 * (5 * cos_theta**2 - 1)
153 * xp.cos(phi)
154 )
155 if m == 2:
156 return 0.25 * math.sqrt(105 / math.pi) * sin_theta**2 * cos_theta * xp.cos(2 * phi)
157 if m == 3:
158 return 0.25 * math.sqrt(35 / 2 / math.pi) * sin_theta**3 * xp.cos(3 * phi)
160 msg = f"No definition found for Ylm({l}, {m})."
161 raise ValueError(msg)
164def handle_spin(func):
165 """Handle spin calculating the function for each channel separately.
167 This can only be applied if the only spin-dependent indexing is the wave function W.
169 Implementing the explicit handling of spin adds an extra layer of complexity where one has to
170 loop over the spin states in many places. We can hide this complexity using this decorator while
171 still supporting many use cases, e.g., the operators previously act on arrays containing wave
172 functions of all states and of one state only. This decorator maintains this functionality and
173 adds the option to act on arrays containing wave functions of all spins and all states as well.
175 Args:
176 func: Function that acts on spin-states.
178 Returns:
179 Decorator.
180 """
182 @functools.wraps(func)
183 def decorator(obj, W, *args, **kwargs):
184 if W.ndim == 3:
185 return xp.stack([func(obj, Wspin, *args, **kwargs) for Wspin in W])
186 return func(obj, W, *args, **kwargs)
188 return decorator
191def handle_k(func=None, *, mode="gracefully"):
192 """Handle k-points calculating the function for each channel with different modes.
194 This uses the same principle as described in :func:`~eminus.utils.handle_spin`.
196 Keyword Args:
197 func: Function that acts on k-points.
198 mode: How to handle the k-point dependency.
200 Returns:
201 Decorator.
202 """
203 if func is None:
204 return functools.partial(handle_k, mode=mode)
206 @functools.wraps(func)
207 def decorator(obj, W, *args, **kwargs):
208 if isinstance(W, list) or (xp.is_array(W) and W.ndim == 4):
209 # No explicit k-point indexing is needed
210 if mode == "gracefully":
211 return [func(obj, Wk, *args, **kwargs) for Wk in W]
212 # Explicit k-point indexing is needed
213 if mode == "index":
214 return [func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)]
215 # Explicit k-point indexing is needed and the result has to be summed up
216 if mode == "reduce":
217 # The Python sum allows summing single values and NumPy arrays elementwise
218 return sum(func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W))
219 # No k-point dependency has been implemented, so skip it
220 if mode == "skip":
221 obj._atoms.kpts._assert_gamma_only()
222 ret = func(obj, W[0], *args, **kwargs)
223 if xp.is_array(ret) and ret.ndim == 3:
224 return [ret]
225 return ret
226 return func(obj, W, *args, **kwargs)
228 return decorator
231def pseudo_uniform(size, seed=1234):
232 """Lehmer random number generator, following MINSTD.
234 Reference: Commun. ACM. 12, 85.
236 Args:
237 size: Dimension of the array to create.
239 Keyword Args:
240 seed: Seed to initialize the random number generator.
242 Returns:
243 Array with (pseudo) random numbers.
244 """
245 W = xp.empty(size, dtype=complex)
246 mult = 48271
247 mod = (2**31) - 1
248 x = (seed * mult + 1) % mod
249 for i in range(size[0]):
250 for j in range(size[1]):
251 for k in range(size[2]):
252 x = (x * mult + 1) % mod
253 W[i, j, k] = x / mod
254 return W
257def add_maybe_none(a, b):
258 """Add a and b together, when one or both can potentially be None.
260 Args:
261 a: Array or None.
262 b: Array or None.
264 Returns:
265 Sum of a and b.
266 """
267 if a is b is None:
268 return None
269 if a is None:
270 return b
271 if b is None:
272 return a
273 return a + b
276def molecule2list(molecule):
277 """Expand a chemical formula to a list of chemical symbols.
279 No charges or parentheses are allowed, only chemical symbols followed by their amount.
281 Args:
282 molecule: Simplified chemical formula (case sensitive).
284 Returns:
285 Atoms of the molecule expanded to a list.
286 """
287 # Insert a whitespace before every capital letter, these can appear once or none at all
288 # Or insert before digits, these can appear at least once
289 tmp_list = re.sub(r"([A-Z?]|\d+)", r" \1", molecule).split()
290 atom_list = []
291 for ia in tmp_list:
292 if ia.isdigit():
293 # If ia is an integer append the previous atom ia-1 times
294 atom_list += [atom_list[-1]] * (int(ia) - 1)
295 else:
296 # If ia is a string add it to the results list
297 atom_list += [ia]
298 return atom_list
301def atom2charge(atom, path=None):
302 """Get the valence charges for a list of chemical symbols from GTH files.
304 Args:
305 atom: Atom symbols.
306 path: Directory of GTH files.
308 Returns:
309 Valence charges per atom.
310 """
311 # Import here to prevent circular imports
312 from .io import read_gth
314 if path is not None:
315 if path.lower() in {"pade", "pbe"}:
316 psp_path = path.lower()
317 else:
318 psp_path = path
319 else:
320 psp_path = "pbe"
321 return [read_gth(ia, psp_path=psp_path)["Zion"] for ia in atom]
324def vector_angle(a, b):
325 """Calculate the angle between two vectors.
327 Args:
328 a: Vector.
329 b: Vector.
331 Returns:
332 Angle between a and b in Degree.
333 """
334 # Normalize vectors first
335 a, b = xp.asarray(a, dtype=float), xp.asarray(b, dtype=float)
336 a_norm = a / xp.linalg.norm(a)
337 b_norm = b / xp.linalg.norm(b)
338 angle = xp.arccos(a_norm @ b_norm)
339 return rad2deg(angle)
342def get_lattice(lattice_vectors):
343 """Generate a cell for given lattice vectors.
345 Args:
346 lattice_vectors: Lattice vectors.
348 Returns:
349 Lattice vertices.
350 """
351 # Vertices of a cube
352 vertices = xp.asarray(
353 [
354 [0, 0, 0],
355 [0, 0, 1],
356 [0, 1, 0],
357 [0, 1, 1],
358 [1, 0, 0],
359 [1, 0, 1],
360 [1, 1, 0],
361 [1, 1, 1],
362 ],
363 dtype=float,
364 )
365 # Connected vertices of a cube with the above ordering
366 edges = xp.asarray(
367 [
368 [0, 1],
369 [0, 2],
370 [0, 4],
371 [1, 3],
372 [1, 5],
373 [2, 3],
374 [2, 6],
375 [3, 7],
376 [4, 5],
377 [4, 6],
378 [5, 7],
379 [6, 7],
380 ]
381 )
382 # Scale vertices with the lattice vectors
383 # Select pairs of vertices to plot them later
384 # The resulting return value is similar to the get_brillouin_zone function
385 return [(vertices @ lattice_vectors)[e, :] for e in edges]