Coverage for eminus / atoms.py: 98.15%

271 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-21 14:20 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Atoms class definition.""" 

4 

5import math 

6import numbers 

7 

8import numpy as np 

9from scipy.fft import next_fast_len 

10 

11from . import backend as xp 

12from . import operators 

13from .kpoints import KPoints 

14from .logger import create_logger, get_level, log 

15from .occupations import Occupations 

16from .tools import center_of_mass, cutoff2gridspacing, inertia_tensor 

17from .utils import atom2charge, BaseObject, molecule2list 

18 

19 

20class Atoms(BaseObject): 

21 """Atoms object that holds all system and cell parameters. 

22 

23 Args: 

24 atom: Atom symbols. 

25 

26 A string can be given, e.g., with :code:`CH4` that will be parsed to 

27 :code:`["C", "H", "H", "H", "H"]`. When calculating atoms one can directly provide the 

28 charge, e.g., with :code:`Li-q3`. 

29 pos: Atom positions. 

30 

31 Keyword Args: 

32 ecut: Cut-off energy. 

33 

34 Defaults to 30 Eh (ca. 816 eV). 

35 a: Cell size or vacuum size. 

36 

37 Floats will create a cubic unit cell. Defaults to a 20 a0 (ca. 10.5 A) cubic cell. 

38 Scaled lattice vectors can be given to build a custom cell. 

39 spin: Number of unpaired electrons. 

40 

41 This is the difference between the number of up and down electrons. 

42 charge: Charge of the system. 

43 unrestricted: Handling of spin. 

44 

45 :code:`False` for restricted, :code:`True` for unrestricted, and :code:`None` for 

46 automatic detection. 

47 center: Center the system inside the cell. 

48 

49 Aligns the geometric center of mass with the center of the call and rotates the system, 

50 such that its geometric moment of inertia aligns with the coordinate axes. Can be one of 

51 bool, "shift", and "rotate". 

52 verbose: Level of output. 

53 

54 Can be one of "critical", "error", "warning", "info" (default), or "debug". An integer 

55 value can be used as well, where larger numbers mean more output, starting from 0. 

56 None will use the global logger verbosity value. 

