Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Refactor of DOS-like data objects 

2# towards replacing ase.dft.dos and ase.dft.pdos 

3from abc import ABCMeta, abstractmethod 

4import warnings 

5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union 

6 

7import numpy as np 

8from ase.utils.plotting import SimplePlottingAxes 

9 

10# This import is for the benefit of type-checking / mypy 

11if False: 

12 import matplotlib.axes 

13 

14# For now we will be strict about Info and say it has to be str->str. Perhaps 

15# later we will allow other types that have reliable comparison operations. 

16Info = Dict[str, str] 

17 

18 

19class DOSData(metaclass=ABCMeta): 

20 """Abstract base class for a single series of DOS-like data 

21 

22 Only the 'info' is a mutable attribute; DOS data is set at init""" 

23 def __init__(self, 

24 info: Info = None) -> None: 

25 if info is None: 

26 self.info = {} 

27 elif isinstance(info, dict): 

28 self.info = info 

29 else: 

30 raise TypeError("Info must be a dict or None") 

31 

32 @abstractmethod 

33 def get_energies(self) -> Sequence[float]: 

34 """Get energy data stored in this object""" 

35 

36 @abstractmethod 

37 def get_weights(self) -> Sequence[float]: 

38 """Get DOS weights stored in this object""" 

39 

40 @abstractmethod 

41 def copy(self) -> 'DOSData': 

42 """Returns a copy in which info dict can be safely mutated""" 

43 

44 def _sample(self, 

45 energies: Sequence[float], 

46 width: float = 0.1, 

47 smearing: str = 'Gauss') -> np.ndarray: 

48 """Sample the DOS data at chosen points, with broadening 

49 

50 Note that no correction is made here for the sampling bin width; total 

51 intensity will vary with sampling density. 

52 

53 Args: 

54 energies: energy values for sampling 

55 width: Width of broadening kernel 

56 smearing: selection of broadening kernel (only "Gauss" is currently 

57 supported) 

58 

59 Returns: 

60 Weights sampled from a broadened DOS at values corresponding to x 

61 """ 

62 

63 self._check_positive_width(width) 

64 weights_grid = np.zeros(len(energies), float) 

65 weights = self.get_weights() 

66 energies = np.asarray(energies, float) 

67 

68 for i, raw_energy in enumerate(self.get_energies()): 

69 delta = self._delta(energies, raw_energy, width, smearing=smearing) 

70 weights_grid += weights[i] * delta 

71 return weights_grid 

72 

73 def _almost_equals(self, other: Any) -> bool: 

74 """Compare with another DOSData for testing purposes""" 

75 if not isinstance(other, type(self)): 

76 return False 

77 if self.info != other.info: 

78 return False 

79 if not np.allclose(self.get_weights(), other.get_weights()): 

80 return False 

81 return np.allclose(self.get_energies(), other.get_energies()) 

82 

83 @staticmethod 

84 def _delta(x: np.ndarray, 

85 x0: float, 

86 width: float, 

87 smearing: str = 'Gauss') -> np.ndarray: 

88 """Return a delta-function centered at 'x0'. 

89 

90 This function is used with numpy broadcasting; if x is a row and x0 is 

91 a column vector, the returned data will be a 2D array with each row 

92 corresponding to a different delta center. 

93 """ 

94 if smearing.lower() == 'gauss': 

95 x1 = -0.5 * ((x - x0) / width)**2 

96 return np.exp(x1) / (np.sqrt(2 * np.pi) * width) 

97 else: 

98 msg = 'Requested smearing type not recognized. Got {}'.format( 

99 smearing) 

100 raise ValueError(msg) 

101 

102 @staticmethod 

103 def _check_positive_width(width): 

104 if width <= 0.0: 

105 msg = 'Cannot add 0 or negative width smearing' 

106 raise ValueError(msg) 

107 

108 def sample_grid(self, 

109 npts: int, 

110 xmin: float = None, 

111 xmax: float = None, 

112 padding: float = 3, 

113 width: float = 0.1, 

114 smearing: str = 'Gauss', 

115 ) -> 'GridDOSData': 

116 """Sample the DOS data on an evenly-spaced energy grid 

117 

118 Args: 

119 npts: Number of sampled points 

120 xmin: Minimum sampled x value; if unspecified, a default is chosen 

121 xmax: Maximum sampled x value; if unspecified, a default is chosen 

122 padding: If xmin/xmax is unspecified, default value will be padded 

123 by padding * width to avoid cutting off peaks. 

124 width: Width of broadening kernel 

125 smearing: selection of broadening kernel (only 'Gauss' is 

126 implemented) 

127 

128 Returns: 

129 (energy values, sampled DOS) 

130 """ 

