Coverage for eminus/kpoints.py: 97.46%

197 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-21 12:19 +0000

1# SPDX-FileCopyrightText: 2023 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Generate k-points and sample band paths.""" 

4 

5import math 

6import numbers 

7 

8import numpy as np 

9from scipy.spatial import Voronoi 

10 

11from . import backend as xp 

12from .data import LATTICE_VECTORS, SPECIAL_POINTS 

13from .logger import log 

14from .utils import BaseObject 

15 

16 

17class KPoints(BaseObject): 

18 """KPoints object that holds k-points properties and build methods. 

19 

20 Args: 

21 lattice: Lattice system. 

22 a: Cell size. 

23 """ 

24 

25 def __init__(self, lattice, a=None): 

26 """Initialize the KPoints object.""" 

27 self.lattice = lattice #: Lattice system. 

28 if a is None: 

29 a = LATTICE_VECTORS[self.lattice] 

30 if isinstance(a, numbers.Real): 

31 a = a * xp.asarray(LATTICE_VECTORS[self.lattice], dtype=float) 

32 self.a = xp.asarray(a, dtype=float) #: Cell size. 

33 self.kmesh = [1, 1, 1] #: k-point mesh. 

34 self.wk = [1] #: k-point weights. 

35 self.k = [[0, 0, 0]] #: k-point coordinates. 

36 self.kshift = [0, 0, 0] #: k-point shift-vector. 

37 self._gamma_centered = True #: Generate a Gamma-point centered grid. 

38 self.is_built = True #: Determines the KPoints object build status. 

39 

40 # ### Class properties ### 

41 

42 @property 

43 def kmesh(self): 

44 """Monkhorst-Pack k-point mesh.""" 

45 return self._kmesh 

46 

47 @kmesh.setter 

48 def kmesh(self, value): 

49 if value is not None: 

50 if isinstance(value, numbers.Integral): 

51 value = value * xp.ones(3, dtype=int) 

52 self._kmesh = xp.asarray(value, dtype=int) 

53 self.path = None 

54 self.is_built = False 

55 # If we set a band path to the object the k-mesh gets set to None 

56 else: 

57 self._kmesh = None 

58 

59 @property 

60 def wk(self): 

61 """k-point weights.""" 

62 return self._wk 

63 

64 @wk.setter 

65 def wk(self, value): 

66 self._wk = xp.asarray(value, dtype=float) 

67 self._Nk = len(self._wk) 

68 self.is_built = False 

69 

70 @property 

71 def k(self): 

72 """k-point coordinates.""" 

73 return self._k 

74 

75 @k.setter 

76 def k(self, value): 

77 self._k = xp.asarray(value, dtype=float) 

78 self.is_built = False 

79 

80 @property 

81 def Nk(self): 

82 """Number of k-points.""" 

83 return self._Nk 

84 

85 @Nk.setter 

86 def Nk(self, value): 

87 self._Nk = int(value) 

88 self.is_built = False 

89 

90 @property 

91 def kshift(self): 

92 """k-point shift-vector.""" 

93 return self._kshift 

94 

95 @kshift.setter 

96 def kshift(self, value): 

97 self._kshift = xp.asarray(value, dtype=float) 

98 self.is_built = False 

99 

100 @property 

101 def gamma_centered(self): 

102 """Generate a Gamma-point centered grid.""" 

103 return self._gamma_centered 

104 

105 @gamma_centered.setter 

106 def gamma_centered(self, value): 

107 if value != self._gamma_centered: 

108 self._gamma_centered = value 

109 self.is_built = False 

110 

111 @property 

112 def path(self): 

113 """k-point band path.""" 

114 return self._path 

115 

116 @path.setter 

117 def path(self, value): 

118 if value is not None: 

119 self._path = value.upper() 

120 self.kmesh = None 

121 self.is_built = False 

122 # If we set a k-mesh to the object the band path gets set to None 

123 else: 

124 self._path = None 

125 

126 # ### Read-only properties ### 

127 

128 @property 

129 def k_scaled(self): 

130 """Scaled k-point coordinates.""" 

131 # This will not be set when setting the k-point coordinates manually 

132 return self._k_scaled 

133 

134 # ### Class methods ### 

135 

136 def build(self): 

137 """Build all parameters of the KPoints object.""" 

138 if self.lattice == "sc" and not xp.all(self.a == xp.diag(xp.diag(self.a))): 

139 log.warning("Lattice system and lattice vectors do not match.") 

140 if self.is_built: 

141 return self 

142 if self.kmesh is not None: 

143 if self.gamma_centered: 

144 self._k_scaled = gamma_centered(self.kmesh) 

145 else: 

146 self._k_scaled = monkhorst_pack(self.kmesh) 

147 else: 

148 self._k_scaled = bandpath(self) 

149 # Without removing redundancies the weight is the same for all k-points 

