Coverage for eminus/atoms.py: 98.17%

273 statements  

« prev     ^ index     » next       coverage.py v7.14.2, created at 2026-06-23 09:07 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Atoms class definition.""" 

4 

5import math 

6import numbers 

7 

8import numpy as np 

9from scipy.fft import next_fast_len 

10 

11from . import backend as xp 

12from . import operators 

13from .kpoints import KPoints 

14from .logger import create_logger, get_level, log 

15from .occupations import Occupations 

16from .tools import center_of_mass, cutoff2gridspacing, inertia_tensor 

17from .utils import atom2charge, BaseObject, molecule2list 

18 

19 

20class Atoms(BaseObject): 

21 """Atoms object that holds all system and cell parameters. 

22 

23 Args: 

24 atom: Atom symbols. 

25 

26 A string can be given, e.g., with :code:`CH4` that will be parsed to 

27 :code:`["C", "H", "H", "H", "H"]`. When calculating atoms one can directly provide the 

28 charge, e.g., with :code:`Li-q3`. 

29 pos: Atom positions. 

30 

31 Keyword Args: 

32 ecut: Cut-off energy. 

33 

34 Defaults to 30 Eh (ca. 816 eV). 

35 a: Cell size or vacuum size. 

36 

37 Floats will create a cubic unit cell. Defaults to a 20 a0 (ca. 10.5 A) cubic cell. 

38 Scaled lattice vectors can be given to build a custom cell. 

39 spin: Number of unpaired electrons. 

40 

41 This is the difference between the number of up and down electrons. 

42 charge: Charge of the system. 

43 unrestricted: Handling of spin. 

44 

45 :code:`False` for restricted, :code:`True` for unrestricted, and :code:`None` for 

46 automatic detection. 

47 center: Center the system inside the cell. 

48 

49 Aligns the geometric center of mass with the center of the call and rotates the system, 

50 such that its geometric moment of inertia aligns with the coordinate axes. Can be one of 

51 bool, "shift", and "rotate". 

52 verbose: Level of output. 

53 

54 Can be one of "critical", "error", "warning", "info" (default), or "debug". An integer 

55 value can be used as well, where larger numbers mean more output, starting from 0. 

56 None will use the global logger verbosity value. 

