Coverage for /builds/debichem-team/python-ase/ase/spectrum/band_structure.py: 84.32%

185 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import numpy as np 

2 

3import ase # Annotations 

4from ase.calculators.calculator import PropertyNotImplementedError 

5from ase.utils import jsonable 

6 

7 

8def calculate_band_structure(atoms, path=None, scf_kwargs=None, 

9 bs_kwargs=None, kpts_tol=1e-6, cell_tol=1e-6): 

10 """Calculate band structure. 

11 

12 The purpose of this function is to abstract a band structure calculation 

13 so the workflow does not depend on the calculator. 

14 

15 First trigger SCF calculation if necessary, then set arguments 

16 on the calculator for band structure calculation, then return 

17 calculated band structure. 

18 

19 The difference from get_band_structure() is that the latter 

20 expects the calculation to already have been done.""" 

21 if path is None: 

22 path = atoms.cell.bandpath() 

23 

24 from ase.lattice import celldiff # Should this be a method on cell? 

25 if any(path.cell.any(1) != atoms.pbc): 

26 raise ValueError('The band path\'s cell, {}, does not match the ' 

27 'periodicity {} of the atoms' 

28 .format(path.cell, atoms.pbc)) 

29 cell_err = celldiff(path.cell, atoms.cell.uncomplete(atoms.pbc)) 

30 if cell_err > cell_tol: 

31 raise ValueError('Atoms and band path have different unit cells. ' 

32 'Please reduce atoms to standard form. ' 

33 'Cell lengths and angles are {} vs {}' 

34 .format(atoms.cell.cellpar(), path.cell.cellpar())) 

35 

36 calc = atoms.calc 

37 if calc is None: 

38 raise ValueError('Atoms have no calculator') 

39 

40 if scf_kwargs is not None: 

41 calc.set(**scf_kwargs) 

42 

43 # Proposed standard mechanism for calculators to advertise that they 

44 # use the bandpath keyword to handle band structures rather than 

45 # a double (SCF + BS) run. 

46 use_bandpath_kw = getattr(calc, 'accepts_bandpath_keyword', False) 

47 if use_bandpath_kw: 

48 calc.set(bandpath=path) 

49 atoms.get_potential_energy() 

50 return calc.band_structure() 

51 

52 atoms.get_potential_energy() 

53 

54 if hasattr(calc, 'get_fermi_level'): 

55 # What is the protocol for a calculator to tell whether 

56 # it has fermi_energy? 

57 eref = calc.get_fermi_level() 

58 else: 

59 eref = 0.0 

60 

61 if bs_kwargs is None: 

62 bs_kwargs = {} 

63 

64 calc.set(kpts=path, **bs_kwargs) 

65 calc.results.clear() # XXX get rid of me 

66 

67 # Calculators are too inconsistent here: 

68 # * atoms.get_potential_energy() will fail when total energy is 

69 # not in results after BS calculation (Espresso) 

70 # * calc.calculate(atoms) doesn't ask for any quantity, so some 

71 # calculators may not calculate anything at all 

72 # * 'bandstructure' is not a recognized property we can ask for 

73 try: 

74 atoms.get_potential_energy() 

75 except PropertyNotImplementedError: 

76 pass 

77 

78 ibzkpts = calc.get_ibz_k_points() 

79 kpts_err = np.abs(path.kpts - ibzkpts).max() 

80 if kpts_err > kpts_tol: 

81 raise RuntimeError('Kpoints of calculator differ from those ' 

82 'of the band path we just used; ' 

83 'err={} > tol={}'.format(kpts_err, kpts_tol)) 

84 

85 bs = get_band_structure(atoms, path=path, reference=eref) 

86 return bs 

87 

88 

89def get_band_structure(atoms=None, calc=None, path=None, reference=None): 

90 """Create band structure object from Atoms or calculator.""" 

91 # path and reference are used internally at the moment, but 

92 # the exact implementation will probably change. WIP. 

93 # 

94 # XXX We throw away info about the bandpath when we create the calculator. 

