Coverage for eminus/operators.py: 98.92%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-18 08:43 +0000

1# SPDX-FileCopyrightText: 2021 The eminus developers 

2# SPDX-License-Identifier: Apache-2.0 

3"""Basis set dependent operators for a plane wave basis. 

4 

5These operators act on discretized wave functions, i.e., the arrays W. 

6 

7These W are column vectors. This has been chosen to let theory and code coincide, e.g., 

8W^dagger W becomes :code:`W.conj().T @ W`. 

9 

10The downside is that the i-th state will be accessed with W[:, i] instead of W[i]. 

11Choosing the i-th state makes the array 1d. 

12 

13These operators can act on six different options, namely 

14 

151. the real-space 

162. the real-space (1d) 

173. the full reciprocal space 

184. the full reciprocal space (1d) 

195. the active reciprocal space 

206. the active reciprocal space (1d) 

21 

22The active space is the truncated reciprocal space by restricting it with a sphere given by ecut. 

23 

24Every spin dependence will be handled with handle_spin by calling the operators for each 

25spin individually. The same goes for the handling of k-points, while for k-points W is represented 

26as a list of arrays. This gives the final indexing for the k-point k, spin s, and state n of 

27W[ik][s, :, n]. 

28""" 

29 

30import copy 

31 

32import numpy as np 

33from scipy.fft import fftn, ifftn 

34 

35from . import config 

36from .utils import handle_k, handle_spin, handle_torch 

37 

38 

39# Spin handling is trivial for this operator 

40@handle_k 

41def O(atoms, W): 

42 """Overlap operator. 

43 

44 This operator acts on the options 3, 4, 5, and 6. 

45 

46 Reference: Comput. Phys. Commun. 128, 1. 

47 

48 Args: 

49 atoms: Atoms object. 

50 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

51 

52 Returns: 

53 The operator applied on W. 

54 """ 

55 return atoms.Omega * W 

56 

57 

58@handle_spin 

59def L(atoms, W, ik=-1): 

60 """Laplacian operator with k-point dependency. 

61 

62 This operator acts on options 3 and 5. 

63 

64 Reference: Comput. Phys. Commun. 128, 1. 

65 

66 Args: 

67 atoms: Atoms object. 

68 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

69 

70 Keyword Args: 

71 ik: k-point index. 

72 

73 Returns: 

74 The operator applied on W. 

75 """ 

76 # Gk2 is a normal 1d row vector, reshape it so it can be applied to the column vector W 

77 if len(W) == len(atoms.Gk2c[ik]): 

78 Gk2 = atoms.Gk2c[ik][:, None] 

79 else: 

80 Gk2 = atoms.Gk2[ik][:, None] 

81 return -atoms.Omega * Gk2 * W 

82 

83 

84@handle_spin 

85def Linv(atoms, W): 

86 """Inverse Laplacian operator. 

87 

88 This operator acts on options 3 and 4. 

89 

90 Reference: Comput. Phys. Commun. 128, 1. 

91 

92 Args: 

93 atoms: Atoms object. 

94 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

95 

96 Returns: 

97 The operator applied on W. 

98 """ 

99 # Ignore the division by zero for the first elements 

100 with np.errstate(divide="ignore", invalid="ignore"): 

101 if W.ndim == 1: 

102 # One could do some proper indexing with [1:] but indexing is slow 

103 out = W / (atoms.G2 * -atoms.Omega) 

104 out[0] = 0 

105 else: 

106 # G2 is a normal 1d row vector, reshape it so it can be applied to the column vector W 

107 G2 = atoms.G2[:, None] 

108 out = W / (G2 * -atoms.Omega) 

109 out[0, :] = 0 

110 return out 

111 

112 

113@handle_torch 

114@handle_k(mode="index") 

115@handle_spin 

116def I(atoms, W, ik=-1): 

117 """Backward transformation from reciprocal space to real-space. 

118 

119 This operator acts on the options 3, 4, 5, and 6. 

120 

121 Reference: Comput. Phys. Commun. 128, 1. 

122 

123 Args: 

124 atoms: Atoms object. 

125 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

126 

127 Keyword Args: 

128 ik: k-point index. 

129 

130 Returns: 

131 The operator applied on W. 

132 """ 

133 n = atoms.Ns 

134 

135 # If W is in the full space do nothing with W 

136 if len(W) == len(atoms.Gk2[ik]): 

137 Wfft = np.copy(W) 

138 else: 

139 # Fill with zeros if W is in the active space 

140 if W.ndim == 1: 

141 Wfft = np.zeros(n, dtype=W.dtype) 

142 else: 

143 Wfft = np.zeros((n, W.shape[-1]), dtype=W.dtype) 

144 Wfft[atoms.active[ik]] = W 

145 

146 # `workers` sets the number of threads the FFT operates on 

147 # `overwrite_x` allows writing in Wfft, but since we do not need Wfft later on, we can set this 

148 # for a little bit of extra performance 

149 # Normally, we would have to multiply by n in the end for the correct normalization, but we can 

150 # ignore this step when properly setting the `norm` option for a faster operation 

151 if W.ndim == 1: 

152 Wfft = Wfft.reshape(atoms.s) 

153 Finv = ifftn(Wfft, workers=config.threads, overwrite_x=True, norm="forward").ravel() 

154 else: 

155 # Here we reshape the input like in the 1d case but add an extra dimension in the end, 

156 # holding the number of states 

157 Wfft = Wfft.reshape(np.append(atoms.s, W.shape[-1])) 

158 # Tell the function that the FFT only has to act on the first 3 axes 

159 Finv = ifftn( 

160 Wfft, workers=config.threads, overwrite_x=True, norm="forward", axes=(0, 1, 2) 

161 ).reshape((n, W.shape[-1])) 