150 self.wk = xp.ones(len(self._k_scaled)) / len(self._k_scaled) 

151 self.k = kpoint_convert(self._k_scaled, self.a) + self.kshift 

152 self.is_built = True 

153 return self 

154 

155 kernel = build 

156 

157 def trs(self): 

158 """Reduce k-points using time reversal symmetry (k=-k).""" 

159 if not self.is_built: 

160 self.build() 

161 if not xp.any(self.k < 0): 

162 log.warning("No negative k-points found. Nothing to do.") 

163 return self 

164 idx_to_remove = [] 

165 for i in range(self.Nk): 

166 for j in range(i + 1, self.Nk): 

167 # Check k=-k within some tolerance 

168 if xp.sum(xp.abs(self.k[i] + self.k[j])) < 1e-15: 

169 idx_to_remove.append(i) 

170 self.wk[j] += self.wk[i] # Adjust weights 

171 # Delete k-points and weights 

172 self.k = xp.delete(self.k, idx_to_remove, axis=0) 

173 self.wk = xp.delete(self.wk, idx_to_remove) 

174 return self 

175 

176 def _assert_gamma_only(self): 

177 """Make sure that the object only contains the Gamma point.""" 

178 if not xp.all(self.k == 0): 

179 msg = "The k-points object does not contain only the Gamma point." 

180 raise NotImplementedError(msg) 

181 

182 def __repr__(self): 

183 """Print the parameters stored in the KPoints object.""" 

184 return ( 

185 f"Number of k-points: {self.Nk}\n" 

186 f"k-mesh: {self.kmesh}\n" 

187 f"Band path: {self.path}\n" 

188 f"Shift: {self.kshift}\n" 

189 f"Weights: {self.wk}" 

190 ) 

191 

192 

193def kpoint_convert(k_points, lattice_vectors): 

194 """Convert scaled k-points to cartesian coordinates. 

195 

196 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py 

197 

198 Args: 

199 k_points: k-points. 

200 lattice_vectors: Lattice vectors. 

201 

202 Returns: 

203 k-points in cartesian coordinates. 

204 """ 

205 k_points = xp.asarray(k_points, dtype=float) 

206 lattice_vectors = xp.asarray(lattice_vectors, dtype=float) 

207 inv_cell = 2 * math.pi * xp.linalg.inv(lattice_vectors).T 

208 return k_points @ inv_cell 

209 

210 

211def monkhorst_pack(nk): 

212 """Generate a Monkhorst-Pack mesh of k-points, i.e., equally spaced k-points. 

213 

214 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py 

215 

216 Args: 

217 nk: Number of k-points per axis. 

218 

219 Returns: 

220 k-points. 

221 """ 

222 # Same index matrix as in Atoms._get_index_matrices() 

223 M = xp.asarray(np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3))) 

224 return (M + 0.5) / xp.asarray(nk) - 0.5 # Normal Monkhorst-Pack grid 

225 

226 

227def gamma_centered(nk): 

228 """Generate a Gamma-point centered mesh of k-points. 

229 

230 Reference: https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/gto/cell.py 

231 

232 Args: 

233 nk: Number of k-points per axis. 

234 

235 Returns: 

236 k-points. 

237 """ 

238 # Same index matrix as in Atoms._get_index_matrices() 

239 M = xp.asarray(np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3))) 

240 return M / xp.asarray(nk) # Gamma-centered grid 

241 

242 

243def bandpath(kpts): 

244 """Generate sampled band paths. 

245 

246 Args: 

247 kpts: KPoints object. 

248 

249 Returns: 

250 Sampled k-points. 

251 """ 

252 # Convert path to a list and get special points 

253 path_list = list(kpts.path) 

254 s_points = SPECIAL_POINTS[kpts.lattice] 

255 # Commas indicate jumps and are no special points 

256 N_special = len([p for p in path_list if p != ","]) 

257 

258 # Input handling 

259 N = kpts.Nk 

260 if N_special > N: 

261 log.warning("Sampling is smaller than the number of special points.") 

262 N = N_special 

263 for p in path_list: 

264 if p not in {*s_points, ","}: 

265 msg = f"{p} is not a special point for the {kpts.lattice} lattice." 

266 raise KeyError(msg) 

267 

268 # Calculate distances between special points 

269 dists = [] 

270 for i in range(len(path_list) - 1): 

271 if "," not in path_list[i : i + 2]: 

272 # Use subtract since s_points are lists 

273 dist = xp.asarray(s_points[path_list[i + 1]]) - xp.asarray(s_points[path_list[i]]) 

274 dists.append(xp.linalg.norm(kpoint_convert(dist, kpts.a))) 

275 else: 

276 # Set distance to zero when jumping between special points 

277 dists.append(0) 

278 

279 # Calculate sample points between the special points 

280 scaled_dists = (N - N_special) * xp.asarray(dists) / sum(dists) 

