Coverage for eminus/kpoints.py: 97.42%

194 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-02 10:16 +0000

1# SPDX-FileCopyrightText: 2023 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Generate k-points and sample band paths.""" 

4 

5import numbers 

6 

7import numpy as np 

8from scipy.linalg import inv, norm 

9from scipy.spatial import Voronoi 

10 

11from .data import LATTICE_VECTORS, SPECIAL_POINTS 

12from .logger import log 

13from .utils import BaseObject 

14 

15 

16class KPoints(BaseObject): 

17 """KPoints object that holds k-points properties and build methods. 

18 

19 Args: 

20 lattice: Lattice system. 

21 a: Cell size. 

22 """ 

23 

24 def __init__(self, lattice, a=None): 

25 """Initialize the KPoints object.""" 

26 self.lattice = lattice #: Lattice system. 

27 if a is None: 

28 a = LATTICE_VECTORS[self.lattice] 

29 if isinstance(a, numbers.Real): 

30 a = a * np.asarray(LATTICE_VECTORS[self.lattice]) 

31 self.a = a #: Cell size. 

32 self.kmesh = [1, 1, 1] #: k-point mesh. 

33 self.wk = [1] #: k-point weights. 

34 self.k = [[0, 0, 0]] #: k-point coordinates. 

35 self.kshift = [0, 0, 0] #: k-point shift-vector. 

36 self._gamma_centered = True #: Generate a Gamma-point centered grid. 

37 self.is_built = True #: Determines the KPoints object build status. 

38 

39 # ### Class properties ### 

40 

41 @property 

42 def kmesh(self): 

43 """Monkhorst-Pack k-point mesh.""" 

44 return self._kmesh 

45 

46 @kmesh.setter 

47 def kmesh(self, value): 

48 if value is not None: 

49 if isinstance(value, numbers.Integral): 

50 value = value * np.ones(3, dtype=int) 

51 self._kmesh = np.asarray(value) 

52 self.path = None 

53 self.is_built = False 

54 # If we set a band path to the object the k-mesh gets set to None 

55 else: 

56 self._kmesh = None 

57 

58 @property 

59 def wk(self): 

60 """k-point weights.""" 

61 return self._wk 

62 

63 @wk.setter 

64 def wk(self, value): 

65 self._wk = np.asarray(value) 

66 self._Nk = len(self._wk) 

67 self.is_built = False 

68 

69 @property 

70 def k(self): 

71 """k-point coordinates.""" 

72 return self._k 

73 

74 @k.setter 

75 def k(self, value): 

76 self._k = np.asarray(value) 

77 self.is_built = False 

78 

79 @property 

80 def Nk(self): 

81 """Number of k-points.""" 

82 return self._Nk 

83 

84 @Nk.setter 

85 def Nk(self, value): 

86 self._Nk = int(value) 

87 self.is_built = False 

88 

89 @property 

90 def kshift(self): 

91 """k-point shift-vector.""" 

92 return self._kshift 

93 

94 @kshift.setter 

95 def kshift(self, value): 

96 self._kshift = np.asarray(value) 

97 self.is_built = False 

98 

99 @property 

100 def gamma_centered(self): 

101 """Generate a Gamma-point centered grid.""" 

102 return self._gamma_centered 

103 

104 @gamma_centered.setter 

105 def gamma_centered(self, value): 

106 if value != self._gamma_centered: 

107 self._gamma_centered = value 

108 self.is_built = False 

109 

110 @property 

111 def path(self): 

112 """k-point band path.""" 

113 return self._path 

114 

115 @path.setter 

116 def path(self, value): 

117 if value is not None: 

118 self._path = value.upper() 

119 self.kmesh = None 

120 self.is_built = False 

121 # If we set a k-mesh to the object the band path gets set to None 

122 else: 

123 self._path = None 

124 

125 # ### Read-only properties ### 

126 

127 @property 

128 def k_scaled(self): 

129 """Scaled k-point coordinates.""" 

130 # This will not be set when setting the k-point coordinates manually 

131 return self._k_scaled 

132 

133 # ### Class methods ### 

134 

135 def build(self): 

136 """Build all parameters of the KPoints object.""" 

137 if self.lattice == "sc" and not (self.a == np.diag(np.diag(self.a))).all(): 

138 log.warning("Lattice system and lattice vectors do not match.") 

139 if self.is_built: 

140 return self 

141 if self.kmesh is not None: 

142 if self.gamma_centered: 

143 self._k_scaled = gamma_centered(self.kmesh) 

144 else: 