57 """ 

58 

59 def __init__( 

60 self, 

61 atom, 

62 pos, 

63 ecut=30, 

64 a=20, 

65 spin=None, 

66 charge=0, 

67 unrestricted=None, 

68 center=False, 

69 verbose=None, 

70 ): 

71 """Initialize the Atoms object.""" 

72 # Set the input parameters (the ordering is important) 

73 self._log = create_logger(self) #: Logger object. 

74 self.verbose = verbose #: Verbosity level. 

75 self.occ = Occupations() #: Occupations object. 

76 self.atom = atom #: Atom symbols. 

77 self.pos = pos #: Atom positions. 

78 self.a = a #: Cell/Vacuum size. 

79 self.ecut = ecut #: Cut-off energy. 

80 self.center = center #: Enables centering the system in the cell. 

81 self.charge = charge #: System charge. 

82 self.spin = spin #: Number of unpaired electrons. 

83 self.unrestricted = unrestricted #: Enables unrestricted spin handling. 

84 self.kpts = KPoints("sc", self.a) #: KPoints object. 

85 

86 # Initialize other attributes 

87 self.occ.fill() #: Fill states from the given input. 

88 self.is_built = False #: Determines the Atoms object build status. 

89 self.clear() 

90 

91 # ### Class properties ### 

92 

93 @property 

94 def atom(self): 

95 """Atom symbols.""" 

96 return self._atom 

97 

98 @atom.setter 

99 def atom(self, value): 

100 # Quick option to set the charge for single atoms 

101 if isinstance(value, str) and "-q" in value: 

102 atom, Z = value.split("-q") 

103 self._atom = [atom] 

104 self._Natoms = 1 

105 self.Z = int(Z) 

106 else: 

107 # If a string, i.e., chemical formula is given convert it to a list of chemical symbols 

108 if isinstance(value, str): 

109 self._atom = molecule2list(value) 

110 else: 

111 self._atom = value 

112 # Get the number of atoms and determine the charges 

113 self._Natoms = len(self._atom) 

114 self.Z = None 

115 

116 @property 

117 def pos(self): 

118 """Atom positions.""" 

119 return self._pos 

120 

121 @pos.setter 

122 def pos(self, value): 

123 # We need atom positions as a two-dimensional array 

124 self._pos = xp.atleast_2d(xp.asarray(value, dtype=float)) 

125 if self.Natoms != len(self._pos): 

126 msg = ( 

127 f"Mismatch between number of atoms ({self.Natoms}) and number of " 

128 f"coordinates ({len(self._pos)})." 

129 ) 

130 raise ValueError(msg) 

131 # The structure factor changes when changing pos 

132 self.is_built = False 

133 

134 @property 

135 def ecut(self): 

136 """Cut-off energy.""" 

137 return self._ecut 

138 

139 @ecut.setter 

140 def ecut(self, value): 

141 self._ecut = value 

142 # Calculate the sampling from the cut-off energy 

143 s = xp.asarray(xp.linalg.norm(self.a, axis=1) / cutoff2gridspacing(value), dtype=int) 

144 # Multiply by two and add one to match PWDFT.jl 

145 s = 2 * s + 1 

146 # Calculate a fast length to optimize the FFT calculations 

147 self.s = [int(next_fast_len(i)) for i in s] 

148 # The cell discretization changes when changing s or ecut 

149 self.is_built = False 

150 

151 @property 

152 def a(self): 

153 """Cell/Vacuum size.""" 

154 return self._a 

155 

156 @a.setter 

157 def a(self, value): 

158 # Build a cubic cell if a number or 1d-array is given 

159 if xp.asarray(value).ndim <= 1: 

160 self._a = xp.asarray(value, dtype=float) * xp.eye(3) 

161 # Otherwise scaled cell vectors are given 

162 else: 

163 self._a = xp.asarray(value, dtype=float) 

164 # Update ecut and s if it has been set before 

165 if hasattr(self, "ecut"): 

166 self.ecut = self.ecut 

167 # Calculate the unit cell volume 

168 self._Omega = abs(xp.linalg.det(self._a)) 

169 if hasattr(self, "kpts"): 

170 self.kpts.a = self._a 

171 # The cell changes when changing a 

172 self.is_built = False 

173 

174 @property 

175 def spin(self): 

176 """Number of unpaired electrons.""" 

177 return self.occ.spin 

178 

179 @spin.setter 

180 def spin(self, value): 

181 self.occ.spin = value 

182 

183 @property 

184 def charge(self): 

185 """System charge.""" 

186 return self.occ.charge 

187 

188 @charge.setter 

189 def charge(self, value): 

190 self.occ.charge = value 

191 

192 @property 

193 def unrestricted(self): 

194 """Enables unrestricted spin handling.""" 

195 return self.occ.Nspin == 2 

196 

197 @unrestricted.setter 

198 def unrestricted(self, value): 

199 if value is None: 

200 self.occ.Nspin = value 

201 else: 

202 self.occ.Nspin = value + 1 

203 

204 @property 

205 def center(self): 

206 """Enables centering the system in the cell.""" 

207 return self._center 

208 

209 @center.setter 

210 def center(self, value): 

211 if isinstance(value, str): 

212 self._center = value.lower() 

213 if self._center not in {"rotate", "shift", "recentered"}: 

214 log.error(f"{self._center} is not a recognized center method.") 

215 else: 

216 self._center = value 

217 # Do nothing when recentering 

218 if self._center == "recentered": 

219 return 

220 # Center system such that the geometric inertia tensor will be diagonal 

221 # Rotate before shifting! 

222 if self._center is True or self._center == "rotate": 

223 I = inertia_tensor(self.pos) 

224 _, eigvecs = xp.linalg.eigh(I) 

225 self.pos = (xp.linalg.inv(eigvecs) @ self.pos.T).T 

226 # Shift system such that its geometric center of mass is in the center of the cell 

227 if self._center is True or self._center == "shift": 

228 com = center_of_mass(self.pos) 

229 self.pos = self.pos - (com - xp.sum(self.a, axis=0) / 2) 

230 # The structure factor changes when changing pos 

231 self.is_built = False 

232 

233 @property 

234 def verbose(self): 

235 """Verbosity level.""" 

236 return self._verbose 

237 

238 @verbose.setter 

239 def verbose(self, value): 

240 # If no verbosity is given use the global verbosity level 

241 if value is None: 

242 value = log.verbose 

243 self._verbose = get_level(value) 

244 self._log.verbose = self._verbose 

245 

246 # ### Class properties with a setter outside of the init method ### 

247 

248 @property 

249 def f(self): 

250 """Occupation numbers per state.""" 

251 return self.occ.f 

252 

253 @f.setter 

254 def f(self, value): 

255 # Pass through to the Occupations object 

256 self.occ.f = value 

257 

258 @property 

259 def s(self): 

260 """Real-space sampling of the cell.""" 

261 return self._s 

262 

263 @s.setter 

264 def s(self, value): 

265 # Choose the same sampling for every direction if an integer is given 

266 if isinstance(value, numbers.Integral): 

267 value = value * xp.ones(3, dtype=int) 

268 self._s = xp.asarray(value, dtype=int) 

269 self._Ns = int(xp.prod(self._s)) 

270 # The cell discretization changes when changing s 

271 self.is_built = False 

272 

273 @property 

274 def Z(self): 

275 """Valence charge per atom.""" 

276 return self._Z 

277 

278 @Z.setter 

279 def Z(self, value): 

280 # Assume same charges for all atoms if an integer is given 

281 if isinstance(value, numbers.Integral): 

282 value = value * xp.ones(self.Natoms, dtype=int) 

283 elif isinstance(value, dict): 

284 value = [value[ia] for ia in self.atom] 

285 # Get the valence charges from the GTH files 

286 elif value is None or isinstance(value, str): 

287 value = atom2charge(self.atom, value) 

288 self._Z = xp.asarray(value, dtype=int) 

289 if self.Natoms != len(self._Z): 

290 msg = ( 

291 f"Mismatch between number of atoms ({self.Natoms}) and number of " 

292 f"charges ({len(self._Z)})." 

293 ) 

294 raise ValueError(msg) 

295 # Get the number of calculated electrons and pass it to occ 

296 self.occ.Nelec = xp.sum(self._Z) - self.charge 

297 if self.occ.Nspin and self.occ.bands < self.occ.Nelec * self.occ.Nspin // 2: 

298 log.warning("The number of bands is too small, reset to the minimally needed amount.") 

299 self.occ.bands = 0 

300 

301 # ### Read-only properties ### 

302 

303 @property 

304 def Natoms(self): 

305 """Number of atoms.""" 

306 return self._Natoms 

307 

308 @property 

309 def Ns(self): 

310 """Number of real-space grid points.""" 

311 return self._Ns 

312 

313 @property 

314 def Omega(self): 

315 """Unit cell volume.""" 

316 return self._Omega 

317 

318 @property 

319 def r(self): 

320 """Real-space sampling points.""" 

321 return self._r 

322 

323 @property 

324 def active(self): 

325 """Mask for active G-vectors.""" 

326 return self._active 

327 

328 @property 

329 def G(self): 

330 """G-vectors.""" 

331 return self._G 

332 

333 @property 

334 def G2(self): 

335 """Squared magnitudes of G-vectors.""" 

336 return self._G2 

337 

338 @property 

339 def G2c(self): 

340 """Truncated squared magnitudes of G-vectors.""" 

341 return self._G2c 

342 

343 @property 

344 def Gk2(self): 

345 """Squared magnitudes of G+k-vectors.""" 

346 return self._Gk2 

347 

348 @property 

349 def Gk2c(self): 

350 """Truncated squared magnitudes of G+k-vectors.""" 

351 return self._Gk2c 

352 

353 @property 

354 def Sf(self): 

355 """Structure factor per atom.""" 

356 return self._Sf 

357 

358 @property 

359 def dV(self): 

360 """Volume element to multiply when integrating field properties.""" 

361 return self.Omega / self._Ns 

362 

363 @property 

364 def _atoms(self): 

365 """Return the Atoms object itself.""" 

366 # This way we can access the object from Atoms and SCF classes with the same code 

367 return self 

368 

369 # ### Class methods ### 

370 

371 def build(self): 

372 """Build all parameters of the Atoms object.""" 

373 self.kpts.build() 

374 self._sample_unit_cell() 

375 self.occ.wk = self.kpts.wk # Pass the weights of k-points to the Occupations object 

376 self.occ.fill() 

377 self.is_built = True 

378 return self 

379 

380 kernel = build 

381 

382 def recenter(self, center=None): 

383 """Recenter the system inside the cell. 