162 return Finv 

163 

164 

165@handle_torch 

166@handle_k(mode="index") 

167@handle_spin 

168def J(atoms, W, ik=-1, full=True): 

169 """Forward transformation from real-space to reciprocal space. 

170 

171 This operator acts on options 1 and 2. 

172 

173 Reference: Comput. Phys. Commun. 128, 1. 

174 

175 Args: 

176 atoms: Atoms object. 

177 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

178 

179 Keyword Args: 

180 ik: k-point index. 

181 full: Whether to transform in the full or in the active space. 

182 

183 Returns: 

184 The operator applied on W. 

185 """ 

186 n = atoms.Ns 

187 Wfft = np.copy(W) 

188 

189 # `workers` sets the number of threads the FFT operates on 

190 # `overwrite_x` allows writing in Wfft, but since we do not need Wfft later on, we can set this 

191 # for a little bit of extra performance 

192 # Normally, we would have to divide by n in the end for the correct normalization, but we can 

193 # ignore this step when properly setting the `norm` option for a faster operation 

194 if W.ndim == 1: 

195 Wfft = Wfft.reshape(atoms.s) 

196 F = fftn(Wfft, workers=config.threads, overwrite_x=True, norm="forward").ravel() 

197 else: 

198 Wfft = Wfft.reshape(np.append(atoms.s, W.shape[-1])) 

199 F = fftn( 

200 Wfft, workers=config.threads, overwrite_x=True, norm="forward", axes=(0, 1, 2) 

201 ).reshape((n, W.shape[-1])) 

202 

203 # There is no way to know if J has to transform to the full or the active space 

204 # but normally it transforms to the full space 

205 if not full: 

206 return F[atoms.active[ik]] 

207 return F 

208 

209 

210@handle_torch 

211@handle_k(mode="index") 

212@handle_spin 

213def Idag(atoms, W, ik=-1, full=False): 

214 """Conjugated backward transformation from real-space to reciprocal space. 

215 

216 This operator acts on options 1 and 2. 

217 

218 Reference: Comput. Phys. Commun. 128, 1. 

219 

220 Args: 

221 atoms: Atoms object. 

222 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

223 

224 Keyword Args: 

225 ik: k-point index. 

226 full: Whether to transform in the full or in the active space. 

227 

228 Returns: 

229 The operator applied on W. 

230 """ 

231 n = atoms.Ns 

232 F = J(atoms, W, ik, full) 

233 return F * n 

234 

235 

236@handle_torch 

237@handle_k(mode="index") 

238@handle_spin 

239def Jdag(atoms, W, ik=-1): 

240 """Conjugated forward transformation from reciprocal space to real-space. 

241 

242 This operator acts on the options 3, 4, 5, and 6. 

243 

244 Reference: Comput. Phys. Commun. 128, 1. 

245 

246 Args: 

247 atoms: Atoms object. 

248 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

249 

250 Keyword Args: 

251 ik: k-point index. 

252 

253 Returns: 

254 The operator applied on W. 

255 """ 

256 n = atoms.Ns 

257 Finv = I(atoms, W, ik) 

258 return Finv / n 

259 

260 

261@handle_spin 

262def K(atoms, W, ik): 

263 """Preconditioning operator with k-point dependency. 

264 

265 This operator acts on options 3 and 5. 

266 

267 Reference: Comput. Mater. Sci. 14, 4. 

268 

269 Args: 

270 atoms: Atoms object. 

271 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

272 ik: k-point index. 

273 

274 Returns: 

275 The operator applied on W. 

276 """ 

277 # Gk2c is a normal 1d row vector, reshape it so it can be applied to the column vector W 

278 return W / (1 + atoms.Gk2c[ik][:, None]) 

279 

280 

281def T(atoms, W, dr): 

282 """Translation operator. 

283 

284 This operator acts on options 5 and 6. 

285 

286 Reference: https://ccrma.stanford.edu/~jos/st/Shift_Theorem.html 

287 

288 Args: 

289 atoms: Atoms object. 

290 W: Expansion coefficients of unconstrained wave functions in reciprocal space. 

291 ik: k-point index. 

292 dr: Real-space shifting vector. 

293 

294 Returns: 

295 The operator applied on W. 

296 """ 

297 # We can not use a fancy decorator for this operator, so handle it here 

298 if isinstance(W, np.ndarray) and W.ndim == 3: 

299 return np.asarray([T(atoms, Wspin, dr) for Wspin in W]) 

300 

301 if isinstance(W, np.ndarray): 

302 atoms.kpts._assert_gamma_only() 

303 if len(W) == len(atoms.Gk2c[0]): 

304 G = atoms.G[atoms.active[0]] 

305 elif len(W) == len(atoms.Gk2c[-1]): 

306 G = atoms.G[atoms.active[-1]] 

307 else: 

308 G = atoms.G 

309 factor = np.exp(-1j * G @ dr) 

310 if W.ndim == 2: 

311 factor = factor[:, None] 

312 return factor * W 

313 

314 # If W is a list we have to account for k-points 

315 Wshift = copy.deepcopy(W) 

316 for ik in range(atoms.kpts.Nk): 

317 # Do the shift by multiplying a phase factor, given by the shift theorem 

318 if W[ik].shape[1] == len(atoms.Gk2c[ik]): 

319 Gk = atoms.G[atoms.active[ik]] + atoms.kpts.k[ik] 

320 else: 

321 Gk = atoms.G + atoms.kpts.k[ik] 

322 Wshift[ik] = np.exp(-1j * Gk @ dr)[:, None] * W[ik] 

323 return Wshift