Coverage for eminus/atoms.py: 98.17%
273 statements
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-23 09:07 +0000
« prev ^ index » next coverage.py v7.14.2, created at 2026-06-23 09:07 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Atoms class definition."""
5import math
6import numbers
8import numpy as np
9from scipy.fft import next_fast_len
11from . import backend as xp
12from . import operators
13from .kpoints import KPoints
14from .logger import create_logger, get_level, log
15from .occupations import Occupations
16from .tools import center_of_mass, cutoff2gridspacing, inertia_tensor
17from .utils import atom2charge, BaseObject, molecule2list
20class Atoms(BaseObject):
21 """Atoms object that holds all system and cell parameters.
23 Args:
24 atom: Atom symbols.
26 A string can be given, e.g., with :code:`CH4` that will be parsed to
27 :code:`["C", "H", "H", "H", "H"]`. When calculating atoms one can directly provide the
28 charge, e.g., with :code:`Li-q3`.
29 pos: Atom positions.
31 Keyword Args:
32 ecut: Cut-off energy.
34 Defaults to 30 Eh (ca. 816 eV).
35 a: Cell size or vacuum size.
37 Floats will create a cubic unit cell. Defaults to a 20 a0 (ca. 10.5 A) cubic cell.
38 Scaled lattice vectors can be given to build a custom cell.
39 spin: Number of unpaired electrons.
41 This is the difference between the number of up and down electrons.
42 charge: Charge of the system.
43 unrestricted: Handling of spin.
45 :code:`False` for restricted, :code:`True` for unrestricted, and :code:`None` for
46 automatic detection.
47 center: Center the system inside the cell.
49 Aligns the geometric center of mass with the center of the call and rotates the system,
50 such that its geometric moment of inertia aligns with the coordinate axes. Can be one of
51 bool, "shift", and "rotate".
52 verbose: Level of output.
54 Can be one of "critical", "error", "warning", "info" (default), or "debug". An integer
55 value can be used as well, where larger numbers mean more output, starting from 0.
56 None will use the global logger verbosity value.
57 """
59 def __init__(
60 self,
61 atom,
62 pos,
63 ecut=30,
64 a=20,
65 spin=None,
66 charge=0,
67 unrestricted=None,
68 center=False,
69 verbose=None,
70 ):
71 """Initialize the Atoms object."""
72 # Set the input parameters (the ordering is important)
73 self._log = create_logger(self) #: Logger object.
74 self.verbose = verbose #: Verbosity level.
75 self.occ = Occupations() #: Occupations object.
76 self.atom = atom #: Atom symbols.
77 self.pos = pos #: Atom positions.
78 self.a = a #: Cell/Vacuum size.
79 self.ecut = ecut #: Cut-off energy.
80 self.center = center #: Enables centering the system in the cell.
81 self.charge = charge #: System charge.
82 self.spin = spin #: Number of unpaired electrons.
83 self.unrestricted = unrestricted #: Enables unrestricted spin handling.
84 self.kpts = KPoints("sc", self.a) #: KPoints object.
86 # Initialize other attributes
87 self.occ.fill() #: Fill states from the given input.
88 self.is_built = False #: Determines the Atoms object build status.
89 self.clear()
91 # ### Class properties ###
93 @property
94 def atom(self):
95 """Atom symbols."""
96 return self._atom
98 @atom.setter
99 def atom(self, value):
100 # Quick option to set the charge for single atoms
101 if isinstance(value, str) and "-q" in value:
102 atom, Z = value.split("-q")
103 self._atom = [atom]
104 self._Natoms = 1
105 self.Z = int(Z)
106 else:
107 # If a string, i.e., chemical formula is given convert it to a list of chemical symbols
108 if isinstance(value, str):
109 self._atom = molecule2list(value)
110 else:
111 self._atom = value
112 # Get the number of atoms and determine the charges
113 self._Natoms = len(self._atom)
114 self.Z = None
116 @property
117 def pos(self):
118 """Atom positions."""
119 return self._pos
121 @pos.setter
122 def pos(self, value):
123 # We need atom positions as a two-dimensional array
124 self._pos = xp.atleast_2d(xp.asarray(value, dtype=float))
125 if self.Natoms != len(self._pos) and self.Natoms > 0:
126 msg = (
127 f"Mismatch between number of atoms ({self.Natoms}) and number of "
128 f"coordinates ({len(self._pos)})."
129 )
130 raise ValueError(msg)
131 if self.Natoms == 0:
132 log.warning("No atoms are placed in the unit cell.")
133 # The structure factor changes when changing pos
134 self.is_built = False
136 @property
137 def ecut(self):
138 """Cut-off energy."""
139 return self._ecut
141 @ecut.setter
142 def ecut(self, value):
143 self._ecut = value
144 # Calculate the sampling from the cut-off energy
145 s = xp.asarray(xp.linalg.norm(self.a, axis=1) / cutoff2gridspacing(value), dtype=int)
146 # Multiply by two and add one to match PWDFT.jl
147 s = 2 * s + 1
148 # Calculate a fast length to optimize the FFT calculations
149 self.s = [int(next_fast_len(i)) for i in s]
150 # The cell discretization changes when changing s or ecut
151 self.is_built = False
153 @property
154 def a(self):
155 """Cell/Vacuum size."""
156 return self._a
158 @a.setter
159 def a(self, value):
160 # Build a cubic cell if a number or 1d-array is given
161 if xp.asarray(value).ndim <= 1:
162 self._a = xp.asarray(value, dtype=float) * xp.eye(3)
163 # Otherwise scaled cell vectors are given
164 else:
165 self._a = xp.asarray(value, dtype=float)
166 # Update ecut and s if it has been set before
167 if hasattr(self, "ecut"):
168 self.ecut = self.ecut
169 # Calculate the unit cell volume
170 self._Omega = abs(xp.linalg.det(self._a))
171 if hasattr(self, "kpts"):
172 self.kpts.a = self._a
173 # The cell changes when changing a
174 self.is_built = False
176 @property
177 def spin(self):
178 """Number of unpaired electrons."""
179 return self.occ.spin
181 @spin.setter
182 def spin(self, value):
183 self.occ.spin = value
185 @property
186 def charge(self):
187 """System charge."""
188 return self.occ.charge
190 @charge.setter
191 def charge(self, value):
192 self.occ.charge = value
194 @property
195 def unrestricted(self):
196 """Enables unrestricted spin handling."""
197 return self.occ.Nspin == 2
199 @unrestricted.setter
200 def unrestricted(self, value):
201 if value is None:
202 self.occ.Nspin = value
203 else:
204 self.occ.Nspin = value + 1
206 @property
207 def center(self):
208 """Enables centering the system in the cell."""
209 return self._center
211 @center.setter
212 def center(self, value):
213 if isinstance(value, str):
214 self._center = value.lower()
215 if self._center not in {"rotate", "shift", "recentered"}:
216 log.error(f"{self._center} is not a recognized center method.")
217 else:
218 self._center = value
219 # Do nothing when recentering
220 if self._center == "recentered":
221 return
222 # Center system such that the geometric inertia tensor will be diagonal
223 # Rotate before shifting!
224 if self._center is True or self._center == "rotate":
225 I = inertia_tensor(self.pos)
226 _, eigvecs = xp.linalg.eigh(I)
227 self.pos = (xp.linalg.inv(eigvecs) @ self.pos.T).T
228 # Shift system such that its geometric center of mass is in the center of the cell
229 if self._center is True or self._center == "shift":
230 com = center_of_mass(self.pos)
231 self.pos = self.pos - (com - xp.sum(self.a, axis=0) / 2)
232 # The structure factor changes when changing pos
233 self.is_built = False
235 @property
236 def verbose(self):
237 """Verbosity level."""
238 return self._verbose
240 @verbose.setter
241 def verbose(self, value):
242 # If no verbosity is given use the global verbosity level
243 if value is None:
244 value = log.verbose
245 self._verbose = get_level(value)
246 self._log.verbose = self._verbose
248 # ### Class properties with a setter outside of the init method ###
250 @property
251 def f(self):
252 """Occupation numbers per state."""
253 return self.occ.f
255 @f.setter
256 def f(self, value):
257 # Pass through to the Occupations object
258 self.occ.f = value
260 @property
261 def s(self):
262 """Real-space sampling of the cell."""
263 return self._s
265 @s.setter
266 def s(self, value):
267 # Choose the same sampling for every direction if an integer is given
268 if isinstance(value, numbers.Integral):
269 value = value * xp.ones(3, dtype=int)
270 self._s = xp.asarray(value, dtype=int)
271 self._Ns = int(xp.prod(self._s))
272 # The cell discretization changes when changing s
273 self.is_built = False
275 @property
276 def Z(self):
277 """Valence charge per atom."""
278 return self._Z
280 @Z.setter
281 def Z(self, value):
282 # Assume same charges for all atoms if an integer is given
283 if isinstance(value, numbers.Integral):
284 value = value * xp.ones(self.Natoms, dtype=int)
285 elif isinstance(value, dict):
286 value = [value[ia] for ia in self.atom]
287 # Get the valence charges from the GTH files
288 elif value is None or isinstance(value, str):
289 value = atom2charge(self.atom, value)
290 self._Z = xp.asarray(value, dtype=int)
291 if self.Natoms != len(self._Z):
292 msg = (
293 f"Mismatch between number of atoms ({self.Natoms}) and number of "
294 f"charges ({len(self._Z)})."
295 )
296 raise ValueError(msg)
297 # Get the number of calculated electrons and pass it to occ
298 self.occ.Nelec = xp.sum(self._Z) - self.charge
299 if self.occ.Nspin and self.occ.bands < self.occ.Nelec * self.occ.Nspin // 2:
300 log.warning("The number of bands is too small, reset to the minimally needed amount.")
301 self.occ.bands = 0
303 # ### Read-only properties ###
305 @property
306 def Natoms(self):
307 """Number of atoms."""
308 return self._Natoms
310 @property
311 def Ns(self):
312 """Number of real-space grid points."""
313 return self._Ns
315 @property
316 def Omega(self):
317 """Unit cell volume."""
318 return self._Omega
320 @property
321 def r(self):
322 """Real-space sampling points."""
323 return self._r
325 @property
326 def active(self):
327 """Mask for active G-vectors."""
328 return self._active
330 @property
331 def G(self):
332 """G-vectors."""
333 return self._G
335 @property
336 def G2(self):
337 """Squared magnitudes of G-vectors."""
338 return self._G2
340 @property
341 def G2c(self):
342 """Truncated squared magnitudes of G-vectors."""
343 return self._G2c
345 @property
346 def Gk2(self):
347 """Squared magnitudes of G+k-vectors."""
348 return self._Gk2
350 @property
351 def Gk2c(self):
352 """Truncated squared magnitudes of G+k-vectors."""
353 return self._Gk2c
355 @property
356 def Sf(self):
357 """Structure factor per atom."""
358 return self._Sf
360 @property
361 def dV(self):
362 """Volume element to multiply when integrating field properties."""
363 return self.Omega / self._Ns
365 @property
366 def _atoms(self):
367 """The Atoms object itself."""
368 # This way we can access the object from Atoms and SCF classes with the same code
369 return self
371 # ### Class methods ###
373 def build(self):
374 """Build all parameters of the Atoms object."""
375 self.kpts.build()
376 self._sample_unit_cell()
377 self.occ.wk = self.kpts.wk # Pass the weights of k-points to the Occupations object
378 self.occ.fill()
379 self.is_built = True
380 return self
382 kernel = build
384 def recenter(self, center=None):
385 """Recenter the system inside the cell.
387 Keyword Args:
388 center: Point to center the system around.
389 """
390 com = center_of_mass(self.pos)
391 if center is None:
392 self.pos = self.pos - (com - xp.sum(self.a, axis=0) / 2)
393 else:
394 center = xp.asarray(center, dtype=float)
395 self.pos = self.pos - (com - center)
396 if self.Sf is not None:
397 # Recalculate the structure factor since it depends on the atom positions
398 self._Sf = xp.exp(1j * (self.G @ self.pos.T)).T
399 self._center = "recentered"
400 return self
402 def set_k(self, k, wk=None):
403 """Interface to set custom k-points.
405 Args:
406 k: k-point coordinates.
408 Keyword Args:
409 wk: k-point weights.
410 """
411 self.kpts.build()
412 self.kpts._k = xp.atleast_2d(xp.asarray(k, dtype=float))
413 if wk is None:
414 self.kpts._wk = xp.ones(len(self.kpts._k)) / len(self.kpts._k)
415 else:
416 self.kpts._wk = xp.asarray(wk, dtype=float)
417 self.kpts._Nk = len(self.kpts._wk)
418 self.kpts._kmesh = None
419 self.occ.wk = self.kpts.wk
420 self._sample_unit_cell()
421 return self
423 def clear(self):
424 """Initialize or clear parameters that will be built out of the inputs."""
425 self._r = None # Sample points in cell
426 self._active = None # Mask for active G-vectors
427 self._G = None # G-vectors
428 self._G2 = None # Squared magnitudes of G-vectors
429 self._G2c = None # Truncated squared magnitudes of G-vectors
430 self._Gk2 = None # Squared magnitudes of G+k-vectors
431 self._Gk2c = None # Truncated squared magnitudes of G+k-vectors
432 self._Sf = None # Structure factor
433 self.is_built = False # Flag to determine if the object was built or not
434 return self
436 def _get_index_matrices(self):
437 """Build index matrices M and N to build the real and reciprocal space samplings.
439 The matrices are using C ordering (the last index is the fastest).
441 Returns:
442 Index matrices.
443 """
444 # Build index matrix M
445 # ms = np.arange(self._Ns)
446 # m1 = np.floor(ms / (self.s[2] * self.s[1])) % self.s[0]
447 # m2 = np.floor(ms / self.s[2]) % self.s[1]
448 # m3 = ms % self.s[2]
449 # M = np.column_stack((m1, m2, m3))
450 M = xp.asarray(np.indices(self.s, dtype=float).transpose(1, 2, 3, 0).reshape(-1, 3))
451 # Build index matrix N
452 N = M - (self.s / 2 < M) * self.s
453 return M, N
455 def _sample_unit_cell(self):
456 """Build the real-space sampling and all G-vector parameters."""
457 # Calculate index matrices
458 M, N = self._get_index_matrices()
459 # Build the real-space sampling
460 self._r = M @ xp.linalg.inv(xp.diag(xp.astype(self.s, float))) @ self.a
461 # Build G-vectors
462 self._G = 2 * math.pi * N @ xp.linalg.inv(self.a.T)
463 # Calculate squared magnitudes of G-vectors
464 self._G2 = xp.linalg.norm(self.G, axis=1) ** 2
465 # Calculate the G2 restriction
466 self._active = [
467 xp.nonzero(2 * self.ecut >= xp.linalg.norm(self.G + self.kpts.k[ik], axis=1) ** 2)
468 for ik in range(self.kpts.Nk)
469 ]
470 self._G2c = self._G2[xp.nonzero(2 * self.ecut >= self._G2)]
471 # Calculate G+k-vectors
472 self._Gk2 = xp.stack(
473 [xp.linalg.norm(self.G + self.kpts.k[ik], axis=1) ** 2 for ik in range(self.kpts.Nk)]
474 )
475 self._Gk2c = [self._Gk2[ik][self._active[ik]] for ik in range(self.kpts.Nk)]
476 # Calculate the structure factor per atom
477 self._Sf = xp.exp(1j * (self.G @ self.pos.T)).T
479 # Create the grid used for the non-wave function fields and append it to the end
480 self._active.append(xp.nonzero(2 * self.ecut >= self._G2))
481 self._Gk2 = xp.vstack((self._Gk2, self._G2))
482 self._Gk2c.append(self._G2c)
484 O = operators.O
485 L = operators.L
486 Linv = operators.Linv
487 K = operators.K
488 T = operators.T
489 I = operators.I
490 J = operators.J
491 Idag = operators.Idag
492 Jdag = operators.Jdag
494 def __repr__(self):
495 """Print the parameters stored in the Atoms object."""
496 out = "Atom Valence Position"
497 for i in range(self.Natoms):
498 out += (
499 f"\n{self.atom[i]:>3} {self.Z[i]:>6} "
500 f"{self.pos[i, 0]:10.5f} {self.pos[i, 1]:10.5f} {self.pos[i, 2]:10.5f}"
501 )
502 return out