281 samplings = xp.asarray(xp.round(scaled_dists), dtype=int) 

282 

283 # If our sampling does not match the given N add the difference to the longest distance 

284 if N - N_special - xp.sum(samplings) != 0: 

285 samplings[xp.argmax(samplings)] += N - N_special - xp.sum(samplings) 

286 

287 # Generate k-point coordinates 

288 k_points = [xp.asarray(s_points[path_list[0]])] # Insert the first special point 

289 for i in range(len(path_list) - 1): 

290 # Only do something when not jumping between special points 

291 if "," not in path_list[i : i + 2]: 

292 s_start = xp.asarray(s_points[path_list[i]]) 

293 s_end = xp.asarray(s_points[path_list[i + 1]]) 

294 # Get the vector between special points 

295 k_dist = s_end - s_start 

296 # Add scaled vectors to the special point to get the new k-points 

297 k_points += [ 

298 s_start + k_dist * (n + 1) / (samplings[i] + 1) for n in range(samplings[i]) 

299 ] 

300 # Append the special point we are ending at 

301 k_points.append(s_end) 

302 # If we jump, add the new special point to start from 

303 elif path_list[i] == ",": 

304 k_points.append(xp.asarray(s_points[path_list[i + 1]])) 

305 return xp.stack(k_points) 

306 

307 

308def kpoints2axis(kpts): 

309 """Generate the x-axis for band structures and the respective band path. 

310 

311 Args: 

312 kpts: KPoints object. 

313 

314 Returns: 

315 k-point axis, special point coordinates, and labels. 

316 """ 

317 # Convert path to a list and get the special points 

318 path_list = list(kpts.path) 

319 s_points = SPECIAL_POINTS[kpts.lattice] 

320 

321 # Calculate the distances between k-points 

322 k_dist = kpts.k_scaled[1:] - kpts.k_scaled[:-1] 

323 dists = xp.linalg.norm(kpoint_convert(k_dist, kpts.a), axis=1) 

324 

325 # Create the labels 

326 labels = [] 

327 for i in range(len(path_list)): 

328 # If a jump happened before the current step the special point is already included 

329 if i > 1 and path_list[i - 1] == ",": 

330 continue 

331 # Append the special point if no jump happens 

332 if "," not in path_list[i : i + 2]: 

333 labels.append(path_list[i]) 

334 # When jumping join the special points to one label 

335 elif path_list[i] == ",": 

336 labels.append("".join(path_list[i - 1 : i + 2])) 

337 

338 # Get the indices of the special points 

339 special_indices = [0] # The first special point is trivial 

340 for p in labels[1:]: 

341 # Only search the k-points starting from the previous special point 

342 shift = special_indices[-1] 

343 k = xp.asarray(kpts.k_scaled[shift:]) 

344 # We index p[0] since p could be a joined label of a jump 

345 # This expression simply finds the special point in the k_points matrix 

346 # The following expressions is a bit more readable, but only works with NumPy, not Torch 

347 # index = np.flatnonzero((k == s_points[p[0]]).all(axis=1))[0] + shift 

348 index = xp.nonzero(xp.ravel(xp.all(k == xp.asarray(s_points[p[0]]), axis=1)))[0][0] + shift 

349 special_indices.append(int(index)) 

350 # Set the distance between special points to zero if we have a jump 

351 if "," in p: 

352 dists[index] = 0 

353 

354 # Insert a zero at the beginning and add up the lengths to create the k-axis 

355 k_axis = xp.concatenate((xp.asarray([0]), xp.cumsum(dists, axis=0))) 

356 return k_axis, k_axis[special_indices], labels 

357 

358 

359def get_brillouin_zone(lattice_vectors): 

360 """Generate the Brillouin zone for given lattice vectors. 

361 

362 The Brillouin zone can be constructed with a Voronoi decomposition of the reciprocal lattice. 

363 

364 Reference: http://staff.ustc.edu.cn/~zqj/posts/howto-plot-brillouin-zone 

365 

366 Args: 

367 lattice_vectors: Lattice vectors. 

368 

369 Returns: 

370 Brillouin zone vertices. 

371 """ 

372 inv_cell = xp.to_np(kpoint_convert(xp.eye(3), lattice_vectors)) 

373 

374 px, py, pz = np.tensordot(inv_cell, np.mgrid[-1:2, -1:2, -1:2], axes=(0, 0)) 

375 points = np.c_[px.ravel(), py.ravel(), pz.ravel()] 

376 vor = Voronoi(points) 

377 

378 bz_ridges = [] 

379 for pid, rid in zip(vor.ridge_points, vor.ridge_vertices): 

380 if pid[0] == 13 or pid[1] == 13: 

381 bz_ridges.append(vor.vertices[np.r_[rid, [rid[0]]]]) 

382 return bz_ridges