Coverage for /builds/debichem-team/python-ase/ase/spectrum/dosdata.py: 100.00%

152 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1# Refactor of DOS-like data objects 

2# towards replacing ase.dft.dos and ase.dft.pdos 

3import warnings 

4from abc import ABCMeta, abstractmethod 

5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union 

6 

7import numpy as np 

8from matplotlib.axes import Axes 

9 

10from ase.utils.plotting import SimplePlottingAxes 

11 

12# For now we will be strict about Info and say it has to be str->str. Perhaps 

13# later we will allow other types that have reliable comparison operations. 

14Info = Dict[str, str] 

15 

16# Still no good solution to type checking with arrays. 

17Floats = Union[Sequence[float], np.ndarray] 

18 

19 

20class DOSData(metaclass=ABCMeta): 

21 """Abstract base class for a single series of DOS-like data 

22 

23 Only the 'info' is a mutable attribute; DOS data is set at init""" 

24 

25 def __init__(self, 

26 info: Info = None) -> None: 

27 if info is None: 

28 self.info = {} 

29 elif isinstance(info, dict): 

30 self.info = info 

31 else: 

32 raise TypeError("Info must be a dict or None") 

33 

34 @abstractmethod 

35 def get_energies(self) -> Floats: 

36 """Get energy data stored in this object""" 

37 

38 @abstractmethod 

39 def get_weights(self) -> Floats: 

40 """Get DOS weights stored in this object""" 

41 

42 @abstractmethod 

43 def copy(self) -> 'DOSData': 

44 """Returns a copy in which info dict can be safely mutated""" 

45 

46 def _sample(self, 

47 energies: Floats, 

48 width: float = 0.1, 

49 smearing: str = 'Gauss') -> np.ndarray: 

50 """Sample the DOS data at chosen points, with broadening 

51 

52 Note that no correction is made here for the sampling bin width; total 

53 intensity will vary with sampling density. 

54 

55 Args: 

56 energies: energy values for sampling 

57 width: Width of broadening kernel 

58 smearing: selection of broadening kernel (only "Gauss" is currently 

59 supported) 

60 

61 Returns: 

62 Weights sampled from a broadened DOS at values corresponding to x 

63 """ 

64 

65 self._check_positive_width(width) 

66 weights_grid = np.zeros(len(energies), float) 

67 weights = self.get_weights() 

68 energies = np.asarray(energies, float) 

69 

70 for i, raw_energy in enumerate(self.get_energies()): 

71 delta = self._delta(energies, raw_energy, width, smearing=smearing) 

72 weights_grid += weights[i] * delta 

73 return weights_grid 

74 

75 def _almost_equals(self, other: Any) -> bool: 

76 """Compare with another DOSData for testing purposes""" 

77 if not isinstance(other, type(self)): 

78 return False 

79 if self.info != other.info: 

80 return False 

81 if not np.allclose(self.get_weights(), other.get_weights()): 

82 return False 

83 return np.allclose(self.get_energies(), other.get_energies()) 

84 

85 @staticmethod 

86 def _delta(x: np.ndarray, 

87 x0: float, 

88 width: float, 

89 smearing: str = 'Gauss') -> np.ndarray: 

90 """Return a delta-function centered at 'x0'. 

91 

92 This function is used with numpy broadcasting; if x is a row and x0 is 

93 a column vector, the returned data will be a 2D array with each row 

94 corresponding to a different delta center. 

95 """ 

96 if smearing.lower() == 'gauss': 

97 x1 = -0.5 * ((x - x0) / width)**2 

98 return np.exp(x1) / (np.sqrt(2 * np.pi) * width) 

99 else: 

100 msg = 'Requested smearing type not recognized. Got {}'.format( 

101 smearing) 

102 raise ValueError(msg) 

103 

104 @staticmethod 

105 def _check_positive_width(width): 

106 if width <= 0.0: 

107 msg = 'Cannot add 0 or negative width smearing' 

108 raise ValueError(msg) 

109 

110 def sample_grid(self, 

111 npts: int, 

112 xmin: float = None, 

113 xmax: float = None, 

114 padding: float = 3, 

115 width: float = 0.1, 

116 smearing: str = 'Gauss', 

117 ) -> 'GridDOSData': 

