Coverage for eminus/utils.py: 96.62%
148 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Linear algebra calculation utilities."""
5import functools
6import pathlib
7import re
9import numpy as np
10from scipy.linalg import norm
12import eminus
14from . import config
15from .units import rad2deg
18class BaseObject:
19 """Base eminus class that implements some shared functionalities."""
21 def view(self, *args, **kwargs):
22 """Unified display function.
24 Args:
25 args: Pass-through arguments.
27 Keyword Args:
28 kwargs: Pass-through keyword arguments.
30 Returns:
31 Viewable object.
32 """
33 return eminus.extras.view(self, *args, **kwargs)
35 def write(self, filename, *args, **kwargs):
36 """Unified file writer function.
38 Args:
39 filename: Input file path/name.
40 args: Pass-through arguments.
42 Keyword Args:
43 kwargs: Pass-through keyword arguments.
45 Returns:
46 None.
47 """
48 # Save the object as a JSON file if no extension is given
49 if not pathlib.Path(filename).suffix and "POSCAR" not in filename:
50 filename += ".json"
51 return eminus.io.write(self, filename, *args, **kwargs)
54def dotprod(a, b):
55 """Efficiently calculate the expression a * b.
57 Add an extra check to make sure the result is never zero since this function is used as a
58 denominator in minimizers.
60 Args:
61 a: Array of vectors.
62 b: Array of vectors.
64 Returns:
65 The expressions result.
66 """
67 eps = 1e-15 # 2.22e-16 is the range of float64 machine precision
68 # The dot product of complex vectors looks like the expression below, but this is slow
69 # res = np.real(np.trace(a.conj().T @ b))
70 # We can calculate the trace faster by taking the sum of the Hadamard product
71 res = np.sum(a.conj() * b)
72 if abs(res) < eps:
73 return eps
74 return np.real(res)
77def Ylm_real(l, m, G): # noqa: C901, PLR0911
78 """Calculate real spherical harmonics from cartesian coordinates.
80 Reference: https://scipython.com/blog/visualizing-the-real-forms-of-the-spherical-harmonics
82 Args:
83 l: Angular momentum number.
84 m: Magnetic quantum number.
85 G: Reciprocal lattice vector or array of lattice vectors.
87 Returns:
88 Real spherical harmonics.
89 """
90 eps = 1e-9
91 # Account for single vectors
92 G = np.atleast_2d(G)
94 # No need to calculate more for l=0
95 if l == 0:
96 return 0.5 * np.sqrt(1 / np.pi) * np.ones(len(G))
98 # cos(theta)=Gz/|G|
99 Gm = norm(G, axis=1)
100 with np.errstate(divide="ignore", invalid="ignore"):
101 cos_theta = G[:, 2] / Gm
102 # Account for small magnitudes, if norm(G) < eps: cos_theta=0
103 cos_theta[Gm < eps] = 0
105 # Vectorized version of sin(theta)=sqrt(max(0, 1-cos_theta^2))
106 sin_theta = np.sqrt(np.amax((np.zeros_like(cos_theta), 1 - cos_theta**2), axis=0))
108 # phi=arctan(Gy/Gx)
109 phi = np.arctan2(G[:, 1], G[:, 0])
110 # If Gx=0: phi=pi/2*sign(Gy)
111 phi_idx = np.abs(G[:, 0]) < eps
112 phi[phi_idx] = np.pi / 2 * np.sign(G[phi_idx, 1])
114 if l == 1:
115 if m == -1: # py
116 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.sin(phi)
117 if m == 0: # pz
118 return 0.5 * np.sqrt(3 / np.pi) * cos_theta
119 if m == 1: # px
120 return 0.5 * np.sqrt(3 / np.pi) * sin_theta * np.cos(phi)
121 elif l == 2:
122 if m == -2: # dxy
123 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.sin(2 * phi)
124 if m == -1: # dyz
125 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.sin(phi)
126 if m == 0: # dz2
127 return 0.25 * np.sqrt(5 / np.pi) * (3 * cos_theta**2 - 1)
128 if m == 1: # dxz
129 return np.sqrt(15 / 4 / np.pi) * cos_theta * sin_theta * np.cos(phi)
130 if m == 2: # dx2-y2
131 return np.sqrt(15 / 16 / np.pi) * sin_theta**2 * np.cos(2 * phi)
132 elif l == 3:
133 if m == -3:
134 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.sin(3 * phi)
135 if m == -2:
136 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.sin(2 * phi)
137 if m == -1:
138 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.sin(phi)
139 if m == 0:
140 return 0.25 * np.sqrt(7 / np.pi) * (5 * cos_theta**3 - 3 * cos_theta)
141 if m == 1:
142 return 0.25 * np.sqrt(21 / 2 / np.pi) * sin_theta * (5 * cos_theta**2 - 1) * np.cos(phi)
143 if m == 2:
144 return 0.25 * np.sqrt(105 / np.pi) * sin_theta**2 * cos_theta * np.cos(2 * phi)
145 if m == 3:
146 return 0.25 * np.sqrt(35 / 2 / np.pi) * sin_theta**3 * np.cos(3 * phi)
148 msg = f"No definition found for Ylm({l}, {m})."
149 raise ValueError(msg)
152def handle_spin(func):
153 """Handle spin calculating the function for each channel separately.
155 This can only be applied if the only spin-dependent indexing is the wave function W.
157 Implementing the explicit handling of spin adds an extra layer of complexity where one has to
158 loop over the spin states in many places. We can hide this complexity using this decorator while
159 still supporting many use cases, e.g., the operators previously act on arrays containing wave
160 functions of all states and of one state only. This decorator maintains this functionality and
161 adds the option to act on arrays containing wave functions of all spins and all states as well.
163 Args:
164 func: Function that acts on spin-states.
166 Returns:
167 Decorator.
168 """
170 @functools.wraps(func)
171 def decorator(obj, W, *args, **kwargs):
172 if W.ndim == 3:
173 return np.asarray([func(obj, Wspin, *args, **kwargs) for Wspin in W])
174 return func(obj, W, *args, **kwargs)
176 return decorator
179def handle_k(func=None, *, mode="gracefully"):
180 """Handle k-points calculating the function for each channel with different modes.
182 This uses the same principle as described in :func:`~eminus.utils.handle_spin`.
184 Keyword Args:
185 func: Function that acts on k-points.
186 mode: How to handle the k-point dependency.
188 Returns:
189 Decorator.
190 """
191 if func is None:
192 return functools.partial(handle_k, mode=mode)
194 @functools.wraps(func)
195 def decorator(obj, W, *args, **kwargs):
196 if isinstance(W, list) or (isinstance(W, np.ndarray) and W.ndim == 4):
197 # No explicit k-point indexing is needed
198 if mode == "gracefully":
199 return [func(obj, Wk, *args, **kwargs) for Wk in W]
200 # Explicit k-point indexing is needed
201 if mode == "index":
202 return [func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W)]
203 # Explicit k-point indexing is needed and the result has to be summed up
204 if mode == "reduce":
205 # The Python sum allows summing single values and NumPy arrays elementwise
206 return sum(func(obj, Wk, ik, *args, **kwargs) for ik, Wk in enumerate(W))
207 # No k-point dependency has been implemented, so skip it
208 if mode == "skip":
209 obj._atoms.kpts._assert_gamma_only()
210 ret = func(obj, W[0], *args, **kwargs)
211 if isinstance(ret, np.ndarray) and ret.ndim == 3:
212 return [ret]
213 return ret
214 return func(obj, W, *args, **kwargs)
216 return decorator
219def handle_torch(func, *args, **kwargs):
220 """Use a function optimized with Torch if available.
222 Args:
223 func: Function with a Torch alternative.
224 args: Pass-through arguments.
226 Keyword Args:
227 kwargs: Pass-through keyword arguments.
229 Returns:
230 Decorator.
231 """
233 @functools.wraps(func)
234 def decorator(*args, **kwargs):
235 if config.use_torch:
236 func_torch = getattr(eminus.extras.torch, func.__name__)
237 return func_torch(*args, **kwargs)
238 return func(*args, **kwargs)
240 return decorator
243def pseudo_uniform(size, seed=1234):
244 """Lehmer random number generator, following MINSTD.
246 Reference: Commun. ACM. 12, 85.
248 Args:
249 size: Dimension of the array to create.
251 Keyword Args:
252 seed: Seed to initialize the random number generator.
254 Returns:
255 Array with (pseudo) random numbers.
256 """
257 W = np.zeros(size, dtype=complex)
258 mult = 48271
259 mod = (2**31) - 1
260 x = (seed * mult + 1) % mod
261 for i in range(size[0]):
262 for j in range(size[1]):
263 for k in range(size[2]):
264 x = (x * mult + 1) % mod
265 W[i, j, k] = x / mod
266 return W
269def add_maybe_none(a, b):
270 """Add a and b together, when one or both can potentially be None.
272 Args:
273 a: Array or None.
274 b: Array or None.
276 Returns:
277 Sum of a and b.
278 """
279 if a is b is None:
280 return None
281 if a is None:
282 return b
283 if b is None:
284 return a
285 return a + b
288def molecule2list(molecule):
289 """Expand a chemical formula to a list of chemical symbols.
291 No charges or parentheses are allowed, only chemical symbols followed by their amount.
293 Args:
294 molecule: Simplified chemical formula (case sensitive).
296 Returns:
297 Atoms of the molecule expanded to a list.
298 """
299 # Insert a whitespace before every capital letter, these can appear once or none at all
300 # Or insert before digits, these can appear at least once
301 tmp_list = re.sub(r"([A-Z?]|\d+)", r" \1", molecule).split()
302 atom_list = []
303 for ia in tmp_list:
304 if ia.isdigit():
305 # If ia is an integer append the previous atom ia-1 times
306 atom_list += [atom_list[-1]] * (int(ia) - 1)
307 else:
308 # If ia is a string add it to the results list
309 atom_list += [ia]
310 return atom_list
313def atom2charge(atom, path=None):
314 """Get the valence charges for a list of chemical symbols from GTH files.
316 Args:
317 atom: Atom symbols.
318 path: Directory of GTH files.
320 Returns:
321 Valence charges per atom.
322 """
323 # Import here to prevent circular imports
324 from .io import read_gth
326 if path is not None:
327 if path.lower() in {"pade", "pbe"}:
328 psp_path = path.lower()
329 else:
330 psp_path = path
331 else:
332 psp_path = "pbe"
333 return [read_gth(ia, psp_path=psp_path)["Zion"] for ia in atom]
336def vector_angle(a, b):
337 """Calculate the angle between two vectors.
339 Args:
340 a: Vector.
341 b: Vector.
343 Returns:
344 Angle between a and b in Degree.
345 """
346 # Normalize vectors first
347 a_norm = a / norm(a)
348 b_norm = b / norm(b)
349 angle = np.arccos(a_norm @ b_norm)
350 return rad2deg(angle)
353def get_lattice(lattice_vectors):
354 """Generate a cell for given lattice vectors.
356 Args:
357 lattice_vectors: Lattice vectors.
359 Returns:
360 Lattice vertices.
361 """
362 # Vertices of a cube
363 vertices = np.array(
364 [
365 [0, 0, 0],
366 [0, 0, 1],
367 [0, 1, 0],
368 [0, 1, 1],
369 [1, 0, 0],
370 [1, 0, 1],
371 [1, 1, 0],
372 [1, 1, 1],
373 ]
374 )
375 # Connected vertices of a cube with the above ordering
376 edges = np.array(
377 [
378 [0, 1],
379 [0, 2],
380 [0, 4],
381 [1, 3],
382 [1, 5],
383 [2, 3],
384 [2, 6],
385 [3, 7],
386 [4, 5],
387 [4, 6],
388 [5, 7],
389 [6, 7],
390 ]
391 )
392 # Scale vertices with the lattice vectors
393 # Select pairs of vertices to plot them later
394 # The resulting return value is similar to the get_brillouin_zone function
395 return [(vertices @ lattice_vectors)[e, :] for e in edges]