Coverage for eminus/kpoints.py: 97.21%

179 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-18 08:43 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Generate k-points and sample band paths.""" 

4 

5import numbers 

6 

7import numpy as np 

8from scipy.linalg import inv, norm 

9from scipy.spatial import Voronoi 

10 

11from .data import LATTICE_VECTORS, SPECIAL_POINTS 

12from .logger import log 

13from .utils import BaseObject 

14 

15 

16class KPoints(BaseObject): 

17 """KPoints object that holds k-points properties and build methods. 

18 

19 Args: 

20 lattice: Lattice system. 

21 a: Cell size. 

22 """ 

23 

24 def __init__(self, lattice, a=None): 

25 """Initialize the KPoints object.""" 

26 self.lattice = lattice #: Lattice system. 

27 if a is None: 

28 a = LATTICE_VECTORS[self.lattice] 

29 if isinstance(a, numbers.Real): 

30 a = a * np.asarray(LATTICE_VECTORS[self.lattice]) 

31 self.a = a #: Cell size. 

32 self.kmesh = [1, 1, 1] #: k-point mesh. 

33 self.wk = [1] #: k-point weights. 

34 self.k = [[0, 0, 0]] #: k-point coordinates. 

35 self.kshift = [0, 0, 0] #: k-point shift-vector. 

36 self._gamma_centered = True #: Generate a Gamma-point centered grid. 

37 self.is_built = True #: Determines the KPoints object build status. 

38 

39 # ### Class properties ### 

40 

41 @property 

42 def kmesh(self): 

43 """Monkhorst-Pack k-point mesh.""" 

44 return self._kmesh 

45 

46 @kmesh.setter 

47 def kmesh(self, value): 

48 if value is not None: 

49 if isinstance(value, numbers.Integral): 

50 value = value * np.ones(3, dtype=int) 

51 self._kmesh = np.asarray(value) 

52 self.path = None 

53 self.is_built = False 

54 # If we set a band path to the object the k-mesh gets set to None 

55 else: 

56 self._kmesh = None 

57 

58 @property 

59 def wk(self): 

60 """k-point weights.""" 

61 return self._wk 

62 

63 @wk.setter 

64 def wk(self, value): 

65 self._wk = np.asarray(value) 

66 self._Nk = len(self._wk) 

67 self.is_built = False 

68 

69 @property 

70 def k(self): 

71 """k-point coordinates.""" 

72 return self._k 

73 

74 @k.setter 

75 def k(self, value): 

76 self._k = np.asarray(value) 

77 self.is_built = False 

78 

79 @property 

80 def Nk(self): 

81 """Number of k-points.""" 

82 return self._Nk 

83 

84 @Nk.setter 

85 def Nk(self, value): 

86 self._Nk = int(value) 

87 self.is_built = False 

88 

89 @property 

90 def kshift(self): 

91 """k-point shift-vector.""" 

92 return self._kshift 

93 

94 @kshift.setter 

95 def kshift(self, value): 

96 self._kshift = np.asarray(value) 

97 self.is_built = False 

98 

99 @property 

100 def gamma_centered(self): 

101 """Generate a Gamma-point centered grid.""" 

102 return self._gamma_centered 

103 

104 @gamma_centered.setter 

105 def gamma_centered(self, value): 

106 if value != self._gamma_centered: 

107 self._gamma_centered = value 

108 self.is_built = False 

109 

110 @property 

111 def path(self): 

112 """k-point band path.""" 

113 return self._path 

114 

115 @path.setter 

116 def path(self, value): 

117 if value is not None: 

118 self._path = value.upper() 

119 self.kmesh = None 

120 self.is_built = False 

121 # If we set a k-mesh to the object the band path gets set to None 

122 else: 

123 self._path = None 

124 

125 # ### Read-only properties ### 

126 

127 @property 

128 def k_scaled(self): 

129 """Scaled k-point coordinates.""" 

130 # This will not be set when setting the k-point coordinates manually 

131 return self._k_scaled 

132 

133 # ### Class methods ### 

134 

135 def build(self): 

136 """Build all parameters of the KPoints object.""" 

137 if self.lattice == "sc" and not (self.a == np.diag(np.diag(self.a))).all(): 

138 log.warning("Lattice system and lattice vectors do not match.") 

139 if self.is_built: 

140 return self 

141 if self.kmesh is not None: 

142 if self.gamma_centered: 

143 self._k_scaled = gamma_centered(self.kmesh) 

144 else: 

145 self._k_scaled = monkhorst_pack(self.kmesh) 

146 else: 

147 self._k_scaled = bandpath(self) 

148 # Without removing redundancies the weight is the same for all k-points 

149 self.wk = np.ones(len(self._k_scaled)) / len(self._k_scaled) 

150 self.k = kpoint_convert(self._k_scaled, self.a) + self.kshift 

151 self.is_built = True 

152 return self 

153 

154 kernel = build 

155 

156 def _assert_gamma_only(self): 

157 """Make sure that the object only contains the Gamma point.""" 

158 if not np.all(self.k == 0): 

159 msg = "The k-points object does not contain only the Gamma point." 

160 raise NotImplementedError(msg) 

161 

162 def __repr__(self): 

163 """Print the parameters stored in the KPoints object.""" 

164 return ( 

165 f"Number of k-points: {self.Nk}\n" 

166 f"k-mesh: {self.kmesh}\n" 

167 f"Band path: {self.path}\n" 

168 f"Shift: {self.kshift}\n" 

169 f"Weights: {self.wk}" 

170 ) 

171 

172 

173def kpoint_convert(k_points, lattice_vectors): 

174 """Convert scaled k-points to cartesian coordinates. 

