Coverage for eminus/localizer.py: 99.26%

136 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 12:19 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Utilities to localize and analyze orbitals.""" 

4 

5import math 

6 

7from scipy.linalg import qr 

8from scipy.stats import unitary_group 

9 

10from . import backend as xp 

11from . import config 

12from .logger import log 

13from .utils import handle_k, handle_spin 

14 

15 

16@handle_k(mode="skip") 

17def eval_psi(atoms, psi, r): 

18 """Evaluate orbitals at given coordinate points. 

19 

20 Args: 

21 atoms: Atoms object. 

22 psi: Set of orbitals in reciprocal space. 

23 r: Real-space positions. 

24 

25 Returns: 

26 Values of psi at points r. 

27 """ 

28 # Shift the evaluation point to the zeroth lattice point (normally this is (0,0,0)) 

29 psi_T = atoms.T(psi, -(r - atoms.r[0])) 

30 psi_Trs = atoms.I(psi_T, 0) 

31 # Get the value at the zeroth lattice point 

32 return psi_Trs[0] 

33 

34 

35@handle_k(mode="skip") 

36def get_R(atoms, psi, fods): 

37 """Calculate transformation matrix to build Fermi orbitals. 

38 

39 Reference: J. Chem. Phys. 153, 084104. 

40 

41 Args: 

42 atoms: Atoms object. 

43 psi: Set of orbitals in reciprocal space. 

44 fods: Fermi-orbital descriptors. 

45 

46 Returns: 

47 Transformation matrix R. 

48 """ 

49 # We only calculate occupied orbitals 

50 R = xp.empty((len(fods), len(fods)), dtype=complex) 

51 

52 for i in range(len(fods)): 

53 # Get the value at one FOD position for all psi 

54 psi_fod = eval_psi(atoms, psi, fods[i]) 

55 sum_psi_fod = xp.sqrt(xp.sum(psi_fod.conj() * psi_fod)) 

56 for j in range(len(fods)): 

57 R[i, j] = psi_fod[j].conj() / sum_psi_fod 

58 return R 

59 

60 

61@handle_k(mode="skip") 

62def get_FO(atoms, psi, fods): 

63 """Calculate Fermi orbitals from Kohn-Sham orbitals. 

64 

65 Reference: J. Chem. Phys. 153, 084104. 

66 

67 Args: 

68 atoms: Atoms object. 

69 psi: Set of orbitals in reciprocal space. 

70 fods: Fermi-orbital descriptors. 

71 

72 Returns: 

73 Real-space Fermi orbitals. 

74 """ 

75 fo = xp.zeros((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex) 

76 

77 # Transform psi to real-space 

78 psi_rs = atoms.I(psi, 0) 

79 for spin in range(atoms.occ.Nspin): 

80 # Get the transformation matrix R 

81 R = get_R(atoms, psi[spin], fods[spin]) 

82 for i in range(len(R)): 

83 for j in range(atoms.occ.Nstate): 

84 if atoms.occ.f[0, spin, j] > 0: 

85 fo[spin, :, i] += R[i, j] * psi_rs[spin, :, j] 

86 return fo 

87 

88 

89@handle_spin 

90def get_S(atoms, psirs): 

91 """Calculate overlap matrix between orbitals. 

92 

93 Reference: J. Chem. Phys. 153, 084104. 

94 

95 Args: 

96 atoms: Atoms object. 

97 psirs: Set of orbitals in real-space. 

98 

99 Returns: 

100 Overlap matrix S. 

101 """ 

102 # Overlap elements: S_ij = \int psi_i^* psi_j dr 

103 S = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

104 

105 for i in range(atoms.occ.Nstate): 

106 for j in range(atoms.occ.Nstate): 

107 S[i, j] = atoms.dV * xp.sum(psirs[:, i].conj() * psirs[:, j]) 

108 return S 

109 

110 

111@handle_k(mode="skip") 

112def get_FLO(atoms, psi, fods): 

113 """Calculate Fermi-Loewdin orbitals by orthonormalizing Fermi orbitals. 

114 

115 Reference: J. Chem. Phys. 153, 084104. 

116 

117 Args: 

118 atoms: Atoms object. 

119 psi: Set of orbitals in reciprocal space. 

120 fods: Fermi-orbital descriptors. 

121 

122 Returns: 

123 Real-space Fermi-Loewdin orbitals. 

124 """ 

125 fo = get_FO(atoms, psi, fods) 

126 flo = xp.empty((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex) 

127 

128 for spin in range(atoms.occ.Nspin): 

129 # Calculate the overlap matrix for FOs 

130 S = get_S(atoms, fo[spin]) 

131 # Calculate eigenvalues and eigenvectors 

132 Q, T = xp.linalg.eig(S) 

133 # Loewdins symmetric orthonormalization method 

134 Q12 = xp.diag(1 / xp.sqrt(Q)) 

135 flo[spin] = fo[spin] @ (T @ Q12 @ T.T) 

136 return flo 

137 

138 

139@handle_k(mode="skip") 

140@handle_spin 

141def get_scdm(atoms, psi): 

142 """Calculate localized SCDM orbitals via QR decomposition. 

