Coverage for eminus/kpoints.py: 97.21%
179 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-18 08:43 +0000
1# SPDX-FileCopyrightText: 2021 The eminus developers
2# SPDX-License-Identifier: Apache-2.0
3"""Generate k-points and sample band paths."""
5import numbers
7import numpy as np
8from scipy.linalg import inv, norm
9from scipy.spatial import Voronoi
11from .data import LATTICE_VECTORS, SPECIAL_POINTS
12from .logger import log
13from .utils import BaseObject
16class KPoints(BaseObject):
17 """KPoints object that holds k-points properties and build methods.
19 Args:
20 lattice: Lattice system.
21 a: Cell size.
22 """
24 def __init__(self, lattice, a=None):
25 """Initialize the KPoints object."""
26 self.lattice = lattice #: Lattice system.
27 if a is None:
28 a = LATTICE_VECTORS[self.lattice]
29 if isinstance(a, numbers.Real):
30 a = a * np.asarray(LATTICE_VECTORS[self.lattice])
31 self.a = a #: Cell size.
32 self.kmesh = [1, 1, 1] #: k-point mesh.
33 self.wk = [1] #: k-point weights.
34 self.k = [[0, 0, 0]] #: k-point coordinates.
35 self.kshift = [0, 0, 0] #: k-point shift-vector.
36 self._gamma_centered = True #: Generate a Gamma-point centered grid.
37 self.is_built = True #: Determines the KPoints object build status.
39 # ### Class properties ###
41 @property
42 def kmesh(self):
43 """Monkhorst-Pack k-point mesh."""
44 return self._kmesh
46 @kmesh.setter
47 def kmesh(self, value):
48 if value is not None:
49 if isinstance(value, numbers.Integral):
50 value = value * np.ones(3, dtype=int)
51 self._kmesh = np.asarray(value)
52 self.path = None
53 self.is_built = False
54 # If we set a band path to the object the k-mesh gets set to None
55 else:
56 self._kmesh = None
58 @property
59 def wk(self):
60 """k-point weights."""
61 return self._wk
63 @wk.setter
64 def wk(self, value):
65 self._wk = np.asarray(value)
66 self._Nk = len(self._wk)
67 self.is_built = False
69 @property
70 def k(self):
71 """k-point coordinates."""
72 return self._k
74 @k.setter
75 def k(self, value):
76 self._k = np.asarray(value)
77 self.is_built = False
79 @property
80 def Nk(self):
81 """Number of k-points."""
82 return self._Nk
84 @Nk.setter
85 def Nk(self, value):
86 self._Nk = int(value)
87 self.is_built = False
89 @property
90 def kshift(self):
91 """k-point shift-vector."""
92 return self._kshift
94 @kshift.setter
95 def kshift(self, value):
96 self._kshift = np.asarray(value)
97 self.is_built = False
99 @property
100 def gamma_centered(self):
101 """Generate a Gamma-point centered grid."""
102 return self._gamma_centered
104 @gamma_centered.setter
105 def gamma_centered(self, value):
106 if value != self._gamma_centered:
107 self._gamma_centered = value
108 self.is_built = False
110 @property
111 def path(self):
112 """k-point band path."""
113 return self._path
115 @path.setter
116 def path(self, value):
117 if value is not None:
118 self._path = value.upper()
119 self.kmesh = None
120 self.is_built = False
121 # If we set a k-mesh to the object the band path gets set to None
122 else:
123 self._path = None
125 # ### Read-only properties ###
127 @property
128 def k_scaled(self):
129 """Scaled k-point coordinates."""
130 # This will not be set when setting the k-point coordinates manually
131 return self._k_scaled
133 # ### Class methods ###
135 def build(self):
136 """Build all parameters of the KPoints object."""
137 if self.lattice == "sc" and not (self.a == np.diag(np.diag(self.a))).all():
138 log.warning("Lattice system and lattice vectors do not match.")
139 if self.is_built:
140 return self
141 if self.kmesh is not None:
142 if self.gamma_centered:
143 self._k_scaled = gamma_centered(self.kmesh)
144 else:
145 self._k_scaled = monkhorst_pack(self.kmesh)
146 else:
147 self._k_scaled = bandpath(self)
148 # Without removing redundancies the weight is the same for all k-points
149 self.wk = np.ones(len(self._k_scaled)) / len(self._k_scaled)
150 self.k = kpoint_convert(self._k_scaled, self.a) + self.kshift
151 self.is_built = True
152 return self
154 kernel = build
156 def _assert_gamma_only(self):
157 """Make sure that the object only contains the Gamma point."""
158 if not np.all(self.k == 0):
159 msg = "The k-points object does not contain only the Gamma point."
160 raise NotImplementedError(msg)
162 def __repr__(self):
163 """Print the parameters stored in the KPoints object."""
164 return (
165 f"Number of k-points: {self.Nk}\n"
166 f"k-mesh: {self.kmesh}\n"
167 f"Band path: {self.path}\n"
168 f"Shift: {self.kshift}\n"
169 f"Weights: {self.wk}"
170 )
173def kpoint_convert(k_points, lattice_vectors):
174 """Convert scaled k-points to cartesian coordinates.
176 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py
178 Args:
179 k_points: k-points.
180 lattice_vectors: Lattice vectors.
182 Returns:
183 k-points in cartesian coordinates.
184 """
185 inv_cell = 2 * np.pi * inv(lattice_vectors).T
186 return k_points @ inv_cell
189def monkhorst_pack(nk):
190 """Generate a Monkhorst-Pack mesh of k-points, i.e., equally spaced k-points.
192 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py
194 Args:
195 nk: Number of k-points per axis.
197 Returns:
198 k-points.
199 """
200 # Same index matrix as in Atoms._get_index_matrices()
201 M = np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3))
202 return (M + 0.5) / nk - 0.5 # Normal Monkhorst-Pack grid
205def gamma_centered(nk):
206 """Generate a Gamma-point centered mesh of k-points.
208 Reference: https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/gto/cell.py
210 Args:
211 nk: Number of k-points per axis.
213 Returns:
214 k-points.
215 """
216 # Same index matrix as in Atoms._get_index_matrices()
217 M = np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3))
218 return M / nk # Gamma-centered grid
221def bandpath(kpts):
222 """Generate sampled band paths.
224 Args:
225 kpts: KPoints object.
227 Returns:
228 Sampled k-points.
229 """
230 # Convert path to a list and get special points
231 path_list = list(kpts.path)
232 s_points = SPECIAL_POINTS[kpts.lattice]
233 # Commas indicate jumps and are no special points
234 N_special = len([p for p in path_list if p != ","])
236 # Input handling
237 N = kpts.Nk
238 if N_special > N:
239 log.warning("Sampling is smaller than the number of special points.")
240 N = N_special
241 for p in path_list:
242 if p not in {*s_points, ","}:
243 msg = f"{p} is not a special point for the {kpts.lattice} lattice."
244 raise KeyError(msg)
246 # Calculate distances between special points
247 dists = []
248 for i in range(len(path_list) - 1):
249 if "," not in path_list[i : i + 2]:
250 # Use subtract since s_points are lists
251 dist = np.subtract(s_points[path_list[i + 1]], s_points[path_list[i]])
252 dists.append(norm(kpoint_convert(dist, kpts.a)))
253 else:
254 # Set distance to zero when jumping between special points
255 dists.append(0)
257 # Calculate sample points between the special points
258 scaled_dists = (N - N_special) * np.array(dists) / sum(dists)
259 samplings = np.int64(np.round(scaled_dists))
261 # If our sampling does not match the given N add the difference to the longest distance
262 if N - N_special - np.sum(samplings) != 0:
263 samplings[np.argmax(samplings)] += N - N_special - np.sum(samplings)
265 # Generate k-point coordinates
266 k_points = [s_points[path_list[0]]] # Insert the first special point
267 for i in range(len(path_list) - 1):
268 # Only do something when not jumping between special points
269 if "," not in path_list[i : i + 2]:
270 s_start = s_points[path_list[i]]
271 s_end = s_points[path_list[i + 1]]
272 # Get the vector between special points
273 k_dist = np.subtract(s_end, s_start)
274 # Add scaled vectors to the special point to get the new k-points
275 k_points += [
276 s_start + k_dist * (n + 1) / (samplings[i] + 1) for n in range(samplings[i])
277 ]
278 # Append the special point we are ending at
279 k_points.append(s_end)
280 # If we jump, add the new special point to start from
281 elif path_list[i] == ",":
282 k_points.append(s_points[path_list[i + 1]])
283 return np.asarray(k_points)
286def kpoints2axis(kpts):
287 """Generate the x-axis for band structures and the respective band path.
289 Args:
290 kpts: KPoints object.
292 Returns:
293 k-point axis, special point coordinates, and labels.
294 """
295 # Convert path to a list and get the special points
296 path_list = list(kpts.path)
297 s_points = SPECIAL_POINTS[kpts.lattice]
299 # Calculate the distances between k-points
300 k_dist = kpts.k_scaled[1:] - kpts.k_scaled[:-1]
301 dists = norm(kpoint_convert(k_dist, kpts.a), axis=1)
303 # Create the labels
304 labels = []
305 for i in range(len(path_list)):
306 # If a jump happened before the current step the special point is already included
307 if i > 1 and path_list[i - 1] == ",":
308 continue
309 # Append the special point if no jump happens
310 if "," not in path_list[i : i + 2]:
311 labels.append(path_list[i])
312 # When jumping join the special points to one label
313 elif path_list[i] == ",":
314 labels.append("".join(path_list[i - 1 : i + 2]))
316 # Get the indices of the special points
317 special_indices = [0] # The first special point is trivial
318 for p in labels[1:]:
319 # Only search the k-points starting from the previous special point
320 shift = special_indices[-1]
321 k = kpts.k_scaled[shift:]
322 # We index p[0] since p could be a joined label of a jump
323 # This expression simply finds the special point in the k_points matrix
324 index = np.flatnonzero((k == s_points[p[0]]).all(axis=1))[0] + shift
325 special_indices.append(index)
326 # Set the distance between special points to zero if we have a jump
327 if "," in p:
328 dists[index] = 0
330 # Insert a zero at the beginning and add up the lengths to create the k-axis
331 k_axis = np.append([0], np.cumsum(dists))
332 return k_axis, k_axis[special_indices], labels
335def get_brillouin_zone(lattice_vectors):
336 """Generate the Brillouin zone for given lattice vectors.
338 The Brillouin zone can be constructed with a Voronoi decomposition of the reciprocal lattice.
340 Reference: http://staff.ustc.edu.cn/~zqj/posts/howto-plot-brillouin-zone
342 Args:
343 lattice_vectors: Lattice vectors.
345 Returns:
346 Brillouin zone vertices.
347 """
348 inv_cell = kpoint_convert(np.eye(3), lattice_vectors)
350 px, py, pz = np.tensordot(inv_cell, np.mgrid[-1:2, -1:2, -1:2], axes=(0, 0))
351 points = np.c_[px.ravel(), py.ravel(), pz.ravel()]
352 vor = Voronoi(points)
354 bz_ridges = []
355 for pid, rid in zip(vor.ridge_points, vor.ridge_vertices):
356 if pid[0] == 13 or pid[1] == 13:
357 bz_ridges.append(vor.vertices[np.r_[rid, [rid[0]]]])
358 return bz_ridges