118 """Sample the DOS data on an evenly-spaced energy grid 

119 

120 Args: 

121 npts: Number of sampled points 

122 xmin: Minimum sampled x value; if unspecified, a default is chosen 

123 xmax: Maximum sampled x value; if unspecified, a default is chosen 

124 padding: If xmin/xmax is unspecified, default value will be padded 

125 by padding * width to avoid cutting off peaks. 

126 width: Width of broadening kernel 

127 smearing: selection of broadening kernel (only 'Gauss' is 

128 implemented) 

129 

130 Returns: 

131 (energy values, sampled DOS) 

132 """ 

133 

134 if xmin is None: 

135 xmin = min(self.get_energies()) - (padding * width) 

136 if xmax is None: 

137 xmax = max(self.get_energies()) + (padding * width) 

138 energies_grid = np.linspace(xmin, xmax, npts) 

139 weights_grid = self._sample(energies_grid, width=width, 

140 smearing=smearing) 

141 

142 return GridDOSData(energies_grid, weights_grid, info=self.info.copy()) 

143 

144 def plot(self, 

145 npts: int = 1000, 

146 xmin: float = None, 

147 xmax: float = None, 

148 width: float = 0.1, 

149 smearing: str = 'Gauss', 

150 ax: Axes = None, 

151 show: bool = False, 

152 filename: str = None, 

153 mplargs: dict = None) -> Axes: 

154 """Simple 1-D plot of DOS data, resampled onto a grid 

155 

156 If the special key 'label' is present in self.info, this will be set 

157 as the label for the plotted line (unless overruled in mplargs). The 

158 label is only seen if a legend is added to the plot (i.e. by calling 

159 ``ax.legend()``). 

160 

161 Args: 

162 npts, xmin, xmax: output data range, as passed to self.sample_grid 

163 width: Width of broadening kernel for self.sample_grid() 

164 smearing: selection of broadening kernel for self.sample_grid() 

165 ax: existing Matplotlib axes object. If not provided, a new figure 

166 with one set of axes will be created using Pyplot 

167 show: show the figure on-screen 

168 filename: if a path is given, save the figure to this file 

169 mplargs: additional arguments to pass to matplotlib plot command 

170 (e.g. {'linewidth': 2} for a thicker line). 

171 

172 

173 Returns: 

174 Plotting axes. If "ax" was set, this is the same object. 

175 """ 

176 

177 if mplargs is None: 

178 mplargs = {} 

179 if 'label' not in mplargs: 

180 mplargs.update({'label': self.label_from_info(self.info)}) 

181 

182 return self.sample_grid(npts, xmin=xmin, xmax=xmax, 

183 width=width, 

184 smearing=smearing 

185 ).plot(ax=ax, xmin=xmin, xmax=xmax, 

186 show=show, filename=filename, 

187 mplargs=mplargs) 

188 

189 @staticmethod 

190 def label_from_info(info: Dict[str, str]): 

191 """Generate an automatic legend label from info dict""" 

192 if 'label' in info: 

193 return info['label'] 

194 else: 

195 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}', 

196 info.items())) 

197 

198 

199class GeneralDOSData(DOSData): 

200 """Base class for a single series of DOS-like data 

201 

202 Only the 'info' is a mutable attribute; DOS data is set at init 

203 

204 This is the base class for DOSData objects that accept/set seperate 

205 "energies" and "weights" sequences of equal length at init. 

206 

207 """ 

208 

209 def __init__(self, 

210 energies: Floats, 

211 weights: Floats, 

212 info: Info = None) -> None: 

213 super().__init__(info=info) 

214 

215 n_entries = len(energies) 

216 if len(weights) != n_entries: 

217 raise ValueError("Energies and weights must be the same length") 

218 

219 # Internally store the data as a np array with two rows; energy, weight 

220 self._data = np.empty((2, n_entries), dtype=float, order='C') 

221 self._data[0, :] = energies 

222 self._data[1, :] = weights 

223 

224 def get_energies(self) -> np.ndarray: 

225 return self._data[0, :].copy() 

226 

227 def get_weights(self) -> np.ndarray: 

228 return self._data[1, :].copy() 

229 

230 D = TypeVar('D', bound='GeneralDOSData') 

231 

232 def copy(self: D) -> D: # noqa F821 

233 return type(self)(self.get_energies(), self.get_weights(), 

234 info=self.info.copy()) 

235 

236 

237class RawDOSData(GeneralDOSData): 

