Coverage for eminus/localizer.py: 100.00%
130 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Utilities to localize and analyze orbitals."""
5import numpy as np
6from scipy.linalg import eig, expm, norm, qr
7from scipy.stats import unitary_group
9from .logger import log
10from .utils import handle_k, handle_spin
13@handle_k(mode="skip")
14def eval_psi(atoms, psi, r):
15 """Evaluate orbitals at given coordinate points.
17 Args:
18 atoms: Atoms object.
19 psi: Set of orbitals in reciprocal space.
20 r: Real-space positions.
22 Returns:
23 Values of psi at points r.
24 """
25 # Shift the evaluation point to the zeroth lattice point (normally this is (0,0,0))
26 psi_T = atoms.T(psi, -(r - atoms.r[0]))
27 psi_Trs = atoms.I(psi_T, 0)
28 # Get the value at the zeroth lattice point
29 return psi_Trs[0]
32@handle_k(mode="skip")
33def get_R(atoms, psi, fods):
34 """Calculate transformation matrix to build Fermi orbitals.
36 Reference: J. Chem. Phys. 153, 084104.
38 Args:
39 atoms: Atoms object.
40 psi: Set of orbitals in reciprocal space.
41 fods: Fermi-orbital descriptors.
43 Returns:
44 Transformation matrix R.
45 """
46 # We only calculate occupied orbitals
47 R = np.empty((len(fods), len(fods)), dtype=complex)
49 for i in range(len(fods)):
50 # Get the value at one FOD position for all psi
51 psi_fod = eval_psi(atoms, psi, fods[i])
52 sum_psi_fod = np.sqrt(np.sum(psi_fod.conj() * psi_fod))
53 for j in range(len(fods)):
54 R[i, j] = psi_fod[j].conj() / sum_psi_fod
55 return R
58@handle_k(mode="skip")
59def get_FO(atoms, psi, fods):
60 """Calculate Fermi orbitals from Kohn-Sham orbitals.
62 Reference: J. Chem. Phys. 153, 084104.
64 Args:
65 atoms: Atoms object.
66 psi: Set of orbitals in reciprocal space.
67 fods: Fermi-orbital descriptors.
69 Returns:
70 Real-space Fermi orbitals.
71 """
72 fo = np.zeros((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex)
74 # Transform psi to real-space
75 psi_rs = atoms.I(psi, 0)
76 for spin in range(atoms.occ.Nspin):
77 # Get the transformation matrix R
78 R = get_R(atoms, psi[spin], fods[spin])
79 for i in range(len(R)):
80 for j in range(atoms.occ.Nstate):
81 if atoms.occ.f[0, spin, j] > 0:
82 fo[spin, :, i] += R[i, j] * psi_rs[spin, :, j]
83 return fo
86@handle_spin
87def get_S(atoms, psirs):
88 """Calculate overlap matrix between orbitals.
90 Reference: J. Chem. Phys. 153, 084104.
92 Args:
93 atoms: Atoms object.
94 psirs: Set of orbitals in real-space.
96 Returns:
97 Overlap matrix S.
98 """
99 # Overlap elements: S_ij = \int psi_i^* psi_j dr
100 S = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
102 for i in range(atoms.occ.Nstate):
103 for j in range(atoms.occ.Nstate):
104 S[i, j] = atoms.dV * np.sum(psirs[:, i].conj() * psirs[:, j])
105 return S
108@handle_k(mode="skip")
109def get_FLO(atoms, psi, fods):
110 """Calculate Fermi-Loewdin orbitals by orthonormalizing Fermi orbitals.
112 Reference: J. Chem. Phys. 153, 084104.
114 Args:
115 atoms: Atoms object.
116 psi: Set of orbitals in reciprocal space.
117 fods: Fermi-orbital descriptors.
119 Returns:
120 Real-space Fermi-Loewdin orbitals.
121 """
122 fo = get_FO(atoms, psi, fods)
123 flo = np.empty((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex)
125 for spin in range(atoms.occ.Nspin):
126 # Calculate the overlap matrix for FOs
127 S = get_S(atoms, fo[spin])
128 # Calculate eigenvalues and eigenvectors
129 Q, T = eig(S)
130 # Loewdins symmetric orthonormalization method
131 Q12 = np.diag(1 / np.sqrt(Q))
132 flo[spin] = fo[spin] @ (T @ Q12 @ T.T)
133 return flo
136@handle_k(mode="skip")
137@handle_spin
138def get_scdm(atoms, psi):
139 """Calculate localized orbitals via QR decomposition, as given in the SCDM method.
141 Reference: J. Chem. Theory Comput. 11, 1463.
143 Args:
144 atoms: Atoms object.
145 psi: Set of orbitals in reciprocal space.
147 Returns:
148 Real-space SCDM orbitals.
149 """
150 # Transform psi to real-space
151 psi_rs = atoms.I(psi, 0)
153 # Do the QR factorization
154 Q, _, _ = qr(psi_rs.T.conj(), pivoting=True)
155 # Apply the transformation
156 return psi_rs @ Q
159@handle_k(mode="skip")
160@handle_spin
161def wannier_cost(atoms, psirs):
162 """Calculate the Wannier cost function, namely the orbital variance. Equivalent to Foster-Boys.
164 This function does not account for periodicity, it may be a good idea to center the system.
166 Reference: J. Chem. Phys. 137, 224114.
168 Args:
169 atoms: Atoms object.
170 psirs: Set of orbitals in real-space.
172 Returns:
173 Variance per orbital.
174 """
175 # Variance = \int psi r^2 psi - (\int psi r psi)^2
176 centers = wannier_center(atoms, psirs)
177 moments = second_moment(atoms, psirs)
178 costs = moments - norm(centers, axis=1) ** 2
179 log.debug(f"Centers:\n{centers}\nMoments:\n{moments}")
180 log.info(f"Costs:\n{costs}")
181 return costs
184@handle_k(mode="skip")
185@handle_spin
186def wannier_center(atoms, psirs):
187 """Calculate Wannier centers, i.e., the expectation values of r.
189 Reference: J. Chem. Phys. 137, 224114.
191 Args:
192 atoms: Atoms object.
193 psirs: Set of orbitals in real-space.
195 Returns:
196 Wannier centers per orbital.
197 """
198 centers = np.empty((atoms.occ.Nstate, 3))
199 for i in range(atoms.occ.Nstate):
200 for dim in range(3):
201 centers[i, dim] = atoms.dV * np.real(
202 np.sum(psirs[:, i].conj() * atoms.r[:, dim] * psirs[:, i], axis=0)
203 )
204 return centers
207@handle_k(mode="skip")
208@handle_spin
209def second_moment(atoms, psirs):
210 """Calculate the second moments, i.e., the expectation values of r^2.
212 Reference: J. Chem. Phys. 137, 224114.
214 Args:
215 atoms: Atoms object.
216 psirs: Set of orbitals in real-space.
218 Returns:
219 Second moments per orbital.
220 """
221 r2 = norm(atoms.r, axis=1) ** 2
223 moments = np.empty(atoms.occ.Nstate)
224 for i in range(atoms.occ.Nstate):
225 moments[i] = atoms.dV * np.real(np.sum(psirs[:, i].conj() * r2 * psirs[:, i], axis=0))
226 return moments
229@handle_spin
230def wannier_supercell_matrices(atoms, psirs):
231 """Calculate matrices for the supercell Wannier localization.
233 Reference: Phys. Rev. B 59, 9703.
235 Args:
236 atoms: Atoms object.
237 psirs: Set of orbitals in real-space.
239 Returns:
240 Matrices X, Y, and Z.
241 """
242 # Similar to the expectation value of r, but accounting for periodicity
243 X = (psirs.conj().T * np.exp(-1j * 2 * np.pi * atoms.r[:, 0] / atoms.a[0, 0])) @ psirs
244 Y = (psirs.conj().T * np.exp(-1j * 2 * np.pi * atoms.r[:, 1] / atoms.a[1, 1])) @ psirs
245 Z = (psirs.conj().T * np.exp(-1j * 2 * np.pi * atoms.r[:, 2] / atoms.a[2, 2])) @ psirs
246 return X * atoms.dV, Y * atoms.dV, Z * atoms.dV
249def wannier_supercell_cost(X, Y, Z):
250 """Calculate the supercell Wannier cost.
252 This is an equivalent criterion to the spread criterion, but not the same. This cost function
253 will be maximized instead of the minimization of the spread.
255 Reference: Phys. Rev. B 59, 9703.
257 Args:
258 X: Calculation specific matrix.
259 Y: Calculation specific matrix.
260 Z: Calculation specific matrix.
262 Returns:
263 Supercell Wannier cost.
264 """
265 X2 = np.abs(np.diagonal(X)) ** 2
266 Y2 = np.abs(np.diagonal(Y)) ** 2
267 Z2 = np.abs(np.diagonal(Z)) ** 2
268 return np.sum(X2 + Y2 + Z2)
271def wannier_supercell_grad(atoms, X, Y, Z):
272 """Calculate the supercell Wannier gradient.
274 Reference: Phys. Rev. B 59, 9703.
276 Args:
277 atoms: Atoms object.
278 X: Calculation specific matrix.
279 Y: Calculation specific matrix.
280 Z: Calculation specific matrix.
282 Returns:
283 Supercell Wannier gradient.
284 """
285 x = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
286 y = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
287 z = np.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
288 # Just the indexed gradient from the paper, without fancy optimization
289 for n in range(atoms.occ.Nstate):
290 for m in range(atoms.occ.Nstate):
291 x[m, n] = X[n, m] * (X[n, n].conj() - X[m, m].conj()) - X[m, n].conj() * (
292 X[m, m] - X[n, n]
293 )
294 y[m, n] = Y[n, m] * (Y[n, n].conj() - Y[m, m].conj()) - Y[m, n].conj() * (
295 Y[m, m] - Y[n, n]
296 )
297 z[m, n] = Z[n, m] * (Z[n, n].conj() - Z[m, m].conj()) - Z[m, n].conj() * (
298 Z[m, m] - Z[n, n]
299 )
300 return x + y + z
303@handle_k(mode="skip")
304@handle_spin
305def get_wannier(atoms, psirs, Nit=10000, conv_tol=1e-7, mu=1, random_guess=False, seed=None):
306 """Steepest descent supercell Wannier localization.
308 This function is rather sensitive to the starting point, thus it is a good idea to start from
309 already localized orbitals.
311 This optimizes the given orbitals under unitary constraint matrices, see
312 IEEE Trans. Signal Process. 56, 1134.
314 Reference: Phys. Rev. B 59, 9703.
316 Args:
317 atoms: Atoms object.
318 psirs: Set of orbitals in real-space.
320 Keyword Args:
321 Nit: Number of iterations.
322 conv_tol: Convergence tolerance.
323 mu: Step size.
324 random_guess: Whether to use a random unitary starting guess or the identity.
325 seed: Seed to get a reproducible random guess.
327 Returns:
328 Localized orbitals.
329 """
330 if not (np.diag(np.diag(atoms.a)) == atoms.a).all():
331 log.warning("The Wannier localization needs a cubic unit cell.")
332 return psirs
334 X, Y, Z = wannier_supercell_matrices(atoms, psirs) # Calculate matrices only once
335 # The initial unitary transformation is the identity or a random unitary matrix
336 if random_guess and atoms.occ.Nstate > 1:
337 U = unitary_group.rvs(atoms.occ.Nstate, random_state=seed)
338 else:
339 U = np.eye(atoms.occ.Nstate)
340 costs = [0] # Add a zero to the costs to allow the sign evaluation in the first iteration
342 atoms._log.debug(f"{'Iteration':<11}{'Cost [a0^2]':<13}{'dCost [a0^2]':<13}")
343 for i in range(Nit):
344 sign = 1
345 costs.append(wannier_supercell_cost(X, Y, Z))
346 if abs(costs[-2] - costs[-1]) < conv_tol:
347 atoms._log.info(f"Wannier localizer converged after {i} iterations.")
348 break
349 # If the cost function gets smaller, change the direction
350 if costs[-2] - costs[-1] < 0:
351 sign = -1
353 # Calculate unitary transformation
354 dOmega = wannier_supercell_grad(atoms, X, Y, Z)
355 A = sign * mu * dOmega
356 # dOmega is anti-hermitian, therefore calculate -A instead of A.conj().T
357 # expm(A) will be unitary
358 expA_pos, expA_neg = expm(A), expm(-A)
359 # Update total rotation
360 U = U @ expA_pos
361 # Update matrices
362 X = expA_neg @ X @ expA_pos
363 Y = expA_neg @ Y @ expA_pos
364 Z = expA_neg @ Z @ expA_pos
366 atoms._log.debug(f"{i:>8} {costs[-1]:<+13,.6f}{costs[-2] - costs[-1]:<+13,.4e}")
368 if len(costs) > 1 and abs(costs[-2] - costs[-1]) > conv_tol:
369 atoms._log.warning("Wannier localizer not converged!")
370 # Return the localized orbitals by rotating them
371 return psirs @ U