145 self._k_scaled = monkhorst_pack(self.kmesh) 

146 else: 

147 self._k_scaled = bandpath(self) 

148 # Without removing redundancies the weight is the same for all k-points 

149 self.wk = np.ones(len(self._k_scaled)) / len(self._k_scaled) 

150 self.k = kpoint_convert(self._k_scaled, self.a) + self.kshift 

151 self.is_built = True 

152 return self 

153 

154 kernel = build 

155 

156 def trs(self): 

157 """Reduce k-points using time reversal symmetry (k=-k).""" 

158 if not self.is_built: 

159 self.build() 

160 if not np.any(self.k < 0): 

161 log.warning("No negative k-points found. Nothing to do.") 

162 return self 

163 idx_to_remove = [] 

164 for i in range(self.Nk): 

165 for j in range(i + 1, self.Nk): 

166 # Check k=-k within some tolerance 

167 if np.sum(np.abs(self.k[i] + self.k[j])) < 1e-15: 

168 idx_to_remove.append(i) 

169 self.wk[j] += self.wk[i] # Adjust weights 

170 # Delete k-points and weights 

171 self.k = np.delete(self.k, idx_to_remove, axis=0) 

172 self.wk = np.delete(self.wk, idx_to_remove) 

173 return self 

174 

175 def _assert_gamma_only(self): 

176 """Make sure that the object only contains the Gamma point.""" 

177 if not np.all(self.k == 0): 

178 msg = "The k-points object does not contain only the Gamma point." 

179 raise NotImplementedError(msg) 

180 

181 def __repr__(self): 

182 """Print the parameters stored in the KPoints object.""" 

183 return ( 

184 f"Number of k-points: {self.Nk}\n" 

185 f"k-mesh: {self.kmesh}\n" 

186 f"Band path: {self.path}\n" 

187 f"Shift: {self.kshift}\n" 

188 f"Weights: {self.wk}" 

189 ) 

190 

191 

192def kpoint_convert(k_points, lattice_vectors): 

193 """Convert scaled k-points to cartesian coordinates. 

194 

195 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py 

196 

197 Args: 

198 k_points: k-points. 

199 lattice_vectors: Lattice vectors. 

200 

201 Returns: 

202 k-points in cartesian coordinates. 

203 """ 

204 inv_cell = 2 * np.pi * inv(lattice_vectors).T 

205 return k_points @ inv_cell 

206 

207 

208def monkhorst_pack(nk): 

209 """Generate a Monkhorst-Pack mesh of k-points, i.e., equally spaced k-points. 

210 

211 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py 

212 

213 Args: 

214 nk: Number of k-points per axis. 

215 

216 Returns: 

217 k-points. 

218 """ 

219 # Same index matrix as in Atoms._get_index_matrices() 

220 M = np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3)) 

221 return (M + 0.5) / nk - 0.5 # Normal Monkhorst-Pack grid 

222 

223 

224def gamma_centered(nk): 

225 """Generate a Gamma-point centered mesh of k-points. 

226 

227 Reference: https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/gto/cell.py 

228 

229 Args: 

230 nk: Number of k-points per axis. 

231 

232 Returns: 

233 k-points. 

234 """ 

235 # Same index matrix as in Atoms._get_index_matrices() 

236 M = np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3)) 

237 return M / nk # Gamma-centered grid 

238 

239 

240def bandpath(kpts): 

241 """Generate sampled band paths. 

242 

243 Args: 

244 kpts: KPoints object. 

245 

246 Returns: 

247 Sampled k-points. 

248 """ 

249 # Convert path to a list and get special points 

250 path_list = list(kpts.path) 

251 s_points = SPECIAL_POINTS[kpts.lattice] 

252 # Commas indicate jumps and are no special points 

253 N_special = len([p for p in path_list if p != ","]) 

254 

255 # Input handling 

256 N = kpts.Nk 

257 if N_special > N: 

258 log.warning("Sampling is smaller than the number of special points.") 

259 N = N_special 

260 for p in path_list: 

261 if p not in {*s_points, ","}: 

262 msg = f"{p} is not a special point for the {kpts.lattice} lattice." 

263 raise KeyError(msg) 

264 

265 # Calculate distances between special points 

266 dists = [] 

267 for i in range(len(path_list) - 1): 

268 if "," not in path_list[i : i + 2]: 

269 # Use subtract since s_points are lists 

270 dist = np.subtract(s_points[path_list[i + 1]], s_points[path_list[i]]) 

271 dists.append(norm(kpoint_convert(dist, kpts.a))) 

272 else: 

273 # Set distance to zero when jumping between special points 