57 """ 

58 

59 def __init__( 

60 self, 

61 atom, 

62 pos, 

63 ecut=30, 

64 a=20, 

65 spin=None, 

66 charge=0, 

67 unrestricted=None, 

68 center=False, 

69 verbose=None, 

70 ): 

71 """Initialize the Atoms object.""" 

72 # Set the input parameters (the ordering is important) 

73 self._log = create_logger(self) #: Logger object. 

74 self.verbose = verbose #: Verbosity level. 

75 self.occ = Occupations() #: Occupations object. 

76 self.atom = atom #: Atom symbols. 

77 self.pos = pos #: Atom positions. 

78 self.a = a #: Cell/Vacuum size. 

79 self.ecut = ecut #: Cut-off energy. 

80 self.center = center #: Enables centering the system in the cell. 

81 self.charge = charge #: System charge. 

82 self.spin = spin #: Number of unpaired electrons. 

83 self.unrestricted = unrestricted #: Enables unrestricted spin handling. 

84 self.kpts = KPoints("sc", self.a) #: KPoints object. 

85 

86 # Initialize other attributes 

87 self.occ.fill() #: Fill states from the given input. 

88 self.is_built = False #: Determines the Atoms object build status. 

89 self.clear() 

90 

91 # ### Class properties ### 

92 

93 @property 

94 def atom(self): 

95 """Atom symbols.""" 

96 return self._atom 

97 

98 @atom.setter 

99 def atom(self, value): 

100 # Quick option to set the charge for single atoms 

101 if isinstance(value, str) and "-q" in value: 

102 atom, Z = value.split("-q") 

103 self._atom = [atom] 

104 self._Natoms = 1 

105 self.Z = int(Z) 

106 else: 

107 # If a string, i.e., chemical formula is given convert it to a list of chemical symbols 

108 if isinstance(value, str): 

109 self._atom = molecule2list(value) 

110 else: 

111 self._atom = value 

112 # Get the number of atoms and determine the charges 

113 self._Natoms = len(self._atom) 

114 self.Z = None 

115 

116 @property 

117 def pos(self): 

118 """Atom positions.""" 

119 return self._pos 

120 

121 @pos.setter 

122 def pos(self, value): 

123 # We need atom positions as a two-dimensional array 

124 self._pos = xp.atleast_2d(xp.asarray(value, dtype=float)) 

125 if self.Natoms != len(self._pos) and self.Natoms > 0: 

126 msg = ( 

127 f"Mismatch between number of atoms ({self.Natoms}) and number of " 

128 f"coordinates ({len(self._pos)})." 

129 ) 

130 raise ValueError(msg) 

131 if self.Natoms == 0: 

132 log.warning("No atoms are placed in the unit cell.") 

133 # The structure factor changes when changing pos 

134 self.is_built = False 

135 

136 @property 

137 def ecut(self): 

138 """Cut-off energy.""" 

139 return self._ecut 

140 

141 @ecut.setter 

142 def ecut(self, value): 

143 self._ecut = value 

144 # Calculate the sampling from the cut-off energy 

145 s = xp.asarray(xp.linalg.norm(self.a, axis=1) / cutoff2gridspacing(value), dtype=int) 

146 # Multiply by two and add one to match PWDFT.jl 

147 s = 2 * s + 1 

148 # Calculate a fast length to optimize the FFT calculations 

149 self.s = [int(next_fast_len(i)) for i in s] 

150 # The cell discretization changes when changing s or ecut 

151 self.is_built = False 

152 

153 @property 

154 def a(self): 

155 """Cell/Vacuum size.""" 

156 return self._a 

157 

158 @a.setter 

159 def a(self, value): 

160 # Build a cubic cell if a number or 1d-array is given 

161 if xp.asarray(value).ndim <= 1: 

162 self._a = xp.asarray(value, dtype=float) * xp.eye(3) 

163 # Otherwise scaled cell vectors are given 

164 else: 

165 self._a = xp.asarray(value, dtype=float) 

166 # Update ecut and s if it has been set before 

167 if hasattr(self, "ecut"): 

168 self.ecut = self.ecut 

169 # Calculate the unit cell volume 

170 self._Omega = abs(xp.linalg.det(self._a)) 

171 if hasattr(self, "kpts"): 

172 self.kpts.a = self._a 

173 # The cell changes when changing a 

174 self.is_built = False 

175 

176 @property 

177 def spin(self): 

178 """Number of unpaired electrons.""" 

179 return self.occ.spin 

180 

181 @spin.setter 

182 def spin(self, value): 

183 self.occ.spin = value 

184 

185 @property 

186 def charge(self): 

187 """System charge.""" 

188 return self.occ.charge 

189 

190 @charge.setter 

191 def charge(self, value): 

192 self.occ.charge = value 

193 

194 @property 

195 def unrestricted(self): 

196 """Enables unrestricted spin handling.""" 

197 return self.occ.Nspin == 2 

198 

199 @unrestricted.setter 

200 def unrestricted(self, value): 

201 if value is None: 

202 self.occ.Nspin = value 

203 else: 

204 self.occ.Nspin = value + 1 

205 

206 @property 

207 def center(self): 

208 """Enables centering the system in the cell.""" 

209 return self._center 

210 

211 @center.setter 

212 def center(self, value): 

213 if isinstance(value, str): 

214 self._center = value.lower() 

215 if self._center not in {"rotate", "shift", "recentered"}: 

216 log.error(f"{self._center} is not a recognized center method.") 

217 else: 

218 self._center = value 

219 # Do nothing when recentering 

220 if self._center == "recentered": 

221 return 

222 # Center system such that the geometric inertia tensor will be diagonal 

223 # Rotate before shifting! 

224 if self._center is True or self._center == "rotate": 

225 I = inertia_tensor(self.pos) 

226 _, eigvecs = xp.linalg.eigh(I) 

227 self.pos = (xp.linalg.inv(eigvecs) @ self.pos.T).T 

228 # Shift system such that its geometric center of mass is in the center of the cell 

229 if self._center is True or self._center == "shift": 

230 com = center_of_mass(self.pos) 

231 self.pos = self.pos - (com - xp.sum(self.a, axis=0) / 2) 

232 # The structure factor changes when changing pos 

233 self.is_built = False 

234 

235 @property 

236 def verbose(self): 

237 """Verbosity level.""" 

238 return self._verbose 

239 

240 @verbose.setter 

241 def verbose(self, value): 

242 # If no verbosity is given use the global verbosity level 

243 if value is None: 

244 value = log.verbose 

245 self._verbose = get_level(value) 

246 self._log.verbose = self._verbose 

247 

248 # ### Class properties with a setter outside of the init method ### 

249 

250 @property 

251 def f(self): 

252 """Occupation numbers per state.""" 

253 return self.occ.f 

254 

255 @f.setter 

256 def f(self, value): 

257 # Pass through to the Occupations object 

258 self.occ.f = value 

259 

260 @property 

261 def s(self): 

262 """Real-space sampling of the cell.""" 

263 return self._s 

264 

265 @s.setter 

266 def s(self, value): 

267 # Choose the same sampling for every direction if an integer is given 

268 if isinstance(value, numbers.Integral): 

269 value = value * xp.ones(3, dtype=int) 

270 self._s = xp.asarray(value, dtype=int) 

271 self._Ns = int(xp.prod(self._s)) 

272 # The cell discretization changes when changing s 

273 self.is_built = False 

274 

275 @property 

276 def Z(self): 

277 """Valence charge per atom.""" 

278 return self._Z 

279 

280 @Z.setter 

281 def Z(self, value): 

282 # Assume same charges for all atoms if an integer is given 

283 if isinstance(value, numbers.Integral): 

284 value = value * xp.ones(self.Natoms, dtype=int) 

285 elif isinstance(value, dict): 

286 value = [value[ia] for ia in self.atom] 

287 # Get the valence charges from the GTH files 

288 elif value is None or isinstance(value, str): 

289 value = atom2charge(self.atom, value) 

290 self._Z = xp.asarray(value, dtype=int) 

291 if self.Natoms != len(self._Z): 

292 msg = ( 

293 f"Mismatch between number of atoms ({self.Natoms}) and number of " 

294 f"charges ({len(self._Z)})." 

295 ) 

296 raise ValueError(msg) 

297 # Get the number of calculated electrons and pass it to occ 

298 self.occ.Nelec = xp.sum(self._Z) - self.charge 

299 if self.occ.Nspin and self.occ.bands < self.occ.Nelec * self.occ.Nspin // 2: 

300 log.warning("The number of bands is too small, reset to the minimally needed amount.") 

301 self.occ.bands = 0 

302 

303 # ### Read-only properties ### 

304 

305 @property 

306 def Natoms(self): 

307 """Number of atoms.""" 

308 return self._Natoms 

309 

310 @property 

311 def Ns(self): 

312 """Number of real-space grid points.""" 

313 return self._Ns 

314 

315 @property 

316 def Omega(self): 

317 """Unit cell volume.""" 

318 return self._Omega 

319 

320 @property 

321 def r(self): 

322 """Real-space sampling points.""" 

323 return self._r 

324 

325 @property 

326 def active(self): 

327 """Mask for active G-vectors.""" 

328 return self._active 

329 

330 @property 

331 def G(self): 

332 """G-vectors.""" 

333 return self._G 

334 

335 @property 

336 def G2(self): 

337 """Squared magnitudes of G-vectors.""" 

338 return self._G2 

339 

340 @property 

341 def G2c(self): 

342 """Truncated squared magnitudes of G-vectors.""" 

343 return self._G2c 

344 

345 @property 

346 def Gk2(self): 

347 """Squared magnitudes of G+k-vectors.""" 

348 return self._Gk2 

349 

350 @property 

351 def Gk2c(self): 

352 """Truncated squared magnitudes of G+k-vectors.""" 

353 return self._Gk2c 

354 

355 @property 

356 def Sf(self): 

357 """Structure factor per atom.""" 

358 return self._Sf 

359 

360 @property 

361 def dV(self): 

362 """Volume element to multiply when integrating field properties.""" 

363 return self.Omega / self._Ns 

364 

365 @property 

366 def _atoms(self): 

367 """The Atoms object itself.""" 

368 # This way we can access the object from Atoms and SCF classes with the same code 

369 return self 

370 

371 # ### Class methods ### 

372 

373 def build(self): 

374 """Build all parameters of the Atoms object.""" 

375 self.kpts.build() 

376 self._sample_unit_cell() 

377 self.occ.wk = self.kpts.wk # Pass the weights of k-points to the Occupations object 

378 self.occ.fill() 

379 self.is_built = True 

380 return self 

381 

382 kernel = build 

383 

384 def recenter(self, center=None): 

385 """Recenter the system inside the cell. 

