Coverage for eminus/localizer.py: 100.00%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-18 08:43 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Utilities to localize and analyze orbitals.""" 

4 

5import numpy as np 

6from scipy.linalg import eig, expm, norm, qr 

7from scipy.stats import unitary_group 

8 

9from .logger import log 

10from .utils import handle_k, handle_spin 

11 

12 

13@handle_k(mode="skip") 

14def eval_psi(atoms, psi, r): 

15 """Evaluate orbitals at given coordinate points. 

16 

17 Args: 

18 atoms: Atoms object. 

19 psi: Set of orbitals in reciprocal space. 

20 r: Real-space positions. 

21 

22 Returns: 

23 Values of psi at points r. 

24 """ 

25 # Shift the evaluation point to the zeroth lattice point (normally this is (0,0,0)) 

26 psi_T = atoms.T(psi, -(r - atoms.r[0])) 

27 psi_Trs = atoms.I(psi_T, 0) 

28 # Get the value at the zeroth lattice point 

29 return psi_Trs[0] 

30 

31 

32@handle_k(mode="skip") 

33def get_R(atoms, psi, fods): 

34 """Calculate transformation matrix to build Fermi orbitals. 

35 

36 Reference: J. Chem. Phys. 153, 084104. 

37 

38 Args: 

39 atoms: Atoms object. 

40 psi: Set of orbitals in reciprocal space. 

41 fods: Fermi-orbital descriptors. 

42 

43 Returns: 

44 Transformation matrix R. 

45 """ 

46 # We only calculate occupied orbitals 

47 R = np.empty((len(fods), len(fods)), dtype=complex) 

48 

49 for i in range(len(fods)): 

50 # Get the value at one FOD position for all psi 

51 psi_fod = eval_psi(atoms, psi, fods[i]) 

52 sum_psi_fod = np.sqrt(np.sum(psi_fod.conj() * psi_fod)) 

53 for j in range(len(fods)): 

54 R[i, j] = psi_fod[j].conj() / sum_psi_fod 

55 return R 

56 

57 

58@handle_k(mode="skip") 

59def get_FO(atoms, psi, fods): 

60 """Calculate Fermi orbitals from Kohn-Sham orbitals. 

61 

62 Reference: J. Chem. Phys. 153, 084104. 

63 

64 Args: 

65 atoms: Atoms object. 

66 psi: Set of orbitals in reciprocal space. 

67 fods: Fermi-orbital descriptors. 

68 

69 Returns: 

70 Real-space Fermi orbitals. 

71 """ 

72 fo = np.zeros((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex) 

73 

74 # Transform psi to real-space 

75 psi_rs = atoms.I(psi, 0) 

76 for spin in range(atoms.occ.Nspin): 

77 # Get the transformation matrix R 

78 R = get_R(atoms, psi[spin], fods[spin]) 

79 for i in range(len(R)): 

80 for j in range(atoms.occ.Nstate): 

81 if atoms.occ.f[0, spin, j] > 0: 

82 fo[spin, :, i] += R[i, j] * psi_rs[spin, :, j] 

83 return fo 

84 

85 

86@handle_spin 

87def get_S(atoms, psirs): 

88 """Calculate overlap matrix between orbitals. 

89 

90 Reference: J. Chem. Phys. 153, 084104. 

91 

92 Args: 

93 atoms: Atoms object. 

94 psirs: Set of orbitals in real-space. 

95 

96 Returns: 

97 Overlap matrix S. 

98 """ 

99 # Overlap elements: S_ij = \int psi_i^* psi_j dr 

100 S = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

101 

102 for i in range(atoms.occ.Nstate): 

103 for j in range(atoms.occ.Nstate): 

104 S[i, j] = atoms.dV * np.sum(psirs[:, i].conj() * psirs[:, j]) 

105 return S 

106 

107 

108@handle_k(mode="skip") 

109def get_FLO(atoms, psi, fods): 

110 """Calculate Fermi-Loewdin orbitals by orthonormalizing Fermi orbitals. 

111 

112 Reference: J. Chem. Phys. 153, 084104. 

113 

114 Args: 

115 atoms: Atoms object. 

116 psi: Set of orbitals in reciprocal space. 

117 fods: Fermi-orbital descriptors. 

118 

119 Returns: 

120 Real-space Fermi-Loewdin orbitals. 

121 """ 

122 fo = get_FO(atoms, psi, fods) 

123 flo = np.empty((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex) 

124 

125 for spin in range(atoms.occ.Nspin): 

126 # Calculate the overlap matrix for FOs 

127 S = get_S(atoms, fo[spin]) 

128 # Calculate eigenvalues and eigenvectors 

129 Q, T = eig(S) 

130 # Loewdins symmetric orthonormalization method 

131 Q12 = np.diag(1 / np.sqrt(Q)) 

132 flo[spin] = fo[spin] @ (T @ Q12 @ T.T) 

133 return flo 

134 

135 

136@handle_k(mode="skip") 

137@handle_spin 

138def get_scdm(atoms, psi): 

139 """Calculate localized orbitals via QR decomposition, as given in the SCDM method. 

