Coverage for eminus/minimizer.py: 98.42%
253 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
1# SPDX-FileCopyrightText: 2022 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Minimization algorithms."""
5import copy
6import logging
7import math
9from . import backend as xp
10from .dft import get_epsilon, get_grad
11from .energies import get_E, get_Eentropy
12from .logger import name
13from .utils import dotprod
16def scf_step(scf, step):
17 """Perform one SCF step for a DFT calculation.
19 Calculating intermediate results speeds up the energy and gradient calculation.
20 This function is similar to H_precompute but will set all variables and energies in the SCF
21 class and returns the total energy.
23 Args:
24 scf: SCF object.
25 step: Optimization step.
27 Returns:
28 Total energy.
29 """
30 scf.callback(scf, step)
31 scf._precompute()
32 # Update occupations every smear_update'th cycle
33 if scf.atoms.occ.smearing > 0 and step % scf.smear_update == 0:
34 epsilon = get_epsilon(scf, scf.W, **scf._precomputed)
35 Efermi = scf.atoms.occ.smear(epsilon)
36 get_Eentropy(scf, epsilon, Efermi)
37 return get_E(scf)
40def check_convergence(scf, method, Elist, linmin=None, cg=None, norm_g=None):
41 """Check the energies for every SCF cycle and handle the output.
43 Args:
44 scf: SCF object.
45 method: Minimization method.
46 Elist: Total energies per SCF step.
48 Keyword Args:
49 linmin: Cosine between previous search direction and current gradient.
50 cg: Conjugate-gradient orthogonality.
51 norm_g: Gradient norm.
53 Returns:
54 Convergence condition.
55 """
56 iteration = len(Elist)
58 # Print all the data
59 print_scf_step(scf, method, Elist, linmin, cg, norm_g)
61 if iteration > 1:
62 # Check for convergence
63 if scf.gradtol is None or norm_g is None:
64 if abs(Elist[-1] - Elist[-2]) < scf.etol:
65 scf.is_converged = True
66 return True
67 # If a gradient tolerance has been set we also check norm_g for convergence
68 elif abs(Elist[-1] - Elist[-2]) < scf.etol and (xp.sum(norm_g, axis=0) < scf.gradtol).all():
69 scf.is_converged = True
70 return True
71 # Check if the current energy is higher than the last two values
72 if (xp.asarray(Elist[-3:-1]) < Elist[-1]).all():
73 scf._log.warning("Total energy is not decreasing.")
74 return False
77def print_scf_step(scf, method, Elist, linmin, cg, norm_g):
78 """Print the data of one SCF step and the header at the beginning.
80 Args:
81 scf: SCF object.
82 method: Minimization method.
83 Elist: Total energies per SCF step.
84 linmin: Cosine between previous search direction and current gradient.
85 cg: Conjugate-gradient orthogonality.
86 norm_g: Gradient norm.
87 """
88 iteration = len(Elist)
90 # Print a column header at the beginning
91 # The ljust values just have been chosen such that the output looks decent
92 if iteration == 1:
93 header = "Method".ljust(8)
94 header += "Iteration".ljust(11)
95 header += "Etot [Eh]".ljust(13)
96 header += "dEtot [Eh]".ljust(13)
97 # Print the gradient norm for cg methods
98 if method not in {"sd", "lm", "pclm"}:
99 header += "|Gradient|".ljust(10 * scf.atoms.occ.Nspin + 3)
100 # Print extra debugging information if available
101 if scf._log.level <= logging.DEBUG:
102 if method != "sd":
103 header += "linmin-test".ljust(10 * scf.atoms.occ.Nspin + 3)
104 if method not in {"sd", "lm", "pclm"}:
105 header += "cg-test".ljust(10 * scf.atoms.occ.Nspin + 3)
106 scf._log.debug(header)
107 else:
108 scf._log.info(header)
110 # Print the information for every cycle
111 info = f"{method:<8}{iteration:>8} {Elist[-1]:<+13,.6f}"
112 # In the first step we do not have all information yet
113 if iteration > 1:
114 info += f"{Elist[-1] - Elist[-2]:<+13,.4e}"
115 if norm_g is not None:
116 norm_g_str = [f"{x:+.2e}" for x in xp.sum(norm_g, axis=0)]
117 info += f"[{' '.join(norm_g_str)}]".ljust(10 * scf.atoms.occ.Nspin + 3)
118 if scf._log.level <= logging.DEBUG:
119 if method != "sd" and linmin is not None:
120 linmin_str = [f"{x:+.2e}" for x in xp.sum(linmin, axis=0)]
121 info += f"[{' '.join(linmin_str)}]".ljust(10 * scf.atoms.occ.Nspin + 3)
122 if method not in {"sd", "lm", "pclm"} and cg is not None:
123 cg_str = [f"{x:+.2e}" for x in xp.sum(cg, axis=0)]
124 info += f"[{' '.join(cg_str)}]".ljust(10 * scf.atoms.occ.Nspin + 3)
125 if scf._log.level <= logging.DEBUG:
126 scf._log.debug(info)
127 else:
128 scf._log.info(info)
131def linmin_test(g, d):
132 """Do the line minimization test.
134 Calculate the cosine of the angle between g and d.
136 Reference: https://trond.hjorteland.com/thesis/node26.html
138 Args:
139 g: Current gradient.
140 d: Previous search direction.
142 Returns:
143 Linmin angle.
144 """
145 # cos = A B / |A| |B|
146 return dotprod(g, d) / math.sqrt(dotprod(g, g) * dotprod(d, d))
149def cg_test(atoms, ik, g, g_old, precondition=True):
150 """Test the gradient-orthogonality theorem, i.e., g and g_old should be orthogonal.
152 Calculate the cosine of the angle between g and g_old. For an angle of 90 degree the cosine goes
153 to zero.
155 Reference: https://math.uci.edu/~chenlong/CAMtips/CG.html
157 Args:
158 atoms: Atoms object.
159 ik: k-point index.
160 g: Current gradient.
161 g_old: Previous gradient.
163 Keyword Args:
164 precondition: Whether to use a preconditioner.
166 Returns:
167 CG angle.
168 """
169 if precondition:
170 Kg, Kg_old = atoms.K(g, ik), atoms.K(g_old, ik)
171 else:
172 Kg, Kg_old = g, g_old
173 # cos = A B / |A| |B|
174 return dotprod(g, Kg_old) / math.sqrt(dotprod(g, Kg) * dotprod(g_old, Kg_old))
177def cg_method(scf, ik, cgform, g, g_old, d_old, precondition=True):
178 """Do different variants of the conjugate gradient method.
180 Reference: https://indrag49.github.io/Numerical-Optimization/conjugate-gradient-methods-1.html
182 Args:
183 scf: SCF object.
184 ik: k-point index.
185 cgform: Conjugate gradient form.
186 g: Current gradient.
187 g_old: Previous gradient.
188 d_old: Previous search direction.
190 Keyword Args:
191 precondition: Whether to use a preconditioner.
193 Returns:
194 Conjugate scalar and gradient norm.
195 """
196 atoms = scf.atoms
198 if precondition:
199 Kg, Kg_old = atoms.K(g, ik), atoms.K(g_old, ik)
200 else:
201 Kg, Kg_old = g, g_old
202 norm_g = dotprod(g, Kg)
204 if cgform == 1: # Fletcher-Reeves
205 return norm_g / dotprod(g_old, Kg_old), norm_g
206 if cgform == 2: # Polak-Ribiere
207 return dotprod(g - g_old, Kg) / dotprod(g_old, Kg_old), norm_g
208 if cgform == 3: # Hestenes-Stiefel
209 return dotprod(g - g_old, Kg) / dotprod(g - g_old, d_old), norm_g
210 if cgform == 4: # Dai-Yuan
211 return norm_g / dotprod(g - g_old, d_old), norm_g
212 msg = f'No cgform found for "{cgform}".'
213 raise ValueError(msg)
216@name("steepest descent minimization")
217def sd(scf, Nit, cost=scf_step, grad=get_grad, condition=check_convergence, betat=3e-5, **kwargs):
218 """Steepest descent minimization algorithm.
220 Args:
221 scf: SCF object.
222 Nit: Maximum number of SCF steps.
224 Keyword Args:
225 cost: Function that will run every SCF step.
226 grad: Function that calculates the respective gradient.
227 condition: Function to check and log the convergence condition.
228 betat: Step size.
229 **kwargs: Throwaway arguments.
231 Returns:
232 Total energies per SCF cycle.
233 """
234 atoms = scf.atoms
235 costs = []
237 for i in range(Nit):
238 c = cost(scf, i)
239 costs.append(c)
240 if condition(scf, "sd", costs):
241 break
242 for ik in range(atoms.kpts.Nk):
243 for spin in range(atoms.occ.Nspin):
244 g = grad(scf, ik, spin, scf.W, **scf._precomputed)
245 scf.W[ik][spin] = scf.W[ik][spin] - betat * g
246 return costs
249@name("preconditioned line minimization")
250def pclm(
251 scf,
252 Nit,
253 cost=scf_step,
254 grad=get_grad,
255 condition=check_convergence,
256 betat=3e-5,
257 precondition=True,
258 **kwargs,
259):
260 """Preconditioned line minimization algorithm.
262 Args:
263 scf: SCF object.
264 Nit: Maximum number of SCF steps.
266 Keyword Args:
267 cost: Function that will run every SCF step.
268 grad: Function that calculates the respective gradient.
269 condition: Function to check and log the convergence condition.
270 betat: Step size.
271 precondition: Whether to use a preconditioner.
272 **kwargs: Throwaway arguments.
274 Returns:
275 Total energies per SCF cycle.
276 """
277 atoms = scf.atoms
278 costs = []
280 if precondition:
281 method = "pclm"
282 else:
283 method = "lm"
285 # Scalars that need to be saved for each spin
286 linmin = xp.empty((atoms.kpts.Nk, atoms.occ.Nspin))
287 # Search direction that needs to be saved for each spin
288 d = [xp.empty_like(Wk) for Wk in scf.W]
289 g = [xp.empty_like(Wk) for Wk in scf.W]
291 for i in range(Nit):
292 W_tmp = copy.deepcopy(scf.W)
293 for ik in range(atoms.kpts.Nk):
294 for spin in range(atoms.occ.Nspin):
295 g[ik][spin] = grad(scf, ik, spin, scf.W, **scf._precomputed)
296 # Calculate linmin each spin separately
297 if scf._log.level <= logging.DEBUG and i > 0:
298 linmin[ik][spin] = linmin_test(g[ik][spin], d[ik][spin])
299 if precondition:
300 d[ik][spin] = -atoms.K(g[ik][spin], ik)
301 else:
302 d[ik][spin] = -g[ik][spin]
303 scf.W[ik][spin] = scf.W[ik][spin] + betat * d[ik][spin]
305 scf._precompute()
306 for ik in range(atoms.kpts.Nk):
307 for spin in range(atoms.occ.Nspin):
308 gt = grad(scf, ik, spin, scf.W, **scf._precomputed)
309 beta = abs(
310 betat
311 * dotprod(g[ik][spin], d[ik][spin])
312 / dotprod(g[ik][spin] - gt, d[ik][spin])
313 )
314 scf.W[ik][spin] = W_tmp[ik][spin] + beta * d[ik][spin]
315 c = cost(scf, i)
316 costs.append(c)
317 if condition(scf, method, costs, linmin):
318 break
319 return costs
322@name("line minimization")
323def lm(scf, Nit, cost=scf_step, grad=get_grad, condition=check_convergence, betat=3e-5, **kwargs):
324 """Line minimization algorithm.
326 Args:
327 scf: SCF object.
328 Nit: Maximum number of SCF steps.
330 Keyword Args:
331 cost: Function that will run every SCF step.
332 grad: Function that calculates the respective gradient.
333 condition: Function to check and log the convergence condition.
334 betat: Step size.
335 **kwargs: Throwaway arguments.
337 Returns:
338 Total energies per SCF cycle.
339 """
340 return pclm(scf, Nit, cost, grad, condition, betat, precondition=False)
343@name("preconditioned conjugate-gradient minimization")
344def pccg(
345 scf,
346 Nit,
347 cost=scf_step,
348 grad=get_grad,
349 condition=check_convergence,
350 betat=3e-5,
351 cgform=1,
352 precondition=True,
353):
354 """Preconditioned conjugate-gradient minimization algorithm.
356 Args:
357 scf: SCF object.
358 Nit: Maximum number of SCF steps.
360 Keyword Args:
361 cost: Function that will run every SCF step.
362 grad: Function that calculates the respective gradient.
363 condition: Function to check and log the convergence condition.
364 betat: Step size.
365 cgform: Conjugate gradient form.
366 precondition: Whether to use a preconditioner.
368 Returns:
369 Total energies per SCF cycle.
370 """
371 atoms = scf.atoms
372 costs = []
374 if precondition:
375 method = "pccg"
376 else:
377 method = "cg"
379 # Scalars that need to be saved for each spin and k-point
380 linmin = xp.empty((atoms.kpts.Nk, atoms.occ.Nspin))
381 cg = xp.empty((atoms.kpts.Nk, atoms.occ.Nspin))
382 norm_g = xp.empty((atoms.kpts.Nk, atoms.occ.Nspin))
383 # Gradients that need to be saved for each spin and k-point
384 d = [xp.empty_like(Wk) for Wk in scf.W]
385 g = [xp.empty_like(Wk) for Wk in scf.W]
386 d_old = [xp.empty_like(Wk) for Wk in scf.W]
387 g_old = [xp.empty_like(Wk) for Wk in scf.W]
389 # Do the first step without the linmin and cg tests, and without the cg_method
390 W_tmp = copy.deepcopy(scf.W)
391 for ik in range(atoms.kpts.Nk):
392 for spin in range(atoms.occ.Nspin):
393 g[ik][spin] = grad(scf, ik, spin, scf.W, **scf._precomputed)
394 if precondition:
395 d[ik][spin] = -atoms.K(g[ik][spin], ik)
396 else:
397 d[ik][spin] = -g[ik][spin]
398 scf.W[ik][spin] = scf.W[ik][spin] + betat * d[ik][spin]
400 # Calculate the optimal step width
401 scf._precompute()
402 for ik in range(atoms.kpts.Nk):
403 for spin in range(atoms.occ.Nspin):
404 gt = grad(scf, ik, spin, scf.W, **scf._precomputed)
405 beta = abs(
406 betat * dotprod(g[ik][spin], d[ik][spin]) / dotprod(g[ik][spin] - gt, d[ik][spin])
407 )
408 scf.W[ik][spin] = W_tmp[ik][spin] + beta * d[ik][spin]
409 g_old[ik][spin], d_old[ik][spin] = g[ik][spin], d[ik][spin]
411 # Evaluate the cost function
412 c = cost(scf, -1)
413 costs.append(c)
414 condition(scf, method, costs)
416 # Start the iteration
417 for i in range(1, Nit):
418 W_tmp = copy.deepcopy(scf.W)
419 for ik in range(atoms.kpts.Nk):
420 for spin in range(atoms.occ.Nspin):
421 g[ik][spin] = grad(scf, ik, spin, scf.W, **scf._precomputed)
422 # Calculate linmin and cg for each spin and k-point separately if needed
423 if scf._log.level <= logging.DEBUG:
424 linmin[ik][spin] = linmin_test(g[ik][spin], d[ik][spin])
425 cg[ik][spin] = cg_test(atoms, ik, g[ik][spin], g_old[ik][spin], precondition)
426 beta, norm_g[ik][spin] = cg_method(
427 scf, ik, cgform, g[ik][spin], g_old[ik][spin], d_old[ik][spin], precondition
428 )
429 if precondition:
430 d[ik][spin] = -atoms.K(g[ik][spin], ik) + beta * d_old[ik][spin]
431 else:
432 d[ik][spin] = -g[ik][spin] + beta * d_old[ik][spin]
433 scf.W[ik][spin] = scf.W[ik][spin] + betat * d[ik][spin]
435 scf._precompute()
436 for ik in range(atoms.kpts.Nk):
437 for spin in range(atoms.occ.Nspin):
438 gt = grad(scf, ik, spin, scf.W, **scf._precomputed)
439 beta = abs(
440 betat
441 * dotprod(g[ik][spin], d[ik][spin])
442 / dotprod(g[ik][spin] - gt, d[ik][spin])
443 )
444 scf.W[ik][spin] = W_tmp[ik][spin] + beta * d[ik][spin]
445 g_old[ik][spin], d_old[ik][spin] = g[ik][spin], d[ik][spin]
447 c = cost(scf, i)
448 costs.append(c)
449 if condition(scf, method, costs, linmin, cg, norm_g):
450 break
451 return costs
454@name("conjugate-gradient minimization")
455def cg(scf, Nit, cost=scf_step, grad=get_grad, condition=check_convergence, betat=3e-5, cgform=1):
456 """Conjugate-gradient minimization algorithm.
458 Args:
459 scf: SCF object.
460 Nit: Maximum number of SCF steps.
462 Keyword Args:
463 cost: Function that will run every SCF step.
464 grad: Function that calculates the respective gradient.
465 condition: Function to check and log the convergence condition.
466 betat: Step size.
467 cgform: Conjugate gradient form.
469 Returns:
470 Total energies per SCF cycle.
471 """
472 return pccg(scf, Nit, cost, grad, condition, betat, cgform, precondition=False)
475@name("auto minimization")
476def auto(scf, Nit, cost=scf_step, grad=get_grad, condition=check_convergence, betat=3e-5, cgform=1): # noqa: C901
477 """Automatic preconditioned conjugate-gradient minimization algorithm.
479 This function chooses an sd step over the pccg step if the energy goes up.
481 Args:
482 scf: SCF object.
483 Nit: Maximum number of SCF steps.
485 Keyword Args:
486 cost: Function that will run every SCF step.
487 grad: Function that calculates the respective gradient.
488 condition: Function to check and log the convergence condition.
489 betat: Step size.
490 cgform: Conjugate gradient form.
492 Returns:
493 Total energies per SCF cycle.
494 """
495 atoms = scf.atoms
496 costs = []
498 # Scalars that need to be saved for each spin
499 linmin = xp.empty((atoms.kpts.Nk, atoms.occ.Nspin))
500 cg = xp.empty((atoms.kpts.Nk, atoms.occ.Nspin))
501 norm_g = xp.empty((atoms.kpts.Nk, atoms.occ.Nspin))
502 # Gradients that need to be saved for each spin
503 d = [xp.empty_like(Wk) for Wk in scf.W]
504 g = [xp.empty_like(Wk) for Wk in scf.W]
505 d_old = [xp.empty_like(Wk) for Wk in scf.W]
506 g_old = [xp.empty_like(Wk) for Wk in scf.W]
508 # Do the first step without the linmin and cg tests, and without the cg_method
509 W_tmp = copy.deepcopy(scf.W)
510 for ik in range(atoms.kpts.Nk):
511 for spin in range(atoms.occ.Nspin):
512 g[ik][spin] = grad(scf, ik, spin, scf.W, **scf._precomputed)
513 d[ik][spin] = -atoms.K(g[ik][spin], ik)
514 scf.W[ik][spin] = scf.W[ik][spin] + betat * d[ik][spin]
516 # Calculate the optimal step width
517 scf._precompute()
518 for ik in range(atoms.kpts.Nk):
519 for spin in range(atoms.occ.Nspin):
520 gt = grad(scf, ik, spin, scf.W, **scf._precomputed)
521 beta = abs(
522 betat * dotprod(g[ik][spin], d[ik][spin]) / dotprod(g[ik][spin] - gt, d[ik][spin])
523 )
524 scf.W[ik][spin] = W_tmp[ik][spin] + beta * d[ik][spin]
525 g_old[ik][spin], d_old[ik][spin] = g[ik][spin], d[ik][spin]
527 # Evaluate the cost function
528 c = cost(scf, -1)
529 costs.append(c)
530 if condition(scf, "pccg", costs):
531 return costs
533 # Start the iteration
534 for i in range(1, Nit):
535 W_tmp = copy.deepcopy(scf.W)
536 for ik in range(atoms.kpts.Nk):
537 for spin in range(atoms.occ.Nspin):
538 g[ik][spin] = grad(scf, ik, spin, scf.W, **scf._precomputed)
539 # Calculate linmin and cg for each spin separately
540 if scf._log.level <= logging.DEBUG:
541 linmin[ik][spin] = linmin_test(g[ik][spin], d[ik][spin])
542 cg[ik][spin] = cg_test(atoms, ik, g[ik][spin], g_old[ik][spin])
543 beta, norm_g[ik][spin] = cg_method(
544 scf, ik, cgform, g[ik][spin], g_old[ik][spin], d_old[ik][spin]
545 )
546 d[ik][spin] = -atoms.K(g[ik][spin], ik) + beta * d_old[ik][spin]
547 scf.W[ik][spin] = scf.W[ik][spin] + betat * d[ik][spin]
549 scf._precompute()
550 for ik in range(atoms.kpts.Nk):
551 for spin in range(atoms.occ.Nspin):
552 gt = grad(scf, ik, spin, scf.W, **scf._precomputed)
553 beta = abs(
554 betat
555 * dotprod(g[ik][spin], d[ik][spin])
556 / dotprod(g[ik][spin] - gt, d[ik][spin])
557 )
558 scf.W[ik][spin] = W_tmp[ik][spin] + beta * d[ik][spin]
559 g_old[ik][spin], d_old[ik][spin] = g[ik][spin], d[ik][spin]
561 c = cost(scf, i)
562 # If the energy does not go down use the steepest descent step and recalculate the energy
563 if c > costs[-1]:
564 for ik in range(atoms.kpts.Nk):
565 scf.W[ik] = W_tmp[ik] - betat * g[ik]
566 c = cost(scf, -1)
567 costs.append(c)
568 # Do not print cg and linmin if we do the sd step
569 if condition(scf, "sd", costs, norm_g=norm_g):
570 break
571 else:
572 costs.append(c)
573 if condition(scf, "pccg", costs, linmin, cg, norm_g):
574 break
575 return costs
578#: Map minimizer names with their respective implementation.
579IMPLEMENTED = {
580 "sd": sd,
581 "lm": lm,
582 "pclm": pclm,
583 "cg": cg,
584 "pccg": pccg,
585 "auto": auto,
586}