384 

385 Keyword Args: 

386 center: Point to center the system around. 

387 """ 

388 com = center_of_mass(self.pos) 

389 if center is None: 

390 self.pos = self.pos - (com - xp.sum(self.a, axis=0) / 2) 

391 else: 

392 center = xp.asarray(center, dtype=float) 

393 self.pos = self.pos - (com - center) 

394 if self.Sf is not None: 

395 # Recalculate the structure factor since it depends on the atom positions 

396 self._Sf = xp.exp(1j * (self.G @ self.pos.T)).T 

397 self._center = "recentered" 

398 return self 

399 

400 def set_k(self, k, wk=None): 

401 """Interface to set custom k-points. 

402 

403 Args: 

404 k: k-point coordinates. 

405 

406 Keyword Args: 

407 wk: k-point weights. 

408 """ 

409 self.kpts.build() 

410 self.kpts._k = xp.atleast_2d(xp.asarray(k, dtype=float)) 

411 if wk is None: 

412 self.kpts._wk = xp.ones(len(self.kpts._k)) / len(self.kpts._k) 

413 else: 

414 self.kpts._wk = xp.asarray(wk, dtype=float) 

415 self.kpts._Nk = len(self.kpts._wk) 

416 self.kpts._kmesh = None 

417 self.occ.wk = self.kpts.wk 

418 self._sample_unit_cell() 

419 return self 

420 

421 def clear(self): 

422 """Initialize or clear parameters that will be built out of the inputs.""" 

423 self._r = None # Sample points in cell 

424 self._active = None # Mask for active G-vectors 

425 self._G = None # G-vectors 

426 self._G2 = None # Squared magnitudes of G-vectors 

427 self._G2c = None # Truncated squared magnitudes of G-vectors 

428 self._Gk2 = None # Squared magnitudes of G+k-vectors 

429 self._Gk2c = None # Truncated squared magnitudes of G+k-vectors 

430 self._Sf = None # Structure factor 

431 self.is_built = False # Flag to determine if the object was built or not 

432 return self 

433 

434 def _get_index_matrices(self): 

435 """Build index matrices M and N to build the real and reciprocal space samplings. 