175 

176 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py 

177 

178 Args: 

179 k_points: k-points. 

180 lattice_vectors: Lattice vectors. 

181 

182 Returns: 

183 k-points in cartesian coordinates. 

184 """ 

185 inv_cell = 2 * np.pi * inv(lattice_vectors).T 

186 return k_points @ inv_cell 

187 

188 

189def monkhorst_pack(nk): 

190 """Generate a Monkhorst-Pack mesh of k-points, i.e., equally spaced k-points. 

191 

192 Reference: https://gitlab.com/ase/ase/-/blob/master/ase/dft/kpoints.py 

193 

194 Args: 

195 nk: Number of k-points per axis. 

196 

197 Returns: 

198 k-points. 

199 """ 

200 # Same index matrix as in Atoms._get_index_matrices() 

201 M = np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3)) 

202 return (M + 0.5) / nk - 0.5 # Normal Monkhorst-Pack grid 

203 

204 

205def gamma_centered(nk): 

206 """Generate a Gamma-point centered mesh of k-points. 

207 

208 Reference: https://github.com/pyscf/pyscf/blob/master/pyscf/pbc/gto/cell.py 

209 

210 Args: 

211 nk: Number of k-points per axis. 

212 

213 Returns: 

214 k-points. 

215 """ 

216 # Same index matrix as in Atoms._get_index_matrices() 

217 M = np.indices(nk).transpose((1, 2, 3, 0)).reshape((-1, 3)) 

218 return M / nk # Gamma-centered grid 

219 

220 

221def bandpath(kpts): 

222 """Generate sampled band paths. 

223 

224 Args: 

225 kpts: KPoints object. 

226 

227 Returns: 

228 Sampled k-points. 

