Coverage for eminus/kpoints.py: 97.46%
197 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-21 12:19 +0000
1# SPDX-FileCopyrightText: 2023 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Generate k-points and sample band paths."""
5import math
6import numbers
8import numpy as np
9from scipy.spatial import Voronoi
11from . import backend as xp
12from .data import LATTICE_VECTORS, SPECIAL_POINTS
13from .logger import log
14from .utils import BaseObject
17class KPoints(BaseObject):
18 """KPoints object that holds k-points properties and build methods.
20 Args:
21 lattice: Lattice system.
22 a: Cell size.
23 """
25 def __init__(self, lattice, a=None):
26 """Initialize the KPoints object."""
27 self.lattice = lattice #: Lattice system.
28 if a is None:
29 a = LATTICE_VECTORS[self.lattice]
30 if isinstance(a, numbers.Real):
31 a = a * xp.asarray(LATTICE_VECTORS[self.lattice], dtype=float)
32 self.a = xp.asarray(a, dtype=float) #: Cell size.
33 self.kmesh = [1, 1, 1] #: k-point mesh.
34 self.wk = [1] #: k-point weights.
35 self.k = [[0, 0, 0]] #: k-point coordinates.
36 self.kshift = [0, 0, 0] #: k-point shift-vector.
37 self._gamma_centered = True #: Generate a Gamma-point centered grid.
38 self.is_built = True #: Determines the KPoints object build status.
40 # ### Class properties ###
42 @property
43 def kmesh(self):
44 """Monkhorst-Pack k-point mesh."""
45 return self._kmesh
47 @kmesh.setter
48 def kmesh(self, value):
49 if value is not None:
50 if isinstance(value, numbers.Integral):
51 value = value * xp.ones(3, dtype=int)
52 self._kmesh = xp.asarray(value, dtype=int)
53 self.path = None
54 self.is_built = False
55 # If we set a band path to the object the k-mesh gets set to None
56 else:
57 self._kmesh = None
59 @property
60 def wk(self):
61 """k-point weights."""
62 return self._wk
64 @wk.setter
65 def wk(self, value):
66 self._wk = xp.asarray(value, dtype=float)
67 self._Nk = len(self._wk)
68 self.is_built = False
70 @property
71 def k(self):
72 """k-point coordinates."""
73 return self._k
75 @k.setter
76 def k(self, value):
77 self._k = xp.asarray(value, dtype=float)
78 self.is_built = False
80 @property
81 def Nk(self):
82 """Number of k-points."""
83 return self._Nk
85 @Nk.setter
86 def Nk(self, value):
87 self._Nk = int(value)
88 self.is_built = False
90 @property
91 def kshift(self):
92 """k-point shift-vector."""
93 return self._kshift
95 @kshift.setter
96 def kshift(self, value):
97 self._kshift = xp.asarray(value, dtype=float)
98 self.is_built = False
100 @property
101 def gamma_centered(self):
102 """Generate a Gamma-point centered grid."""
103 return self._gamma_centered
105 @gamma_centered.setter
106 def gamma_centered(self, value):
107 if value != self._gamma_centered:
108 self._gamma_centered = value
109 self.is_built = False
111 @property
112 def path(self):
113 """k-point band path."""
114 return self._path
116 @path.setter
117 def path(self, value):
118 if value is not None:
119 self._path = value.upper()
120 self.kmesh = None
121 self.is_built = False
122 # If we set a k-mesh to the object the band path gets set to None
123 else:
124 self._path = None
126 # ### Read-only properties ###
128 @property
129 def k_scaled(self):
130 """Scaled k-point coordinates."""
131 # This will not be set when setting the k-point coordinates manually
132 return self._k_scaled
134 # ### Class methods ###
136 def build(self):
137 """Build all parameters of the KPoints object."""
138 if self.lattice == "sc" and not xp.all(self.a == xp.diag(xp.diag(self.a))):
139 log.warning("Lattice system and lattice vectors do not match.")
140 if self.is_built:
141 return self
142 if self.kmesh is not None:
143 if self.gamma_centered:
144 self._k_scaled = gamma_centered(self.kmesh)
145 else:
146 self._k_scaled = monkhorst_pack(self.kmesh)
147 else:
148 self._k_scaled = bandpath(self)
149 # Without removing redundancies the weight is the same for all k-points
150 self.wk = xp.ones(len(self._k_scaled)) / len(self._k_scaled)
151 self.k = kpoint_convert(self._k_scaled, self.a) + self.kshift
152 self.is_built = True
153 return self
155 kernel = build
157 def trs(self):
158 """Reduce k-points using time reversal symmetry (k=-k)."""
159 if not self.is_built:
160 self.build()
161 if not xp.any(self.k < 0):
162 log.warning("No negative k-points found. Nothing to do.")
163 return self
164 idx_to_remove = []
165 for i in range(self.Nk):
166 for j in range(i + 1, self.Nk):
167 # Check k=-k within some tolerance
168 if xp.sum(xp.abs(self.k[i] + self.k[j])) < 1e-15:
169 idx_to_remove.append(i)
170 self.wk[j] += self.wk[i] # Adjust weights
171 # Delete k-points and weights
172 self.k = xp.delete(self.k, idx_to_remove, axis=0)
173 self.wk = xp.delete(self.wk, idx_to_remove)
174 return self
176 def _assert_gamma_only(self):
177 """Make sure that the object only contains the Gamma point."""
178 if not xp.all(self.k == 0):
179 msg = "The k-points object does not contain only the Gamma point."
180 raise NotImplementedError(msg)
182 def __repr__(self):
183 """Print the parameters stored in the KPoints object."""
184 return (
185 f"Number of k-points: {self.Nk}\n"
186 f"k-mesh: {self.kmesh}\n"
187 f"Band path: {self.path}\n"
188 f"Shift: {self.kshift}\n"
189 f"Weights: {self.wk}"
190 )
193def kpoint_convert(k_points, lattice_vectors):
194 """Convert scaled k-points to cartesian coordinates.
196 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py
198 Args:
199 k_points: k-points.
200 lattice_vectors: Lattice vectors.
202 Returns:
203 k-points in cartesian coordinates.
204 """
205 k_points = xp.asarray(k_points, dtype=float)
206 lattice_vectors = xp.asarray(lattice_vectors, dtype=float)
207 inv_cell = 2 * math.pi * xp.linalg.inv(lattice_vectors).T
208 return k_points @ inv_cell
211def monkhorst_pack(nk):
212 """Generate a Monkhorst-Pack mesh of k-points, i.e., equally spaced k-points.
214 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py
216 Args:
217 nk: Number of k-points per axis.
219 Returns:
220 k-points.
221 """
222 # Same index matrix as in Atoms._get_index_matrices()
223 M = xp.asarray(np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3)))
224 return (M + 0.5) / xp.asarray(nk) - 0.5 # Normal Monkhorst-Pack grid
227def gamma_centered(nk):
228 """Generate a Gamma-point centered mesh of k-points.
230 Reference: https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/gto/cell.py
232 Args:
233 nk: Number of k-points per axis.
235 Returns:
236 k-points.
237 """
238 # Same index matrix as in Atoms._get_index_matrices()
239 M = xp.asarray(np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3)))
240 return M / xp.asarray(nk) # Gamma-centered grid
243def bandpath(kpts):
244 """Generate sampled band paths.
246 Args:
247 kpts: KPoints object.
249 Returns:
250 Sampled k-points.
251 """
252 # Convert path to a list and get special points
253 path_list = list(kpts.path)
254 s_points = SPECIAL_POINTS[kpts.lattice]
255 # Commas indicate jumps and are no special points
256 N_special = len([p for p in path_list if p != ","])
258 # Input handling
259 N = kpts.Nk
260 if N_special > N:
261 log.warning("Sampling is smaller than the number of special points.")
262 N = N_special
263 for p in path_list:
264 if p not in {*s_points, ","}:
265 msg = f"{p} is not a special point for the {kpts.lattice} lattice."
266 raise KeyError(msg)
268 # Calculate distances between special points
269 dists = []
270 for i in range(len(path_list) - 1):
271 if "," not in path_list[i : i + 2]:
272 # Use subtract since s_points are lists
273 dist = xp.asarray(s_points[path_list[i + 1]]) - xp.asarray(s_points[path_list[i]])
274 dists.append(xp.linalg.norm(kpoint_convert(dist, kpts.a)))
275 else:
276 # Set distance to zero when jumping between special points
277 dists.append(0)
279 # Calculate sample points between the special points
280 scaled_dists = (N - N_special) * xp.asarray(dists) / sum(dists)
281 samplings = xp.asarray(xp.round(scaled_dists), dtype=int)
283 # If our sampling does not match the given N add the difference to the longest distance
284 if N - N_special - xp.sum(samplings) != 0:
285 samplings[xp.argmax(samplings)] += N - N_special - xp.sum(samplings)
287 # Generate k-point coordinates
288 k_points = [xp.asarray(s_points[path_list[0]])] # Insert the first special point
289 for i in range(len(path_list) - 1):
290 # Only do something when not jumping between special points
291 if "," not in path_list[i : i + 2]:
292 s_start = xp.asarray(s_points[path_list[i]])
293 s_end = xp.asarray(s_points[path_list[i + 1]])
294 # Get the vector between special points
295 k_dist = s_end - s_start
296 # Add scaled vectors to the special point to get the new k-points
297 k_points += [
298 s_start + k_dist * (n + 1) / (samplings[i] + 1) for n in range(samplings[i])
299 ]
300 # Append the special point we are ending at
301 k_points.append(s_end)
302 # If we jump, add the new special point to start from
303 elif path_list[i] == ",":
304 k_points.append(xp.asarray(s_points[path_list[i + 1]]))
305 return xp.stack(k_points)
308def kpoints2axis(kpts):
309 """Generate the x-axis for band structures and the respective band path.
311 Args:
312 kpts: KPoints object.
314 Returns:
315 k-point axis, special point coordinates, and labels.
316 """
317 # Convert path to a list and get the special points
318 path_list = list(kpts.path)
319 s_points = SPECIAL_POINTS[kpts.lattice]
321 # Calculate the distances between k-points
322 k_dist = kpts.k_scaled[1:] - kpts.k_scaled[:-1]
323 dists = xp.linalg.norm(kpoint_convert(k_dist, kpts.a), axis=1)
325 # Create the labels
326 labels = []
327 for i in range(len(path_list)):
328 # If a jump happened before the current step the special point is already included
329 if i > 1 and path_list[i - 1] == ",":
330 continue
331 # Append the special point if no jump happens
332 if "," not in path_list[i : i + 2]:
333 labels.append(path_list[i])
334 # When jumping join the special points to one label
335 elif path_list[i] == ",":
336 labels.append("".join(path_list[i - 1 : i + 2]))
338 # Get the indices of the special points
339 special_indices = [0] # The first special point is trivial
340 for p in labels[1:]:
341 # Only search the k-points starting from the previous special point
342 shift = special_indices[-1]
343 k = xp.asarray(kpts.k_scaled[shift:])
344 # We index p[0] since p could be a joined label of a jump
345 # This expression simply finds the special point in the k_points matrix
346 # The following expressions is a bit more readable, but only works with NumPy, not Torch
347 # index = np.flatnonzero((k == s_points[p[0]]).all(axis=1))[0] + shift
348 index = xp.nonzero(xp.ravel(xp.all(k == xp.asarray(s_points[p[0]]), axis=1)))[0][0] + shift
349 special_indices.append(int(index))
350 # Set the distance between special points to zero if we have a jump
351 if "," in p:
352 dists[index] = 0
354 # Insert a zero at the beginning and add up the lengths to create the k-axis
355 k_axis = xp.concatenate((xp.asarray([0]), xp.cumsum(dists, axis=0)))
356 return k_axis, k_axis[special_indices], labels
359def get_brillouin_zone(lattice_vectors):
360 """Generate the Brillouin zone for given lattice vectors.
362 The Brillouin zone can be constructed with a Voronoi decomposition of the reciprocal lattice.
364 Reference: http://staff.ustc.edu.cn/~zqj/posts/howto-plot-brillouin-zone
366 Args:
367 lattice_vectors: Lattice vectors.
369 Returns:
370 Brillouin zone vertices.
371 """
372 inv_cell = xp.to_np(kpoint_convert(xp.eye(3), lattice_vectors))
374 px, py, pz = np.tensordot(inv_cell, np.mgrid[-1:2, -1:2, -1:2], axes=(0, 0))
375 points = np.c_[px.ravel(), py.ravel(), pz.ravel()]
376 vor = Voronoi(points)
378 bz_ridges = []
379 for pid, rid in zip(vor.ridge_points, vor.ridge_vertices):
380 if pid[0] == 13 or pid[1] == 13:
381 bz_ridges.append(vor.vertices[np.r_[rid, [rid[0]]]])
382 return bz_ridges