143 

144 Reference: J. Chem. Theory Comput. 11, 1463. 

145 

146 Args: 

147 atoms: Atoms object. 

148 psi: Set of orbitals in reciprocal space. 

149 

150 Returns: 

151 Real-space SCDM orbitals. 

152 """ 

153 # Transform psi to real-space 

154 psi_rs = atoms.I(psi, 0) 

155 

156 # Do the QR factorization 

157 Q, _, _ = qr(xp.to_np(psi_rs.T.conj()), pivoting=True) 

158 Q = xp.asarray(Q, dtype=psi_rs.dtype) 

159 

160 # Apply the transformation 

161 return psi_rs @ Q 

162 

163 

164@handle_k(mode="skip") 

165@handle_spin 

166def wannier_cost(atoms, psirs): 

167 """Calculate the Wannier cost function, namely the orbital variance. Equivalent to Foster-Boys. 

168 

169 This function does not account for periodicity, it may be a good idea to center the system. 

170 

171 Reference: J. Chem. Phys. 137, 224114. 

172 

173 Args: 

174 atoms: Atoms object. 

175 psirs: Set of orbitals in real-space. 

176 

177 Returns: 

178 Variance per orbital. 

179 """ 

180 # Variance = \int psi r^2 psi - (\int psi r psi)^2 

181 centers = wannier_center(atoms, psirs) 

182 moments = second_moment(atoms, psirs) 

183 costs = moments - xp.linalg.norm(centers, axis=1) ** 2 

184 log.debug(f"Centers:\n{centers}\nMoments:\n{moments}") 

185 log.info(f"Costs:\n{costs}") 

186 return costs 

187 

188 

189@handle_k(mode="skip") 

190@handle_spin 

191def wannier_center(atoms, psirs): 

192 """Calculate Wannier centers, i.e., the expectation values of r. 

193 

194 Reference: J. Chem. Phys. 137, 224114. 

195 

196 Args: 

197 atoms: Atoms object. 

198 psirs: Set of orbitals in real-space. 

199 

200 Returns: 

201 Wannier centers per orbital. 

202 """ 

203 centers = xp.empty((atoms.occ.Nstate, 3)) 

204 for i in range(atoms.occ.Nstate): 

205 for dim in range(3): 

206 centers[i, dim] = atoms.dV * xp.real( 

207 xp.sum(psirs[:, i].conj() * atoms.r[:, dim] * psirs[:, i], axis=0) 

208 ) 

209 return centers 

210 

211 

212@handle_k(mode="skip") 

213@handle_spin 

214def second_moment(atoms, psirs): 

215 """Calculate the second moments, i.e., the expectation values of r^2. 

216 

217 Reference: J. Chem. Phys. 137, 224114. 

218 

219 Args: 

220 atoms: Atoms object. 

221 psirs: Set of orbitals in real-space. 

222 

223 Returns: 

224 Second moments per orbital. 

225 """ 

226 r2 = xp.linalg.norm(atoms.r, axis=1) ** 2 

227 

228 moments = xp.empty(atoms.occ.Nstate) 

229 for i in range(atoms.occ.Nstate): 

230 moments[i] = atoms.dV * xp.real(xp.sum(psirs[:, i].conj() * r2 * psirs[:, i], axis=0)) 

231 return moments 

232 

233 

234@handle_spin 

235def wannier_supercell_matrices(atoms, psirs): 

236 """Calculate matrices for the supercell Wannier localization. 

237 

238 Reference: Phys. Rev. B 59, 9703. 

239 

240 Args: 

241 atoms: Atoms object. 

242 psirs: Set of orbitals in real-space. 

243 

244 Returns: 

245 Matrices X, Y, and Z. 

246 """ 

247 # Similar to the expectation value of r, but accounting for periodicity 

248 X = (psirs.conj().T * xp.exp(-1j * 2 * math.pi * atoms.r[:, 0] / atoms.a[0, 0])) @ psirs 

249 Y = (psirs.conj().T * xp.exp(-1j * 2 * math.pi * atoms.r[:, 1] / atoms.a[1, 1])) @ psirs 

250 Z = (psirs.conj().T * xp.exp(-1j * 2 * math.pi * atoms.r[:, 2] / atoms.a[2, 2])) @ psirs 

251 return X * atoms.dV, Y * atoms.dV, Z * atoms.dV 

252 

253 

254def wannier_supercell_cost(X, Y, Z): 

255 """Calculate the supercell Wannier cost. 

256 

257 This is an equivalent criterion to the spread criterion, but not the same. This cost function 

258 will be maximized instead of the minimization of the spread. 

259 

260 Reference: Phys. Rev. B 59, 9703. 

261 

262 Args: 

263 X: Calculation specific matrix. 

264 Y: Calculation specific matrix. 

265 Z: Calculation specific matrix. 

266 

267 Returns: 

268 Supercell Wannier cost. 