95 # If we have kept the bandpath, we can provide it as an argument here. 

96 # It would be wise to check that the bandpath kpoints are the same as 

97 # those stored in the calculator. 

98 atoms = atoms if atoms is not None else calc.atoms 

99 calc = calc if calc is not None else atoms.calc 

100 

101 kpts = calc.get_ibz_k_points() 

102 

103 energies = [] 

104 for s in range(calc.get_number_of_spins()): 

105 energies.append([calc.get_eigenvalues(kpt=k, spin=s) 

106 for k in range(len(kpts))]) 

107 energies = np.array(energies) 

108 

109 if path is None: 

110 from ase.dft.kpoints import ( 

111 BandPath, 

112 find_bandpath_kinks, 

113 resolve_custom_points, 

114 ) 

115 standard_path = atoms.cell.bandpath(npoints=0) 

116 # Kpoints are already evaluated, we just need to put them into 

117 # the path (whether they fit our idea of what the path is, or not). 

118 # 

119 # Depending on how the path was established, the kpoints might 

120 # be valid high-symmetry points, but since there are multiple 

121 # high-symmetry points of each type, they may not coincide 

122 # with ours if the bandpath was generated by another code. 

123 # 

124 # Here we hack it so the BandPath has proper points even if they 

125 # come from some weird source. 

126 # 

127 # This operation (manually hacking the bandpath) is liable to break. 

128 # TODO: Make it available as a proper (documented) bandpath method. 

129 kinks = find_bandpath_kinks(atoms.cell, kpts, eps=1e-5) 

130 pathspec, special_points = resolve_custom_points( 

131 kpts[kinks], standard_path.special_points, eps=1e-5) 

132 path = BandPath(standard_path.cell, 

133 kpts=kpts, 

134 path=pathspec, 

135 special_points=special_points) 

136 

137 # XXX If we *did* get the path, now would be a good time to check 

138 # that it matches the cell! Although the path can only be passed 

139 # because we internally want to not re-evaluate the Bravais 

140 # lattice type. (We actually need an eps parameter, too.) 

141 

142 if reference is None: 

143 # Fermi level should come from the GS calculation, not the BS one! 

144 reference = calc.get_fermi_level() 

145 

146 if reference is None: 

147 # Fermi level may not be available, e.g., with non-Fermi smearing. 

148 # XXX Actually get_fermi_level() should raise an error when Fermi 

149 # level wasn't available, so we should fix that. 

150 reference = 0.0 

151 

152 return BandStructure(path=path, 

153 energies=energies, 

154 reference=reference) 

155 

156 

157class BandStructurePlot: 

158 def __init__(self, bs): 

159 self.bs = bs 

160 self.ax = None 

161 self.xcoords = None 

162 

163 def plot(self, ax=None, emin=-10, emax=5, filename=None, 

164 show=False, ylabel=None, colors=None, point_colors=None, 

165 label=None, loc=None, 

166 cmap=None, cmin=-1.0, cmax=1.0, sortcolors=False, 

167 colorbar=True, clabel='$s_z$', cax=None, 

168 **plotkwargs): 