131 

132 if xmin is None: 

133 xmin = min(self.get_energies()) - (padding * width) 

134 if xmax is None: 

135 xmax = max(self.get_energies()) + (padding * width) 

136 energies_grid = np.linspace(xmin, xmax, npts) 

137 weights_grid = self._sample(energies_grid, width=width, 

138 smearing=smearing) 

139 

140 return GridDOSData(energies_grid, weights_grid, info=self.info.copy()) 

141 

142 def plot(self, 

143 npts: int = 1000, 

144 xmin: float = None, 

145 xmax: float = None, 

146 width: float = 0.1, 

147 smearing: str = 'Gauss', 

148 ax: 'matplotlib.axes.Axes' = None, 

149 show: bool = False, 

150 filename: str = None, 

151 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

152 """Simple 1-D plot of DOS data, resampled onto a grid 

153 

154 If the special key 'label' is present in self.info, this will be set 

155 as the label for the plotted line (unless overruled in mplargs). The 

156 label is only seen if a legend is added to the plot (i.e. by calling 

157 ``ax.legend()``). 

158 

159 Args: 

160 npts, xmin, xmax: output data range, as passed to self.sample_grid 

161 width: Width of broadening kernel for self.sample_grid() 

162 smearing: selection of broadening kernel for self.sample_grid() 

163 ax: existing Matplotlib axes object. If not provided, a new figure 

164 with one set of axes will be created using Pyplot 

165 show: show the figure on-screen 

166 filename: if a path is given, save the figure to this file 

167 mplargs: additional arguments to pass to matplotlib plot command 

168 (e.g. {'linewidth': 2} for a thicker line). 

169 

170 

171 Returns: 

172 Plotting axes. If "ax" was set, this is the same object. 

173 """ 

174 

175 if mplargs is None: 

176 mplargs = {} 

177 if 'label' not in mplargs: 

178 mplargs.update({'label': self.label_from_info(self.info)}) 

179 

180 return self.sample_grid(npts, xmin=xmin, xmax=xmax, 

181 width=width, 

182 smearing=smearing 

183 ).plot(ax=ax, xmin=xmin, xmax=xmax, 

184 show=show, filename=filename, 

185 mplargs=mplargs) 

186 

187 @staticmethod 

188 def label_from_info(info: Dict[str, str]): 

189 """Generate an automatic legend label from info dict""" 

190 if 'label' in info: 

191 return info['label'] 

192 else: 

193 return '; '.join(map(lambda x: '{}: {}'.format(x[0], x[1]), 

194 info.items())) 

195 

196 

197class GeneralDOSData(DOSData): 

198 """Base class for a single series of DOS-like data 

199 

200 Only the 'info' is a mutable attribute; DOS data is set at init 

201 

202 This is the base class for DOSData objects that accept/set seperate 

203 "energies" and "weights" sequences of equal length at init. 

204 

205 """ 

206 def __init__(self, 

207 energies: Union[Sequence[float], np.ndarray], 

208 weights: Union[Sequence[float], np.ndarray], 

209 info: Info = None) -> None: 

210 super().__init__(info=info) 

211 

212 n_entries = len(energies) 

213 if len(weights) != n_entries: 

214 raise ValueError("Energies and weights must be the same length") 

215 

216 # Internally store the data as a np array with two rows; energy, weight 

217 self._data = np.empty((2, n_entries), dtype=float, order='C') 

218 self._data[0, :] = energies 

219 self._data[1, :] = weights 

220 

221 def get_energies(self) -> np.ndarray: 

222 return self._data[0, :].copy() 

223 

224 def get_weights(self) -> np.ndarray: 

225 return self._data[1, :].copy() 

226 

227 D = TypeVar('D', bound='GeneralDOSData') 

228 

229 def copy(self: D) -> D: # noqa F821 

230 return type(self)(self.get_energies(), self.get_weights(), 

231 info=self.info.copy()) 

232 

233 

234class RawDOSData(GeneralDOSData): 

235 """A collection of weighted delta functions which sum to form a DOS 

236 

237 This is an appropriate data container for density-of-states (DOS) or 

238 spectral data where the energy data values not form a known regular 

239 grid. The data may be plotted or resampled for further analysis using the 

240 sample_grid() and plot() methods. Multiple weights at the same 

241 energy value will *only* be combined in output data, and data stored in 

242 RawDOSData is never resampled. A plot_deltas() function is also provided 

243 which plots the raw data. 

244 

245 Metadata may be stored in the info dict, in which keys and values must be 

246 strings. This data is used for selecting and combining multiple DOSData 

247 objects in a DOSCollection object. 

248 

249 When RawDOSData objects are combined with the addition operator:: 

250 

251 big_dos = raw_dos_1 + raw_dos_2 

252 

253 the energy and weights data is *concatenated* (i.e. combined without 

254 sorting or replacement) and the new info dictionary consists of the 

255 *intersection* of the inputs: only key-value pairs that were common to both 

256 of the input objects will be retained in the new combined object. For 

257 example:: 

258 

259 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'}) 

260 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'})) 

261 

262 will yield the equivalent of:: 

263 

264 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'}) 

265 

266 """ 

267 

268 def __add__(self, other: 'RawDOSData') -> 'RawDOSData': 

269 if not isinstance(other, RawDOSData): 

270 raise TypeError("RawDOSData can only be combined with other " 

271 "RawDOSData objects") 

272 

273 # Take intersection of metadata (i.e. only common entries are retained) 

274 new_info = dict(set(self.info.items()) & set(other.info.items())) 

275 

276 # Concatenate the energy/weight data 

277 new_data = np.concatenate((self._data, other._data), axis=1) 

278 

279 new_object = RawDOSData([], [], info=new_info) 

280 new_object._data = new_data 

281 

282 return new_object 

283 

284 def plot_deltas(self, 

285 ax: 'matplotlib.axes.Axes' = None, 

286 show: bool = False, 

287 filename: str = None, 

288 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

289 """Simple plot of sparse DOS data as a set of delta functions 

290 

291 Items at the same x-value can overlap and will not be summed together 

292 

293 Args: 

294 ax: existing Matplotlib axes object. If not provided, a new figure 

295 with one set of axes will be created using Pyplot 

296 show: show the figure on-screen 

297 filename: if a path is given, save the figure to this file 

298 mplargs: additional arguments to pass to matplotlib Axes.vlines 

299 command (e.g. {'linewidth': 2} for a thicker line). 

300 

301 Returns: 

302 Plotting axes. If "ax" was set, this is the same object. 

303 """ 

304 

305 if mplargs is None: 

306 mplargs = {} 

307 

308 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

309 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs) 

310 

311 return ax 

312 

313 

314class GridDOSData(GeneralDOSData): 

315 """A collection of regularly-sampled data which represents a DOS 

316 

317 This is an appropriate data container for density-of-states (DOS) or 

318 spectral data where the intensity values form a regular grid. This 

319 is generally the result of sampling or integrating into discrete 

320 bins, rather than a collection of unique states. The data may be 

321 plotted or resampled for further analysis using the sample_grid() 

322 and plot() methods. 

323 

324 Metadata may be stored in the info dict, in which keys and values must be 

325 strings. This data is used for selecting and combining multiple DOSData 

326 objects in a DOSCollection object. 

327 

328 When RawDOSData objects are combined with the addition operator:: 

329 

330 big_dos = raw_dos_1 + raw_dos_2 

331 

332 the weights data is *summed* (requiring a consistent energy grid) and the 

333 new info dictionary consists of the *intersection* of the inputs: only 

334 key-value pairs that were common to both of the input objects will be 

335 retained in the new combined object. For example:: 

336 

337 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3], 

338 info={'symbol': 'O', 'index': '1'}) 

339 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6], 

340 info={'symbol': 'O', 'index': '2'})) 

341 

342 will yield the equivalent of:: 

343 

344 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'}) 

345 

346 """ 

347 def __init__(self, 

348 energies: Sequence[float], 

349 weights: Sequence[float], 

350 info: Info = None) -> None: 

351 n_entries = len(energies) 

352 if not np.allclose(energies, 

353 np.linspace(energies[0], energies[-1], n_entries)): 

354 raise ValueError("Energies must be an evenly-spaced 1-D grid") 

355 

356 if len(weights) != n_entries: 

357 raise ValueError("Energies and weights must be the same length") 

358 

359 super().__init__(energies, weights, info=info) 

360 self.sigma_cutoff = 3 

361 

362 def _check_spacing(self, width) -> float: 

363 current_spacing = self._data[0, 1] - self._data[0, 0] 

364 if width < (2 * current_spacing): 

365 warnings.warn( 

366 "The broadening width is small compared to the original " 

367 "sampling density. The results are unlikely to be smooth.") 

368 return current_spacing 

369 

370 def _sample(self, 

371 energies: Sequence[float], 

372 width: float = 0.1, 

373 smearing: str = 'Gauss') -> np.ndarray: 

374 current_spacing = self._check_spacing(width) 

375 return super()._sample(energies=energies, 

376 width=width, smearing=smearing 

377 ) * current_spacing 

378 

379 def __add__(self, other: 'GridDOSData') -> 'GridDOSData': 

380 # This method uses direct access to the mutable energy and weights data 

381 # (self._data) to avoid redundant copying operations. The __init__ 

382 # method of GridDOSData will write this to a new array, so on this 

383 # occasion it is safe to pass references to the mutable data. 

384 

385 if not isinstance(other, GridDOSData): 

386 raise TypeError("GridDOSData can only be combined with other " 

387 "GridDOSData objects") 

388 if len(self._data[0, :]) != len(other.get_energies()): 

389 raise ValueError("Cannot add GridDOSData objects with different-" 

390 "length energy grids.") 

391 

392 if not np.allclose(self._data[0, :], other.get_energies()): 

393 raise ValueError("Cannot add GridDOSData objects with different " 

394 "energy grids.") 

395 

396 # Take intersection of metadata (i.e. only common entries are retained) 

397 new_info = dict(set(self.info.items()) & set(other.info.items())) 

398 

399 # Concatenate the energy/weight data 

400 new_weights = self._data[1, :] + other.get_weights() 

401 

402 new_object = GridDOSData(self._data[0, :], new_weights, 

403 info=new_info) 

404 return new_object 

405 

406 @staticmethod 

407 def _interpret_smearing_args(npts: int, 

408 width: float = None, 

409 default_npts: int = 1000, 

410 default_width: float = 0.1 

411 ) -> Tuple[int, Union[float, None]]: 

412 """Figure out what the user intended: resample if width provided""" 

413 if width is not None: 

414 if npts: 

415 return (npts, float(width)) 

416 else: 

417 return (default_npts, float(width)) 

418 else: 

419 if npts: 

420 return (npts, default_width) 

421 else: 

422 return (0, None) 

423 

424 def plot(self, 

425 npts: int = 0, 

426 xmin: float = None, 

427 xmax: float = None, 

428 width: float = None, 

429 smearing: str = 'Gauss', 

430 ax: 'matplotlib.axes.Axes' = None, 

431 show: bool = False, 

432 filename: str = None, 

433 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

434 """Simple 1-D plot of DOS data 

435 

436 Data will be resampled onto a grid with `npts` points unless `npts` is 

437 set to zero, in which case: 

438 

439 - no resampling takes place 

440 - `width` and `smearing` are ignored 

441 - `xmin` and `xmax` affect the axis limits of the plot, not the 

442 underlying data. 

443 

444 If the special key 'label' is present in self.info, this will be set 

445 as the label for the plotted line (unless overruled in mplargs). The 

446 label is only seen if a legend is added to the plot (i.e. by calling 

447 ``ax.legend()``). 

448 

449 Args: 

450 npts, xmin, xmax: output data range, as passed to self.sample_grid 

451 width: Width of broadening kernel, passed to self.sample_grid(). 

452 If no npts was set but width is set, npts will be set to 1000. 

453 smearing: selection of broadening kernel for self.sample_grid() 

454 ax: existing Matplotlib axes object. If not provided, a new figure 

455 with one set of axes will be created using Pyplot 

456 show: show the figure on-screen 

457 filename: if a path is given, save the figure to this file 

458 mplargs: additional arguments to pass to matplotlib plot command 

459 (e.g. {'linewidth': 2} for a thicker line). 

460 

461 Returns: 

462 Plotting axes. If "ax" was set, this is the same object. 

463 """ 

464 

465 npts, width = self._interpret_smearing_args(npts, width) 

466 

467 if mplargs is None: 

468 mplargs = {} 

469 if 'label' not in mplargs: 

470 mplargs.update({'label': self.label_from_info(self.info)}) 

471 

472 if npts: 

473 assert isinstance(width, float) 

474 dos = self.sample_grid(npts, xmin=xmin, 

475 xmax=xmax, width=width, 

476 smearing=smearing) 

477 else: 

478 dos = self 

479 

480 energies, intensity = dos.get_energies(), dos.get_weights() 

481 

482 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

483 ax.plot(energies, intensity, **mplargs) 

484 ax.set_xlim(left=xmin, right=xmax) 

485 

486 return ax