269 """ 

270 X2 = xp.abs(xp.diagonal(X)) ** 2 

271 Y2 = xp.abs(xp.diagonal(Y)) ** 2 

272 Z2 = xp.abs(xp.diagonal(Z)) ** 2 

273 return xp.sum(X2 + Y2 + Z2) 

274 

275 

276def wannier_supercell_grad(atoms, X, Y, Z): 

277 """Calculate the supercell Wannier gradient. 

278 

279 Reference: Phys. Rev. B 59, 9703. 

280 

281 Args: 

282 atoms: Atoms object. 

283 X: Calculation specific matrix. 

284 Y: Calculation specific matrix. 

285 Z: Calculation specific matrix. 

286 

287 Returns: 

288 Supercell Wannier gradient. 

289 """ 

290 x = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

291 y = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

292 z = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

293 # Just the indexed gradient from the paper, without fancy optimization 

294 for n in range(atoms.occ.Nstate): 

295 for m in range(atoms.occ.Nstate): 

296 x[m, n] = X[n, m] * (X[n, n].conj() - X[m, m].conj()) - X[m, n].conj() * ( 

297 X[m, m] - X[n, n] 

298 ) 

299 y[m, n] = Y[n, m] * (Y[n, n].conj() - Y[m, m].conj()) - Y[m, n].conj() * ( 

300 Y[m, m] - Y[n, n] 

301 ) 

302 z[m, n] = Z[n, m] * (Z[n, n].conj() - Z[m, m].conj()) - Z[m, n].conj() * ( 

303 Z[m, m] - Z[n, n] 

304 ) 

305 return x + y + z 

306 

307 

308@handle_k(mode="skip") 

309@handle_spin 

310def get_wannier(atoms, psirs, Nit=10000, conv_tol=1e-7, mu=1, random_guess=False, seed=None): 

311 """Steepest descent supercell Wannier localization. 

312 

313 This function is rather sensitive to the starting point, thus it is a good idea to start from 

314 already localized orbitals. 

315 

316 This optimizes the given orbitals under unitary constraint matrices, see 

317 IEEE Trans. Signal Process. 56, 1134. 

318 

319 Reference: Phys. Rev. B 59, 9703. 

320 

321 Args: 

322 atoms: Atoms object. 

323 psirs: Set of orbitals in real-space. 

324 

325 Keyword Args: 

326 Nit: Number of iterations. 

327 conv_tol: Convergence tolerance. 

328 mu: Step size. 

329 random_guess: Whether to use a random unitary starting guess or the identity. 

330 seed: Seed to get a reproducible random guess. 

331 

332 Returns: 

333 Localized orbitals. 

334 """ 

335 if config.backend == "torch": 

336 expm = xp.linalg.matrix_exp 

337 else: 

338 from scipy.linalg import expm 

339 

340 if not (xp.diag(xp.diag(atoms.a)) == atoms.a).all(): 

341 log.warning("The Wannier localization needs a cubic unit cell.") 

342 return psirs 

343 

344 X, Y, Z = wannier_supercell_matrices(atoms, psirs) # Calculate matrices only once 

345 # The initial unitary transformation is the identity or a random unitary matrix 

346 if random_guess and atoms.occ.Nstate > 1: 

347 U = xp.asarray(unitary_group.rvs(atoms.occ.Nstate, random_state=seed)) 

348 else: 

349 U = xp.eye(atoms.occ.Nstate, dtype=complex) 

350 costs = [0] # Add a zero to the costs to allow the sign evaluation in the first iteration 

351 

352 atoms._log.debug(f"{'Iteration':<11}{'Cost [a0^2]':<13}{'dCost [a0^2]':<13}") 

353 for i in range(Nit): 

354 sign = 1 

355 costs.append(wannier_supercell_cost(X, Y, Z)) 

356 if abs(costs[-1] - costs[-2]) < conv_tol: 

357 atoms._log.info(f"Wannier localizer converged after {i} iterations.") 

358 break 

359 # If the cost function gets smaller, change the direction 

360 if costs[-1] - costs[-2] > 0: 

361 sign = -1 

362 

363 # Calculate unitary transformation 

364 dOmega = wannier_supercell_grad(atoms, X, Y, Z) 

365 A = sign * mu * dOmega 

366 # dOmega is anti-hermitian, therefore calculate -A instead of A.conj().T 

367 # expm(A) will be unitary 

368 expA_pos, expA_neg = expm(A), expm(-A) 

369 # Update total rotation 

370 U = U @ expA_pos 

371 # Update matrices 

372 X = expA_neg @ X @ expA_pos 

373 Y = expA_neg @ Y @ expA_pos 

374 Z = expA_neg @ Z @ expA_pos 

375 

376 atoms._log.debug(f"{i:>8} {costs[-1]:<+13,.6f}{costs[-1] - costs[-2]:<+13,.4e}") 

377 

378 if len(costs) > 1 and abs(costs[-1] - costs[-2]) > conv_tol: 

379 atoms._log.warning("Wannier localizer not converged!") 

380 # Return the localized orbitals by rotating them 

381 return psirs @ U