386 

387 Keyword Args: 

388 center: Point to center the system around. 

389 """ 

390 com = center_of_mass(self.pos) 

391 if center is None: 

392 self.pos = self.pos - (com - xp.sum(self.a, axis=0) / 2) 

393 else: 

394 center = xp.asarray(center, dtype=float) 

395 self.pos = self.pos - (com - center) 

396 if self.Sf is not None: 

397 # Recalculate the structure factor since it depends on the atom positions 

398 self._Sf = xp.exp(1j * (self.G @ self.pos.T)).T 

399 self._center = "recentered" 

400 return self 

401 

402 def set_k(self, k, wk=None): 

403 """Interface to set custom k-points. 

404 

405 Args: 

406 k: k-point coordinates. 

407 

408 Keyword Args: 

409 wk: k-point weights. 

410 """ 

411 self.kpts.build() 

412 self.kpts._k = xp.atleast_2d(xp.asarray(k, dtype=float)) 

413 if wk is None: 

414 self.kpts._wk = xp.ones(len(self.kpts._k)) / len(self.kpts._k) 

415 else: 

416 self.kpts._wk = xp.asarray(wk, dtype=float) 

417 self.kpts._Nk = len(self.kpts._wk) 

418 self.kpts._kmesh = None 

419 self.occ.wk = self.kpts.wk 

420 self._sample_unit_cell() 

421 return self 

422 

423 def clear(self): 

424 """Initialize or clear parameters that will be built out of the inputs.""" 

425 self._r = None # Sample points in cell 

426 self._active = None # Mask for active G-vectors 

427 self._G = None # G-vectors 

428 self._G2 = None # Squared magnitudes of G-vectors 

429 self._G2c = None # Truncated squared magnitudes of G-vectors 

430 self._Gk2 = None # Squared magnitudes of G+k-vectors 

431 self._Gk2c = None # Truncated squared magnitudes of G+k-vectors 

432 self._Sf = None # Structure factor 

433 self.is_built = False # Flag to determine if the object was built or not 

434 return self 

435 

436 def _get_index_matrices(self): 

437 """Build index matrices M and N to build the real and reciprocal space samplings. 