140 

141 Reference: J. Chem. Theory Comput. 11, 1463. 

142 

143 Args: 

144 atoms: Atoms object. 

145 psi: Set of orbitals in reciprocal space. 

146 

147 Returns: 

148 Real-space SCDM orbitals. 

149 """ 

150 # Transform psi to real-space 

151 psi_rs = atoms.I(psi, 0) 

152 

153 # Do the QR factorization 

154 Q, _, _ = qr(psi_rs.T.conj(), pivoting=True) 

155 # Apply the transformation 

156 return psi_rs @ Q 

157 

158 

159@handle_k(mode="skip") 

160@handle_spin 

161def wannier_cost(atoms, psirs): 

162 """Calculate the Wannier cost function, namely the orbital variance. Equivalent to Foster-Boys. 

163 

164 This function does not account for periodicity, it may be a good idea to center the system. 

165 

166 Reference: J. Chem. Phys. 137, 224114. 

167 

168 Args: 

169 atoms: Atoms object. 

170 psirs: Set of orbitals in real-space. 

171 

172 Returns: 

173 Variance per orbital. 

174 """ 

175 # Variance = \int psi r^2 psi - (\int psi r psi)^2 

176 centers = wannier_center(atoms, psirs) 

177 moments = second_moment(atoms, psirs) 

178 costs = moments - norm(centers, axis=1) ** 2 

179 log.debug(f"Centers:\n{centers}\nMoments:\n{moments}") 

180 log.info(f"Costs:\n{costs}") 

181 return costs 

182 

183 

184@handle_k(mode="skip") 

185@handle_spin 

186def wannier_center(atoms, psirs): 

187 """Calculate Wannier centers, i.e., the expectation values of r. 

188 

189 Reference: J. Chem. Phys. 137, 224114. 

190 

191 Args: 

192 atoms: Atoms object. 

193 psirs: Set of orbitals in real-space. 

194 

195 Returns: 

196 Wannier centers per orbital. 

197 """ 

198 centers = np.empty((atoms.occ.Nstate, 3)) 

199 for i in range(atoms.occ.Nstate): 

200 for dim in range(3): 

201 centers[i, dim] = atoms.dV * np.real( 

202 np.sum(psirs[:, i].conj() * atoms.r[:, dim] * psirs[:, i], axis=0) 

203 ) 

204 return centers 

205 

206 

207@handle_k(mode="skip") 

208@handle_spin 

209def second_moment(atoms, psirs): 

210 """Calculate the second moments, i.e., the expectation values of r^2. 

211 

212 Reference: J. Chem. Phys. 137, 224114. 

213 

214 Args: 

215 atoms: Atoms object. 

216 psirs: Set of orbitals in real-space. 

217 

218 Returns: 

219 Second moments per orbital. 

220 """ 

221 r2 = norm(atoms.r, axis=1) ** 2 

222 

223 moments = np.empty(atoms.occ.Nstate) 

224 for i in range(atoms.occ.Nstate): 

225 moments[i] = atoms.dV * np.real(np.sum(psirs[:, i].conj() * r2 * psirs[:, i], axis=0)) 

226 return moments 

227 

228 

229@handle_spin 

230def wannier_supercell_matrices(atoms, psirs): 

231 """Calculate matrices for the supercell Wannier localization. 

232 

233 Reference: Phys. Rev. B 59, 9703. 

234 

235 Args: 

236 atoms: Atoms object. 

237 psirs: Set of orbitals in real-space. 

238 

239 Returns: 

240 Matrices X, Y, and Z. 

241 """ 

242 # Similar to the expectation value of r, but accounting for periodicity 

243 X = (psirs.conj().T * np.exp(-1j * 2 * np.pi * atoms.r[:, 0] / atoms.a[0, 0])) @ psirs 

244 Y = (psirs.conj().T * np.exp(-1j * 2 * np.pi * atoms.r[:, 1] / atoms.a[1, 1])) @ psirs 

245 Z = (psirs.conj().T * np.exp(-1j * 2 * np.pi * atoms.r[:, 2] / atoms.a[2, 2])) @ psirs 

246 return X * atoms.dV, Y * atoms.dV, Z * atoms.dV 

247 

248 

249def wannier_supercell_cost(X, Y, Z): 

250 """Calculate the supercell Wannier cost. 

251 

252 This is an equivalent criterion to the spread criterion, but not the same. This cost function 

253 will be maximized instead of the minimization of the spread. 

254 

255 Reference: Phys. Rev. B 59, 9703. 

256 

257 Args: 

258 X: Calculation specific matrix. 

259 Y: Calculation specific matrix. 

260 Z: Calculation specific matrix. 

261 

262 Returns: 

263 Supercell Wannier cost. 