169 """Plot band-structure. 

170 

171 ax: Axes 

172 MatPlotLib Axes object. Will be created if not supplied. 

173 emin, emax: float 

174 Minimum and maximum energy above reference. 

175 filename: str 

176 If given, write image to a file. 

177 show: bool 

178 Show the image (not needed in notebooks). 

179 ylabel: str 

180 The label along the y-axis. Defaults to 'energies [eV]' 

181 colors: sequence of str 

182 A sequence of one or two color specifications, depending on 

183 whether there is spin. 

184 Default: green if no spin, yellow and blue if spin is present. 

185 point_colors: ndarray 

186 An array of numbers of the shape (nspins, n_kpts, nbands) which 

187 are then mapped onto colors by the colormap (see ``cmap``). 

188 ``colors`` and ``point_colors`` are mutually exclusive 

189 label: str or list of str 

190 Label for the curves on the legend. A string if one spin is 

191 present, a list of two strings if two spins are present. 

192 Default: If no spin is given, no legend is made; if spin is 

193 present default labels 'spin up' and 'spin down' are used, but 

194 can be suppressed by setting ``label=False``. 

195 loc: str 

196 Location of the legend. 

197 

198 If ``point_colors`` is given, the following arguments can be specified. 

199 

200 cmap: 

201 Only used if colors is an array of numbers. A matplotlib 

202 colormap object, or a string naming a standard colormap. 

203 Default: The matplotlib default, typically 'viridis'. 

204 cmin, cmax: float 

205 Minimal and maximal values used for colormap translation. 

206 Default: -1.0 and 1.0 

207 colorbar: bool 

208 Whether to make a colorbar. 

209 clabel: str 

210 Label for the colorbar (default 's_z', set to None to suppress. 

211 cax: Axes 

212 Axes object used for plotting colorbar. Default: split off a 

213 new one. 

214 sortcolors (bool or callable): 

215 Sort points so highest color values are in front. If a callable is 

216 given, then it is called on the color values to determine the sort 

217 order. 

218 

219 Any additional keyword arguments are passed directly to matplotlib's 

220 plot() or scatter() methods, depending on whether point_colors is 

221 given. 

222 """ 

223 import matplotlib.pyplot as plt 

224 

225 if colors is not None and point_colors is not None: 

226 raise ValueError("Don't give both 'color' and 'point_color'") 

227 

228 if self.ax is None: 

229 ax = self.prepare_plot(ax, emin, emax, ylabel) 

230 

231 e_skn = self.bs.energies 

232 nspins = len(e_skn) 

233 

234 if point_colors is None: 

235 # Normal band structure plot 

236 if colors is None: 

237 if len(e_skn) == 1: 

238 colors = 'g' 

239 else: 

240 colors = 'yb' 

241 elif (len(colors) != nspins): 

242 raise ValueError( 

243 "colors should be a sequence of {nspin} colors" 

244 ) 

245 

246 # Default values for label 

247 if label is None and nspins == 2: 

248 label = ['spin up', 'spin down'] 

249 

250 if label: 

251 if nspins == 1 and isinstance(label, str): 

252 label = [label] 

253 elif len(label) != nspins: 

254 raise ValueError( 

255 f'label should be a list of {nspins} strings' 

256 ) 

257 

258 for spin, e_kn in enumerate(e_skn): 

259 kwargs = dict(color=colors[spin]) 

260 kwargs.update(plotkwargs) 

261 lbl = None # Retain lbl=None if label=False 

262 if label: 

263 lbl = label[spin] 

264 ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs) 

265 

266 for e_k in e_kn.T[1:]: 

267 ax.plot(self.xcoords, e_k, **kwargs) 

268 show_legend = label is not None or nspins == 2 

269 

270 else: 

271 # A color per datapoint. 

272 kwargs = dict(vmin=cmin, vmax=cmax, cmap=cmap, s=1) 

273 kwargs.update(plotkwargs) 

274 shape = e_skn.shape 

275 xcoords = np.zeros(shape) 

276 xcoords += self.xcoords[np.newaxis, :, np.newaxis] 

277 if sortcolors: 

278 if callable(sortcolors): 

279 perm = sortcolors(point_colors).argsort(axis=None) 

280 else: 

281 perm = point_colors.argsort(axis=None) 

282 e_skn = e_skn.ravel()[perm].reshape(shape) 

283 point_colors = point_colors.ravel()[perm].reshape(shape) 

284 xcoords = xcoords.ravel()[perm].reshape(shape) 

285 

286 things = ax.scatter(xcoords, e_skn, c=point_colors, **kwargs) 

287 if colorbar: 

288 cbar = plt.colorbar(things, cax=cax) 

289 if clabel: 

290 cbar.set_label(clabel) 

291 show_legend = False 

292 

293 self.finish_plot(filename, show, loc, show_legend) 

294 