229 """ 

230 # Convert path to a list and get special points 

231 path_list = list(kpts.path) 

232 s_points = SPECIAL_POINTS[kpts.lattice] 

233 # Commas indicate jumps and are no special points 

234 N_special = len([p for p in path_list if p != ","]) 

235 

236 # Input handling 

237 N = kpts.Nk 

238 if N_special > N: 

239 log.warning("Sampling is smaller than the number of special points.") 

240 N = N_special 

241 for p in path_list: 

242 if p not in {*s_points, ","}: 

243 msg = f"{p} is not a special point for the {kpts.lattice} lattice." 

244 raise KeyError(msg) 

245 

246 # Calculate distances between special points 

247 dists = [] 

248 for i in range(len(path_list) - 1): 

249 if "," not in path_list[i : i + 2]: 

250 # Use subtract since s_points are lists 

251 dist = np.subtract(s_points[path_list[i + 1]], s_points[path_list[i]]) 

252 dists.append(norm(kpoint_convert(dist, kpts.a))) 

253 else: 

254 # Set distance to zero when jumping between special points 

255 dists.append(0) 

256 

257 # Calculate sample points between the special points 

258 scaled_dists = (N - N_special) * np.array(dists) / sum(dists) 

259 samplings = np.int64(np.round(scaled_dists)) 

260 

261 # If our sampling does not match the given N add the difference to the longest distance 

262 if N - N_special - np.sum(samplings) != 0: 

263 samplings[np.argmax(samplings)] += N - N_special - np.sum(samplings) 

264 

265 # Generate k-point coordinates 

266 k_points = [s_points[path_list[0]]] # Insert the first special point 

267 for i in range(len(path_list) - 1): 

268 # Only do something when not jumping between special points 

269 if "," not in path_list[i : i + 2]: 

270 s_start = s_points[path_list[i]] 

271 s_end = s_points[path_list[i + 1]] 

272 # Get the vector between special points 

273 k_dist = np.subtract(s_end, s_start) 

274 # Add scaled vectors to the special point to get the new k-points 

275 k_points += [ 

276 s_start + k_dist * (n + 1) / (samplings[i] + 1) for n in range(samplings[i]) 

277 ] 

278 # Append the special point we are ending at 

279 k_points.append(s_end) 

280 # If we jump, add the new special point to start from 

281 elif path_list[i] == ",": 

282 k_points.append(s_points[path_list[i + 1]]) 

283 return np.asarray(k_points) 

284 

285 

286def kpoints2axis(kpts): 

287 """Generate the x-axis for band structures and the respective band path. 

288 

289 Args: 

290 kpts: KPoints object. 

291 

292 Returns: 

293 k-point axis, special point coordinates, and labels. 

294 """ 

295 # Convert path to a list and get the special points 

296 path_list = list(kpts.path) 

297 s_points = SPECIAL_POINTS[kpts.lattice] 

298 

299 # Calculate the distances between k-points 

300 k_dist = kpts.k_scaled[1:] - kpts.k_scaled[:-1] 

301 dists = norm(kpoint_convert(k_dist, kpts.a), axis=1) 

302 

303 # Create the labels 

304 labels = [] 

305 for i in range(len(path_list)): 

306 # If a jump happened before the current step the special point is already included 

307 if i > 1 and path_list[i - 1] == ",": 

308 continue 

309 # Append the special point if no jump happens 

310 if "," not in path_list[i : i + 2]: 

311 labels.append(path_list[i]) 

312 # When jumping join the special points to one label 

313 elif path_list[i] == ",": 

314 labels.append("".join(path_list[i - 1 : i + 2])) 

315 

316 # Get the indices of the special points 

317 special_indices = [0] # The first special point is trivial 

318 for p in labels[1:]: 

319 # Only search the k-points starting from the previous special point 

320 shift = special_indices[-1] 

321 k = kpts.k_scaled[shift:] 

322 # We index p[0] since p could be a joined label of a jump 

323 # This expression simply finds the special point in the k_points matrix 

324 index = np.flatnonzero((k == s_points[p[0]]).all(axis=1))[0] + shift 

325 special_indices.append(index) 

326 # Set the distance between special points to zero if we have a jump 

327 if "," in p: 

328 dists[index] = 0 

329 

330 # Insert a zero at the beginning and add up the lengths to create the k-axis 

331 k_axis = np.append([0], np.cumsum(dists)) 

332 return k_axis, k_axis[special_indices], labels 

333 

334 

335def get_brillouin_zone(lattice_vectors): 

336 """Generate the Brillouin zone for given lattice vectors. 

337 

338 The Brillouin zone can be constructed with a Voronoi decomposition of the reciprocal lattice. 

339 

340 Reference: http://staff.ustc.edu.cn/~zqj/posts/howto-plot-brillouin-zone 

341 

342 Args: 

343 lattice_vectors: Lattice vectors. 

344 

345 Returns: 

346 Brillouin zone vertices. 

347 """ 

348 inv_cell = kpoint_convert(np.eye(3), lattice_vectors) 

349 

350 px, py, pz = np.tensordot(inv_cell, np.mgrid[-1:2, -1:2, -1:2], axes=(0, 0)) 

351 points = np.c_[px.ravel(), py.ravel(), pz.ravel()] 

352 vor = Voronoi(points) 

353 

354 bz_ridges = [] 

355 for pid, rid in zip(vor.ridge_points, vor.ridge_vertices): 

356 if pid[0] == 13 or pid[1] == 13: 

357 bz_ridges.append(vor.vertices[np.r_[rid, [rid[0]]]]) 

358 return bz_ridges