436 

437 The matrices are using C ordering (the last index is the fastest). 

438 

439 Returns: 

440 Index matrices. 

441 """ 

442 # Build index matrix M 

443 # ms = np.arange(self._Ns) 

444 # m1 = np.floor(ms / (self.s[2] * self.s[1])) % self.s[0] 

445 # m2 = np.floor(ms / self.s[2]) % self.s[1] 

446 # m3 = ms % self.s[2] 

447 # M = np.column_stack((m1, m2, m3)) 

448 M = xp.asarray(np.indices(self.s, dtype=float).transpose(1, 2, 3, 0).reshape(-1, 3)) 

449 # Build index matrix N 

450 N = M - (self.s / 2 < M) * self.s 

451 return M, N 

452 

453 def _sample_unit_cell(self): 

454 """Build the real-space sampling and all G-vector parameters.""" 

455 # Calculate index matrices 

456 M, N = self._get_index_matrices() 

457 # Build the real-space sampling 

458 self._r = M @ xp.linalg.inv(xp.diag(xp.astype(self.s, float))) @ self.a 

459 # Build G-vectors 

460 self._G = 2 * math.pi * N @ xp.linalg.inv(self.a.T) 

461 # Calculate squared magnitudes of G-vectors 

462 self._G2 = xp.linalg.norm(self.G, axis=1) ** 2 

463 # Calculate the G2 restriction 

464 self._active = [ 

465 xp.nonzero(2 * self.ecut >= xp.linalg.norm(self.G + self.kpts.k[ik], axis=1) ** 2) 

466 for ik in range(self.kpts.Nk) 

467 ] 

468 self._G2c = self._G2[xp.nonzero(2 * self.ecut >= self._G2)] 

469 # Calculate G+k-vectors 

470 self._Gk2 = xp.stack( 

471 [xp.linalg.norm(self.G + self.kpts.k[ik], axis=1) ** 2 for ik in range(self.kpts.Nk)] 

472 ) 

473 self._Gk2c = [self.Gk2[ik][self._active[ik]] for ik in range(self.kpts.Nk)] 

474 # Calculate the structure factor per atom 

475 self._Sf = xp.exp(1j * (self.G @ self.pos.T)).T 

476 

477 # Create the grid used for the non-wave function fields and append it to the end 

478 self._active.append(xp.nonzero(2 * self.ecut >= self._G2)) 

479 self._Gk2 = xp.vstack((self._Gk2, self._G2)) 

480 self._Gk2c.append(self._G2c) 

481 

482 O = operators.O 

483 L = operators.L 

484 Linv = operators.Linv 

485 K = operators.K 

486 T = operators.T 

487 I = operators.I 

488 J = operators.J 

489 Idag = operators.Idag 

490 Jdag = operators.Jdag 

491 

492 def __repr__(self): 

493 """Print the parameters stored in the Atoms object.""" 

494 out = "Atom Valence Position" 

495 for i in range(self.Natoms): 

496 out += ( 

497 f"\n{self.atom[i]:>3} {self.Z[i]:>6} " 

498 f"{self.pos[i, 0]:10.5f} {self.pos[i, 1]:10.5f} {self.pos[i, 2]:10.5f}" 

499 ) 

500 return out