438 

439 The matrices are using C ordering (the last index is the fastest). 

440 

441 Returns: 

442 Index matrices. 

443 """ 

444 # Build index matrix M 

445 # ms = np.arange(self._Ns) 

446 # m1 = np.floor(ms / (self.s[2] * self.s[1])) % self.s[0] 

447 # m2 = np.floor(ms / self.s[2]) % self.s[1] 

448 # m3 = ms % self.s[2] 

449 # M = np.column_stack((m1, m2, m3)) 

450 M = xp.asarray(np.indices(self.s, dtype=float).transpose(1, 2, 3, 0).reshape(-1, 3)) 

451 # Build index matrix N 

452 N = M - (self.s / 2 < M) * self.s 

453 return M, N 

454 

455 def _sample_unit_cell(self): 

456 """Build the real-space sampling and all G-vector parameters.""" 

457 # Calculate index matrices 

458 M, N = self._get_index_matrices() 

459 # Build the real-space sampling 

460 self._r = M @ xp.linalg.inv(xp.diag(xp.astype(self.s, float))) @ self.a 

461 # Build G-vectors 

462 self._G = 2 * math.pi * N @ xp.linalg.inv(self.a.T) 

463 # Calculate squared magnitudes of G-vectors 

464 self._G2 = xp.linalg.norm(self.G, axis=1) ** 2 

465 # Calculate the G2 restriction 

466 self._active = [ 

467 xp.nonzero(2 * self.ecut >= xp.linalg.norm(self.G + self.kpts.k[ik], axis=1) ** 2) 

468 for ik in range(self.kpts.Nk) 

469 ] 

470 self._G2c = self._G2[xp.nonzero(2 * self.ecut >= self._G2)] 

471 # Calculate G+k-vectors 

472 self._Gk2 = xp.stack( 

473 [xp.linalg.norm(self.G + self.kpts.k[ik], axis=1) ** 2 for ik in range(self.kpts.Nk)] 

474 ) 

475 self._Gk2c = [self._Gk2[ik][self._active[ik]] for ik in range(self.kpts.Nk)] 

476 # Calculate the structure factor per atom 

477 self._Sf = xp.exp(1j * (self.G @ self.pos.T)).T 

478 

479 # Create the grid used for the non-wave function fields and append it to the end 

480 self._active.append(xp.nonzero(2 * self.ecut >= self._G2)) 

481 self._Gk2 = xp.vstack((self._Gk2, self._G2)) 

482 self._Gk2c.append(self._G2c) 

483 

484 O = operators.O 

485 L = operators.L 

486 Linv = operators.Linv 

487 K = operators.K 

488 T = operators.T 

489 I = operators.I 

490 J = operators.J 

491 Idag = operators.Idag 

492 Jdag = operators.Jdag 

493 

494 def __repr__(self): 

495 """Print the parameters stored in the Atoms object.""" 

496 out = "Atom Valence Position" 

497 for i in range(self.Natoms): 

498 out += ( 

499 f"\n{self.atom[i]:>3} {self.Z[i]:>6} " 

500 f"{self.pos[i, 0]:10.5f} {self.pos[i, 1]:10.5f} {self.pos[i, 2]:10.5f}" 

501 ) 

502 return out