264 """ 

265 X2 = np.abs(np.diagonal(X)) ** 2 

266 Y2 = np.abs(np.diagonal(Y)) ** 2 

267 Z2 = np.abs(np.diagonal(Z)) ** 2 

268 return np.sum(X2 + Y2 + Z2) 

269 

270 

271def wannier_supercell_grad(atoms, X, Y, Z): 

272 """Calculate the supercell Wannier gradient. 

273 

274 Reference: Phys. Rev. B 59, 9703. 

275 

276 Args: 

277 atoms: Atoms object. 

278 X: Calculation specific matrix. 

279 Y: Calculation specific matrix. 

280 Z: Calculation specific matrix. 

281 

282 Returns: 

283 Supercell Wannier gradient. 

284 """ 

285 x = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

286 y = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

287 z = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex) 

288 # Just the indexed gradient from the paper, without fancy optimization 

289 for n in range(atoms.occ.Nstate): 

290 for m in range(atoms.occ.Nstate): 

291 x[m, n] = X[n, m] * (X[n, n].conj() - X[m, m].conj()) - X[m, n].conj() * ( 

292 X[m, m] - X[n, n] 

293 ) 

294 y[m, n] = Y[n, m] * (Y[n, n].conj() - Y[m, m].conj()) - Y[m, n].conj() * ( 

295 Y[m, m] - Y[n, n] 

296 ) 

297 z[m, n] = Z[n, m] * (Z[n, n].conj() - Z[m, m].conj()) - Z[m, n].conj() * ( 

298 Z[m, m] - Z[n, n] 

299 ) 

300 return x + y + z 

301 

302 

303@handle_k(mode="skip") 

304@handle_spin 

305def get_wannier(atoms, psirs, Nit=10000, conv_tol=1e-7, mu=1, random_guess=False, seed=None): 

306 """Steepest descent supercell Wannier localization. 

307 

308 This function is rather sensitive to the starting point, thus it is a good idea to start from 

309 already localized orbitals. 

310 

311 This optimizes the given orbitals under unitary constraint matrices, see 

312 IEEE Trans. Signal Process. 56, 1134. 

313 

314 Reference: Phys. Rev. B 59, 9703. 

315 

316 Args: 

317 atoms: Atoms object. 

318 psirs: Set of orbitals in real-space. 

319 

320 Keyword Args: 

321 Nit: Number of iterations. 

322 conv_tol: Convergence tolerance. 

323 mu: Step size. 

324 random_guess: Whether to use a random unitary starting guess or the identity. 

325 seed: Seed to get a reproducible random guess. 

326 

327 Returns: 

328 Localized orbitals. 

329 """ 

330 if not (np.diag(np.diag(atoms.a)) == atoms.a).all(): 

331 log.warning("The Wannier localization needs a cubic unit cell.") 

332 return psirs 

333 

334 X, Y, Z = wannier_supercell_matrices(atoms, psirs) # Calculate matrices only once 

335 # The initial unitary transformation is the identity or a random unitary matrix 

336 if random_guess and atoms.occ.Nstate > 1: 

337 U = unitary_group.rvs(atoms.occ.Nstate, random_state=seed) 

338 else: 

339 U = np.eye(atoms.occ.Nstate) 

340 costs = [0] # Add a zero to the costs to allow the sign evaluation in the first iteration 

341 

342 atoms._log.debug(f"{'Iteration':<11}{'Cost [a0^2]':<13}{'dCost [a0^2]':<13}") 

343 for i in range(Nit): 

344 sign = 1 

345 costs.append(wannier_supercell_cost(X, Y, Z)) 

346 if abs(costs[-2] - costs[-1]) < conv_tol: 

347 atoms._log.info(f"Wannier localizer converged after {i} iterations.") 

348 break 

349 # If the cost function gets smaller, change the direction 

350 if costs[-2] - costs[-1] < 0: 

351 sign = -1 

352 

353 # Calculate unitary transformation 

354 dOmega = wannier_supercell_grad(atoms, X, Y, Z) 

355 A = sign * mu * dOmega 

356 # dOmega is anti-hermitian, therefore calculate -A instead of A.conj().T 

357 # expm(A) will be unitary 

358 expA_pos, expA_neg = expm(A), expm(-A) 

359 # Update total rotation 

360 U = U @ expA_pos 

361 # Update matrices 

362 X = expA_neg @ X @ expA_pos 

363 Y = expA_neg @ Y @ expA_pos 

364 Z = expA_neg @ Z @ expA_pos 

365 

366 atoms._log.debug(f"{i:>8} {costs[-1]:<+13,.6f}{costs[-2] - costs[-1]:<+13,.4e}") 

367 

368 if len(costs) > 1 and abs(costs[-2] - costs[-1]) > conv_tol: 

369 atoms._log.warning("Wannier localizer not converged!") 

370 # Return the localized orbitals by rotating them 

371 return psirs @ U