238 """A collection of weighted delta functions which sum to form a DOS 

239 

240 This is an appropriate data container for density-of-states (DOS) or 

241 spectral data where the energy data values not form a known regular 

242 grid. The data may be plotted or resampled for further analysis using the 

243 sample_grid() and plot() methods. Multiple weights at the same 

244 energy value will *only* be combined in output data, and data stored in 

245 RawDOSData is never resampled. A plot_deltas() function is also provided 

246 which plots the raw data. 

247 

248 Metadata may be stored in the info dict, in which keys and values must be 

249 strings. This data is used for selecting and combining multiple DOSData 

250 objects in a DOSCollection object. 

251 

252 When RawDOSData objects are combined with the addition operator:: 

253 

254 big_dos = raw_dos_1 + raw_dos_2 

255 

256 the energy and weights data is *concatenated* (i.e. combined without 

257 sorting or replacement) and the new info dictionary consists of the 

258 *intersection* of the inputs: only key-value pairs that were common to both 

259 of the input objects will be retained in the new combined object. For 

260 example:: 

261 

262 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'}) 

263 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'})) 

264 

265 will yield the equivalent of:: 

266 

267 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'}) 

268 

269 """ 

270 

271 def __add__(self, other: 'RawDOSData') -> 'RawDOSData': 

272 if not isinstance(other, RawDOSData): 

273 raise TypeError("RawDOSData can only be combined with other " 

274 "RawDOSData objects") 

275 

276 # Take intersection of metadata (i.e. only common entries are retained) 

277 new_info = dict(set(self.info.items()) & set(other.info.items())) 

278 

279 # Concatenate the energy/weight data 

280 new_data = np.concatenate((self._data, other._data), axis=1) 

281 

282 new_object = RawDOSData([], [], info=new_info) 

283 new_object._data = new_data 

284 

285 return new_object 

286 

287 def plot_deltas(self, 

288 ax: Axes = None, 

289 show: bool = False, 

290 filename: str = None, 

291 mplargs: dict = None) -> Axes: 

292 """Simple plot of sparse DOS data as a set of delta functions 

293 

294 Items at the same x-value can overlap and will not be summed together 

295 

296 Args: 

297 ax: existing Matplotlib axes object. If not provided, a new figure 

298 with one set of axes will be created using Pyplot 

299 show: show the figure on-screen 

300 filename: if a path is given, save the figure to this file 

301 mplargs: additional arguments to pass to matplotlib Axes.vlines 

302 command (e.g. {'linewidth': 2} for a thicker line). 

303 

304 Returns: 

305 Plotting axes. If "ax" was set, this is the same object. 

306 """ 

307 

308 if mplargs is None: 

309 mplargs = {} 

310 

311 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

312 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs) 

313 

314 return ax 

315 

316 

317class GridDOSData(GeneralDOSData): 

318 """A collection of regularly-sampled data which represents a DOS 

319 

320 This is an appropriate data container for density-of-states (DOS) or 

321 spectral data where the intensity values form a regular grid. This 

322 is generally the result of sampling or integrating into discrete 

323 bins, rather than a collection of unique states. The data may be 

324 plotted or resampled for further analysis using the sample_grid() 

325 and plot() methods. 

326 

327 Metadata may be stored in the info dict, in which keys and values must be 

328 strings. This data is used for selecting and combining multiple DOSData 

329 objects in a DOSCollection object. 

330 

331 When RawDOSData objects are combined with the addition operator:: 

332 

333 big_dos = raw_dos_1 + raw_dos_2 

334 

335 the weights data is *summed* (requiring a consistent energy grid) and the 

336 new info dictionary consists of the *intersection* of the inputs: only 

337 key-value pairs that were common to both of the input objects will be 

338 retained in the new combined object. For example:: 

339 

340 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3], 

341 info={'symbol': 'O', 'index': '1'}) 

342 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6], 

343 info={'symbol': 'O', 'index': '2'})) 

344 

345 will yield the equivalent of:: 

346 

347 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'}) 

348 

349 """ 

350 

351 def __init__(self, 

352 energies: Floats, 

353 weights: Floats, 

354 info: Info = None) -> None: 

355 n_entries = len(energies) 

356 if not np.allclose(energies, 

357 np.linspace(energies[0], energies[-1], n_entries)): 

358 raise ValueError("Energies must be an evenly-spaced 1-D grid") 

359 

360 if len(weights) != n_entries: 

361 raise ValueError("Energies and weights must be the same length") 

362 

363 super().__init__(energies, weights, info=info) 

364 self.sigma_cutoff = 3 

365 

366 def _check_spacing(self, width) -> float: 

367 current_spacing = self._data[0, 1] - self._data[0, 0] 

368 if width < (2 * current_spacing): 

369 warnings.warn( 

370 "The broadening width is small compared to the original " 

371 "sampling density. The results are unlikely to be smooth.") 