274 dists.append(0) 

275 

276 # Calculate sample points between the special points 

277 scaled_dists = (N - N_special) * np.array(dists) / sum(dists) 

278 samplings = np.int64(np.round(scaled_dists)) 

279 

280 # If our sampling does not match the given N add the difference to the longest distance 

281 if N - N_special - np.sum(samplings) != 0: 

282 samplings[np.argmax(samplings)] += N - N_special - np.sum(samplings) 

283 

284 # Generate k-point coordinates 

285 k_points = [s_points[path_list[0]]] # Insert the first special point 

286 for i in range(len(path_list) - 1): 

287 # Only do something when not jumping between special points 

288 if "," not in path_list[i : i + 2]: 

289 s_start = s_points[path_list[i]] 

290 s_end = s_points[path_list[i + 1]] 

291 # Get the vector between special points 

292 k_dist = np.subtract(s_end, s_start) 

293 # Add scaled vectors to the special point to get the new k-points 

294 k_points += [ 

295 s_start + k_dist * (n + 1) / (samplings[i] + 1) for n in range(samplings[i]) 

296 ] 

297 # Append the special point we are ending at 

298 k_points.append(s_end) 

299 # If we jump, add the new special point to start from 

300 elif path_list[i] == ",": 

301 k_points.append(s_points[path_list[i + 1]]) 

302 return np.asarray(k_points) 

303 

304 

305def kpoints2axis(kpts): 

306 """Generate the x-axis for band structures and the respective band path. 

307 

308 Args: 

309 kpts: KPoints object. 

310 

311 Returns: 

312 k-point axis, special point coordinates, and labels. 

313 """ 

314 # Convert path to a list and get the special points 

315 path_list = list(kpts.path) 

316 s_points = SPECIAL_POINTS[kpts.lattice] 

317 

318 # Calculate the distances between k-points 

319 k_dist = kpts.k_scaled[1:] - kpts.k_scaled[:-1] 

320 dists = norm(kpoint_convert(k_dist, kpts.a), axis=1) 

321 

322 # Create the labels 

323 labels = [] 

324 for i in range(len(path_list)): 

325 # If a jump happened before the current step the special point is already included 

326 if i > 1 and path_list[i - 1] == ",": 

327 continue 

328 # Append the special point if no jump happens 

329 if "," not in path_list[i : i + 2]: 

330 labels.append(path_list[i]) 

331 # When jumping join the special points to one label 

332 elif path_list[i] == ",": 

333 labels.append("".join(path_list[i - 1 : i + 2])) 

334 

335 # Get the indices of the special points 

336 special_indices = [0] # The first special point is trivial 

337 for p in labels[1:]: 

338 # Only search the k-points starting from the previous special point 

339 shift = special_indices[-1] 

340 k = kpts.k_scaled[shift:] 

341 # We index p[0] since p could be a joined label of a jump 

342 # This expression simply finds the special point in the k_points matrix 

343 index = np.flatnonzero((k == s_points[p[0]]).all(axis=1))[0] + shift 

344 special_indices.append(index) 

345 # Set the distance between special points to zero if we have a jump 

346 if "," in p: 

347 dists[index] = 0 

348 

349 # Insert a zero at the beginning and add up the lengths to create the k-axis 

350 k_axis = np.append([0], np.cumsum(dists)) 

351 return k_axis, k_axis[special_indices], labels 

352 

353 

354def get_brillouin_zone(lattice_vectors): 

355 """Generate the Brillouin zone for given lattice vectors. 

356 

357 The Brillouin zone can be constructed with a Voronoi decomposition of the reciprocal lattice. 

358 

359 Reference: http://staff.ustc.edu.cn/~zqj/posts/howto-plot-brillouin-zone 

360 

361 Args: 

362 lattice_vectors: Lattice vectors. 

363 

364 Returns: 

365 Brillouin zone vertices. 

366 """ 

367 inv_cell = kpoint_convert(np.eye(3), lattice_vectors) 

368 

369 px, py, pz = np.tensordot(inv_cell, np.mgrid[-1:2, -1:2, -1:2], axes=(0, 0)) 

370 points = np.c_[px.ravel(), py.ravel(), pz.ravel()] 

371 vor = Voronoi(points) 

372 

373 bz_ridges = [] 

374 for pid, rid in zip(vor.ridge_points, vor.ridge_vertices): 

375 if pid[0] == 13 or pid[1] == 13: 

376 bz_ridges.append(vor.vertices[np.r_[rid, [rid[0]]]]) 

377 return bz_ridges