295 return ax 

296 

297 def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None): 

298 import matplotlib.pyplot as plt 

299 if ax is None: 

300 ax = plt.figure().add_subplot(111) 

301 

302 def pretty(kpt): 

303 if kpt == 'G': 

304 kpt = r'$\Gamma$' 

305 elif len(kpt) == 2: 

306 kpt = kpt[0] + '$_' + kpt[1] + '$' 

307 return kpt 

308 

309 self.xcoords, label_xcoords, orig_labels = self.bs.get_labels() 

310 label_xcoords = list(label_xcoords) 

311 labels = [pretty(name) for name in orig_labels] 

312 

313 i = 1 

314 while i < len(labels): 

315 if label_xcoords[i - 1] == label_xcoords[i]: 

316 labels[i - 1] = labels[i - 1] + ',' + labels[i] 

317 labels.pop(i) 

318 label_xcoords.pop(i) 

319 else: 

320 i += 1 

321 

322 for x in label_xcoords[1:-1]: 

323 ax.axvline(x, color='0.5') 

324 

325 ylabel = ylabel if ylabel is not None else 'energies [eV]' 

326 

327 ax.set_xticks(label_xcoords) 

328 ax.set_xticklabels(labels) 

329 ax.set_ylabel(ylabel) 

330 ax.axhline(self.bs.reference, color='k', ls=':') 

331 ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax) 

332 self.ax = ax 

333 return ax 

334 

335 def finish_plot(self, filename, show, loc, show_legend=False): 

336 import matplotlib.pyplot as plt 

337 

338 if show_legend: 

339 leg = plt.legend(loc=loc) 

340 leg.get_frame().set_alpha(1) 

341 

342 if filename: 

343 plt.savefig(filename) 

344 

345 if show: 

346 plt.show() 

347 

348 

349@jsonable('bandstructure') 

350class BandStructure: 

351 """A band structure consists of an array of eigenvalues and a bandpath. 

352 

353 BandStructure objects support JSON I/O. 

354 """ 

355 

356 def __init__(self, path, energies, reference=0.0): 

357 self._path = path 

358 self._energies = np.asarray(energies) 

359 assert self.energies.shape[0] in [1, 2] # spins x kpts x bands 

360 assert self.energies.shape[1] == len(path.kpts) 

361 assert np.isscalar(reference) 

362 self._reference = reference 

363 

364 @property 

365 def energies(self) -> np.ndarray: 

366 """The energies of this band structure. 

367 

368 This is a numpy array of shape (nspins, nkpoints, nbands).""" 

369 return self._energies 

370 

371 @property 

372 def path(self) -> 'ase.dft.kpoints.BandPath': 

373 """The :class:`~ase.dft.kpoints.BandPath` of this band structure.""" 

374 return self._path 

375 

376 @property 

377 def reference(self) -> float: 

378 """The reference energy. 

379 

380 Semantics may vary; typically a Fermi energy or zero, 

381 depending on how the band structure was created.""" 

382 return self._reference 

383 

384 def subtract_reference(self) -> 'BandStructure': 

385 """Return new band structure with reference energy subtracted.""" 

386 return BandStructure(self.path, self.energies - self.reference, 

387 reference=0.0) 

388 

389 def todict(self): 

390 return dict(path=self.path, 

391 energies=self.energies, 

392 reference=self.reference) 

393 

394 def get_labels(self, eps=1e-5): 

395 """"See :func:`ase.dft.kpoints.labels_from_kpts`.""" 

396 return self.path.get_linear_kpoint_axis(eps=eps) 

397 

398 def plot(self, *args, **kwargs): 

399 """Plot this band structure.""" 

400 bsp = BandStructurePlot(self) 

401 return bsp.plot(*args, **kwargs) 

402 

403 def __repr__(self): 

404 return ('{}(path={!r}, energies=[{} values], reference={})' 

405 .format(self.__class__.__name__, self.path, 

406 '{}x{}x{}'.format(*self.energies.shape), 

407 self.reference))