372 return current_spacing 

373 

374 def _sample(self, 

375 energies: Floats, 

376 width: float = 0.1, 

377 smearing: str = 'Gauss') -> np.ndarray: 

378 current_spacing = self._check_spacing(width) 

379 return super()._sample(energies=energies, 

380 width=width, smearing=smearing 

381 ) * current_spacing 

382 

383 def __add__(self, other: 'GridDOSData') -> 'GridDOSData': 

384 # This method uses direct access to the mutable energy and weights data 

385 # (self._data) to avoid redundant copying operations. The __init__ 

386 # method of GridDOSData will write this to a new array, so on this 

387 # occasion it is safe to pass references to the mutable data. 

388 

389 if not isinstance(other, GridDOSData): 

390 raise TypeError("GridDOSData can only be combined with other " 

391 "GridDOSData objects") 

392 if len(self._data[0, :]) != len(other.get_energies()): 

393 raise ValueError("Cannot add GridDOSData objects with different-" 

394 "length energy grids.") 

395 

396 if not np.allclose(self._data[0, :], other.get_energies()): 

397 raise ValueError("Cannot add GridDOSData objects with different " 

398 "energy grids.") 

399 

400 # Take intersection of metadata (i.e. only common entries are retained) 

401 new_info = dict(set(self.info.items()) & set(other.info.items())) 

402 

403 # Sum the energy/weight data 

404 new_weights = self._data[1, :] + other.get_weights() 

405 

406 new_object = GridDOSData(self._data[0, :], new_weights, 

407 info=new_info) 

408 return new_object 

409 

410 @staticmethod 

411 def _interpret_smearing_args(npts: int, 

412 width: float = None, 

413 default_npts: int = 1000, 

414 default_width: float = 0.1 

415 ) -> Tuple[int, Union[float, None]]: 

416 """Figure out what the user intended: resample if width provided""" 

417 if width is not None: 

418 if npts: 

419 return (npts, float(width)) 

420 else: 

421 return (default_npts, float(width)) 

422 else: 

423 if npts: 

424 return (npts, default_width) 

425 else: 

426 return (0, None) 

427 

428 def plot(self, 

429 npts: int = 0, 

430 xmin: float = None, 

431 xmax: float = None, 

432 width: float = None, 

433 smearing: str = 'Gauss', 

434 ax: Axes = None, 

435 show: bool = False, 

436 filename: str = None, 

437 mplargs: dict = None) -> Axes: 

438 """Simple 1-D plot of DOS data 

439 

440 Data will be resampled onto a grid with `npts` points unless `npts` is 

441 set to zero, in which case: 

442 

443 - no resampling takes place 

444 - `width` and `smearing` are ignored 

445 - `xmin` and `xmax` affect the axis limits of the plot, not the 

446 underlying data. 

447 

448 If the special key 'label' is present in self.info, this will be set 

449 as the label for the plotted line (unless overruled in mplargs). The 

450 label is only seen if a legend is added to the plot (i.e. by calling 

451 ``ax.legend()``). 

452 

453 Args: 

454 npts, xmin, xmax: output data range, as passed to self.sample_grid 

455 width: Width of broadening kernel, passed to self.sample_grid(). 

456 If no npts was set but width is set, npts will be set to 1000. 

457 smearing: selection of broadening kernel for self.sample_grid() 

458 ax: existing Matplotlib axes object. If not provided, a new figure 

459 with one set of axes will be created using Pyplot 

460 show: show the figure on-screen 

461 filename: if a path is given, save the figure to this file 

462 mplargs: additional arguments to pass to matplotlib plot command 

463 (e.g. {'linewidth': 2} for a thicker line). 

464 

465 Returns: 

466 Plotting axes. If "ax" was set, this is the same object. 

467 """ 

468 

469 npts, width = self._interpret_smearing_args(npts, width) 

470 

471 if mplargs is None: 

472 mplargs = {} 

473 if 'label' not in mplargs: 

474 mplargs.update({'label': self.label_from_info(self.info)}) 

475 

476 if npts: 

477 assert isinstance(width, float) 

478 dos = self.sample_grid(npts, xmin=xmin, 

479 xmax=xmax, width=width, 

480 smearing=smearing) 

481 else: 

482 dos = self 

483 

484 energies, intensity = dos.get_energies(), dos.get_weights() 

485 

486 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

487 ax.plot(energies, intensity, **mplargs) 

488 ax.set_xlim(left=xmin, right=xmax) 

489 

490 return ax