Coverage for eminus/localizer.py: 99.26%
136 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Utilities to localize and analyze orbitals."""
5import math
7from scipy.linalg import qr
8from scipy.stats import unitary_group
10from . import backend as xp
11from . import config
12from .logger import log
13from .utils import handle_k, handle_spin
16@handle_k(mode="skip")
17def eval_psi(atoms, psi, r):
18 """Evaluate orbitals at given coordinate points.
20 Args:
21 atoms: Atoms object.
22 psi: Set of orbitals in reciprocal space.
23 r: Real-space positions.
25 Returns:
26 Values of psi at points r.
27 """
28 # Shift the evaluation point to the zeroth lattice point (normally this is (0,0,0))
29 psi_T = atoms.T(psi, -(r - atoms.r[0]))
30 psi_Trs = atoms.I(psi_T, 0)
31 # Get the value at the zeroth lattice point
32 return psi_Trs[0]
35@handle_k(mode="skip")
36def get_R(atoms, psi, fods):
37 """Calculate transformation matrix to build Fermi orbitals.
39 Reference: J. Chem. Phys. 153, 084104.
41 Args:
42 atoms: Atoms object.
43 psi: Set of orbitals in reciprocal space.
44 fods: Fermi-orbital descriptors.
46 Returns:
47 Transformation matrix R.
48 """
49 # We only calculate occupied orbitals
50 R = xp.empty((len(fods), len(fods)), dtype=complex)
52 for i in range(len(fods)):
53 # Get the value at one FOD position for all psi
54 psi_fod = eval_psi(atoms, psi, fods[i])
55 sum_psi_fod = xp.sqrt(xp.sum(psi_fod.conj() * psi_fod))
56 for j in range(len(fods)):
57 R[i, j] = psi_fod[j].conj() / sum_psi_fod
58 return R
61@handle_k(mode="skip")
62def get_FO(atoms, psi, fods):
63 """Calculate Fermi orbitals from Kohn-Sham orbitals.
65 Reference: J. Chem. Phys. 153, 084104.
67 Args:
68 atoms: Atoms object.
69 psi: Set of orbitals in reciprocal space.
70 fods: Fermi-orbital descriptors.
72 Returns:
73 Real-space Fermi orbitals.
74 """
75 fo = xp.zeros((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex)
77 # Transform psi to real-space
78 psi_rs = atoms.I(psi, 0)
79 for spin in range(atoms.occ.Nspin):
80 # Get the transformation matrix R
81 R = get_R(atoms, psi[spin], fods[spin])
82 for i in range(len(R)):
83 for j in range(atoms.occ.Nstate):
84 if atoms.occ.f[0, spin, j] > 0:
85 fo[spin, :, i] += R[i, j] * psi_rs[spin, :, j]
86 return fo
89@handle_spin
90def get_S(atoms, psirs):
91 """Calculate overlap matrix between orbitals.
93 Reference: J. Chem. Phys. 153, 084104.
95 Args:
96 atoms: Atoms object.
97 psirs: Set of orbitals in real-space.
99 Returns:
100 Overlap matrix S.
101 """
102 # Overlap elements: S_ij = \int psi_i^* psi_j dr
103 S = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
105 for i in range(atoms.occ.Nstate):
106 for j in range(atoms.occ.Nstate):
107 S[i, j] = atoms.dV * xp.sum(psirs[:, i].conj() * psirs[:, j])
108 return S
111@handle_k(mode="skip")
112def get_FLO(atoms, psi, fods):
113 """Calculate Fermi-Loewdin orbitals by orthonormalizing Fermi orbitals.
115 Reference: J. Chem. Phys. 153, 084104.
117 Args:
118 atoms: Atoms object.
119 psi: Set of orbitals in reciprocal space.
120 fods: Fermi-orbital descriptors.
122 Returns:
123 Real-space Fermi-Loewdin orbitals.
124 """
125 fo = get_FO(atoms, psi, fods)
126 flo = xp.empty((atoms.occ.Nspin, atoms.Ns, atoms.occ.Nstate), dtype=complex)
128 for spin in range(atoms.occ.Nspin):
129 # Calculate the overlap matrix for FOs
130 S = get_S(atoms, fo[spin])
131 # Calculate eigenvalues and eigenvectors
132 Q, T = xp.linalg.eig(S)
133 # Loewdins symmetric orthonormalization method
134 Q12 = xp.diag(1 / xp.sqrt(Q))
135 flo[spin] = fo[spin] @ (T @ Q12 @ T.T)
136 return flo
139@handle_k(mode="skip")
140@handle_spin
141def get_scdm(atoms, psi):
142 """Calculate localized SCDM orbitals via QR decomposition.
144 Reference: J. Chem. Theory Comput. 11, 1463.
146 Args:
147 atoms: Atoms object.
148 psi: Set of orbitals in reciprocal space.
150 Returns:
151 Real-space SCDM orbitals.
152 """
153 # Transform psi to real-space
154 psi_rs = atoms.I(psi, 0)
156 # Do the QR factorization
157 Q, _, _ = qr(xp.to_np(psi_rs.T.conj()), pivoting=True)
158 Q = xp.asarray(Q, dtype=psi_rs.dtype)
160 # Apply the transformation
161 return psi_rs @ Q
164@handle_k(mode="skip")
165@handle_spin
166def wannier_cost(atoms, psirs):
167 """Calculate the Wannier cost function, namely the orbital variance. Equivalent to Foster-Boys.
169 This function does not account for periodicity, it may be a good idea to center the system.
171 Reference: J. Chem. Phys. 137, 224114.
173 Args:
174 atoms: Atoms object.
175 psirs: Set of orbitals in real-space.
177 Returns:
178 Variance per orbital.
179 """
180 # Variance = \int psi r^2 psi - (\int psi r psi)^2
181 centers = wannier_center(atoms, psirs)
182 moments = second_moment(atoms, psirs)
183 costs = moments - xp.linalg.norm(centers, axis=1) ** 2
184 log.debug(f"Centers:\n{centers}\nMoments:\n{moments}")
185 log.info(f"Costs:\n{costs}")
186 return costs
189@handle_k(mode="skip")
190@handle_spin
191def wannier_center(atoms, psirs):
192 """Calculate Wannier centers, i.e., the expectation values of r.
194 Reference: J. Chem. Phys. 137, 224114.
196 Args:
197 atoms: Atoms object.
198 psirs: Set of orbitals in real-space.
200 Returns:
201 Wannier centers per orbital.
202 """
203 centers = xp.empty((atoms.occ.Nstate, 3))
204 for i in range(atoms.occ.Nstate):
205 for dim in range(3):
206 centers[i, dim] = atoms.dV * xp.real(
207 xp.sum(psirs[:, i].conj() * atoms.r[:, dim] * psirs[:, i], axis=0)
208 )
209 return centers
212@handle_k(mode="skip")
213@handle_spin
214def second_moment(atoms, psirs):
215 """Calculate the second moments, i.e., the expectation values of r^2.
217 Reference: J. Chem. Phys. 137, 224114.
219 Args:
220 atoms: Atoms object.
221 psirs: Set of orbitals in real-space.
223 Returns:
224 Second moments per orbital.
225 """
226 r2 = xp.linalg.norm(atoms.r, axis=1) ** 2
228 moments = xp.empty(atoms.occ.Nstate)
229 for i in range(atoms.occ.Nstate):
230 moments[i] = atoms.dV * xp.real(xp.sum(psirs[:, i].conj() * r2 * psirs[:, i], axis=0))
231 return moments
234@handle_spin
235def wannier_supercell_matrices(atoms, psirs):
236 """Calculate matrices for the supercell Wannier localization.
238 Reference: Phys. Rev. B 59, 9703.
240 Args:
241 atoms: Atoms object.
242 psirs: Set of orbitals in real-space.
244 Returns:
245 Matrices X, Y, and Z.
246 """
247 # Similar to the expectation value of r, but accounting for periodicity
248 X = (psirs.conj().T * xp.exp(-1j * 2 * math.pi * atoms.r[:, 0] / atoms.a[0, 0])) @ psirs
249 Y = (psirs.conj().T * xp.exp(-1j * 2 * math.pi * atoms.r[:, 1] / atoms.a[1, 1])) @ psirs
250 Z = (psirs.conj().T * xp.exp(-1j * 2 * math.pi * atoms.r[:, 2] / atoms.a[2, 2])) @ psirs
251 return X * atoms.dV, Y * atoms.dV, Z * atoms.dV
254def wannier_supercell_cost(X, Y, Z):
255 """Calculate the supercell Wannier cost.
257 This is an equivalent criterion to the spread criterion, but not the same. This cost function
258 will be maximized instead of the minimization of the spread.
260 Reference: Phys. Rev. B 59, 9703.
262 Args:
263 X: Calculation specific matrix.
264 Y: Calculation specific matrix.
265 Z: Calculation specific matrix.
267 Returns:
268 Supercell Wannier cost.
269 """
270 X2 = xp.abs(xp.diagonal(X)) ** 2
271 Y2 = xp.abs(xp.diagonal(Y)) ** 2
272 Z2 = xp.abs(xp.diagonal(Z)) ** 2
273 return xp.sum(X2 + Y2 + Z2)
276def wannier_supercell_grad(atoms, X, Y, Z):
277 """Calculate the supercell Wannier gradient.
279 Reference: Phys. Rev. B 59, 9703.
281 Args:
282 atoms: Atoms object.
283 X: Calculation specific matrix.
284 Y: Calculation specific matrix.
285 Z: Calculation specific matrix.
287 Returns:
288 Supercell Wannier gradient.
289 """
290 x = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
291 y = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
292 z = xp.empty((atoms.occ.Nstate, atoms.occ.Nstate), dtype=complex)
293 # Just the indexed gradient from the paper, without fancy optimization
294 for n in range(atoms.occ.Nstate):
295 for m in range(atoms.occ.Nstate):
296 x[m, n] = X[n, m] * (X[n, n].conj() - X[m, m].conj()) - X[m, n].conj() * (
297 X[m, m] - X[n, n]
298 )
299 y[m, n] = Y[n, m] * (Y[n, n].conj() - Y[m, m].conj()) - Y[m, n].conj() * (
300 Y[m, m] - Y[n, n]
301 )
302 z[m, n] = Z[n, m] * (Z[n, n].conj() - Z[m, m].conj()) - Z[m, n].conj() * (
303 Z[m, m] - Z[n, n]
304 )
305 return x + y + z
308@handle_k(mode="skip")
309@handle_spin
310def get_wannier(atoms, psirs, Nit=10000, conv_tol=1e-7, mu=1, random_guess=False, seed=None):
311 """Steepest descent supercell Wannier localization.
313 This function is rather sensitive to the starting point, thus it is a good idea to start from
314 already localized orbitals.
316 This optimizes the given orbitals under unitary constraint matrices, see
317 IEEE Trans. Signal Process. 56, 1134.
319 Reference: Phys. Rev. B 59, 9703.
321 Args:
322 atoms: Atoms object.
323 psirs: Set of orbitals in real-space.
325 Keyword Args:
326 Nit: Number of iterations.
327 conv_tol: Convergence tolerance.
328 mu: Step size.
329 random_guess: Whether to use a random unitary starting guess or the identity.
330 seed: Seed to get a reproducible random guess.
332 Returns:
333 Localized orbitals.
334 """
335 if config.backend == "torch":
336 expm = xp.linalg.matrix_exp
337 else:
338 from scipy.linalg import expm
340 if not (xp.diag(xp.diag(atoms.a)) == atoms.a).all():
341 log.warning("The Wannier localization needs a cubic unit cell.")
342 return psirs
344 X, Y, Z = wannier_supercell_matrices(atoms, psirs) # Calculate matrices only once
345 # The initial unitary transformation is the identity or a random unitary matrix
346 if random_guess and atoms.occ.Nstate > 1:
347 U = xp.asarray(unitary_group.rvs(atoms.occ.Nstate, random_state=seed))
348 else:
349 U = xp.eye(atoms.occ.Nstate, dtype=complex)
350 costs = [0] # Add a zero to the costs to allow the sign evaluation in the first iteration
352 atoms._log.debug(f"{'Iteration':<11}{'Cost [a0^2]':<13}{'dCost [a0^2]':<13}")
353 for i in range(Nit):
354 sign = 1
355 costs.append(wannier_supercell_cost(X, Y, Z))
356 if abs(costs[-1] - costs[-2]) < conv_tol:
357 atoms._log.info(f"Wannier localizer converged after {i} iterations.")
358 break
359 # If the cost function gets smaller, change the direction
360 if costs[-1] - costs[-2] > 0:
361 sign = -1
363 # Calculate unitary transformation
364 dOmega = wannier_supercell_grad(atoms, X, Y, Z)
365 A = sign * mu * dOmega
366 # dOmega is anti-hermitian, therefore calculate -A instead of A.conj().T
367 # expm(A) will be unitary
368 expA_pos, expA_neg = expm(A), expm(-A)
369 # Update total rotation
370 U = U @ expA_pos
371 # Update matrices
372 X = expA_neg @ X @ expA_pos
373 Y = expA_neg @ Y @ expA_pos
374 Z = expA_neg @ Z @ expA_pos
376 atoms._log.debug(f"{i:>8} {costs[-1]:<+13,.6f}{costs[-1] - costs[-2]:<+13,.4e}")
378 if len(costs) > 1 and abs(costs[-1] - costs[-2]) > conv_tol:
379 atoms._log.warning("Wannier localizer not converged!")
380 # Return the localized orbitals by rotating them
381 return psirs @ U