Coverage for /builds/debichem-team/python-ase/ase/spectrum/doscollection.py: 97.84%

185 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import collections 

2from functools import reduce, singledispatch 

3from typing import ( 

4 Any, 

5 Dict, 

6 Iterable, 

7 List, 

8 Optional, 

9 Sequence, 

10 TypeVar, 

11 Union, 

12 overload, 

13) 

14 

15import numpy as np 

16from matplotlib.axes import Axes 

17 

18from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData 

19from ase.utils.plotting import SimplePlottingAxes 

20 

21 

22class DOSCollection(collections.abc.Sequence): 

23 """Base class for a collection of DOSData objects""" 

24 

25 def __init__(self, dos_series: Iterable[DOSData]) -> None: 

26 self._data = list(dos_series) 

27 

28 def _sample(self, 

29 energies: Floats, 

30 width: float = 0.1, 

31 smearing: str = 'Gauss') -> np.ndarray: 

32 """Sample the DOS data at chosen points, with broadening 

33 

34 This samples the underlying DOS data in the same way as the .sample() 

35 method of those DOSData items, returning a 2-D array with columns 

36 corresponding to x and rows corresponding to the collected data series. 

37 

38 Args: 

39 energies: energy values for sampling 

40 width: Width of broadening kernel 

41 smearing: selection of broadening kernel (only "Gauss" is currently 

42 supported) 

43 

44 Returns: 

45 Weights sampled from a broadened DOS at values corresponding to x, 

46 in rows corresponding to DOSData entries contained in this object 

47 """ 

48 

49 if len(self) == 0: 

50 raise IndexError("No data to sample") 

51 

52 return np.asarray( 

53 [data._sample(energies, width=width, smearing=smearing) 

54 for data in self]) 

55 

56 def plot(self, 

57 npts: int = 1000, 

58 xmin: float = None, 

59 xmax: float = None, 

60 width: float = 0.1, 

61 smearing: str = 'Gauss', 

62 ax: Axes = None, 

63 show: bool = False, 

64 filename: str = None, 

65 mplargs: dict = None) -> Axes: 

66 """Simple plot of collected DOS data, resampled onto a grid 

67 

68 If the special key 'label' is present in self.info, this will be set 

69 as the label for the plotted line (unless overruled in mplargs). The 

70 label is only seen if a legend is added to the plot (i.e. by calling 

71 `ax.legend()`). 

72 

73 Args: 

74 npts, xmin, xmax: output data range, as passed to self.sample_grid 

75 width: Width of broadening kernel, passed to self.sample_grid() 

76 smearing: selection of broadening kernel for self.sample_grid() 

77 ax: existing Matplotlib axes object. If not provided, a new figure 

78 with one set of axes will be created using Pyplot 

79 show: show the figure on-screen 

80 filename: if a path is given, save the figure to this file 

81 mplargs: additional arguments to pass to matplotlib plot command 

82 (e.g. {'linewidth': 2} for a thicker line). 

83 

84 Returns: 

85 Plotting axes. If "ax" was set, this is the same object. 

86 """ 

87 return self.sample_grid(npts, 

88 xmin=xmin, xmax=xmax, 

89 width=width, smearing=smearing 

90 ).plot(npts=npts, 

91 xmin=xmin, xmax=xmax, 

92 width=width, smearing=smearing, 

93 ax=ax, show=show, filename=filename, 

94 mplargs=mplargs) 

95 

96 def sample_grid(self, 

97 npts: int, 

98 xmin: float = None, 

99 xmax: float = None, 

100 padding: float = 3, 

101 width: float = 0.1, 

102 smearing: str = 'Gauss', 

103 ) -> 'GridDOSCollection': 

104 """Sample the DOS data on an evenly-spaced energy grid 

105 

106 Args: 

107 npts: Number of sampled points 

108 xmin: Minimum sampled energy value; if unspecified, a default is 

109 chosen 

110 xmax: Maximum sampled energy value; if unspecified, a default is 

111 chosen 

112 padding: If xmin/xmax is unspecified, default value will be padded 

113 by padding * width to avoid cutting off peaks. 

114 width: Width of broadening kernel, passed to self.sample_grid() 

115 smearing: selection of broadening kernel, for self.sample_grid() 

116 

117 Returns: 

118 (energy values, sampled DOS) 

119 """ 

120 if len(self) == 0: 

121 raise IndexError("No data to sample") 

122 

123 if xmin is None: 

124 xmin = (min(min(data.get_energies()) for data in self) 

125 - (padding * width)) 

126 if xmax is None: 

127 xmax = (max(max(data.get_energies()) for data in self) 

128 + (padding * width)) 

129 

130 return GridDOSCollection( 

131 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width, 

132 smearing=smearing) 

133 for data in self]) 

134 

135 @classmethod 

136 def from_data(cls, 

137 energies: Floats, 

138 weights: Sequence[Floats], 

139 info: Sequence[Info] = None) -> 'DOSCollection': 

140 """Create a DOSCollection from data sharing a common set of energies 

141 

142 This is a convenience method to be used when all the DOS data in the 

143 collection has a common energy axis. There is no performance advantage 

144 in using this method for the generic DOSCollection, but for 

145 GridDOSCollection it is more efficient. 

146 

147 Args: 

148 energy: common set of energy values for input data 

149 weights: array of DOS weights with rows corresponding to different 

150 datasets 

151 info: sequence of info dicts corresponding to weights rows. 

152 

153 Returns: 

154 Collection of DOS data (in RawDOSData format) 

155 """ 

156 

157 info = cls._check_weights_and_info(weights, info) 

158 

159 return cls(RawDOSData(energies, row_weights, row_info) 

160 for row_weights, row_info in zip(weights, info)) 

161 

162 @staticmethod 

163 def _check_weights_and_info(weights: Sequence[Floats], 

164 info: Optional[Sequence[Info]], 

165 ) -> Sequence[Info]: 

166 if info is None: 

167 info = [{} for _ in range(len(weights))] 

168 else: 

169 if len(info) != len(weights): 

170 raise ValueError("Length of info must match number of rows in " 

171 "weights") 

172 return info 

173 

174 @overload 

175 def __getitem__(self, item: int) -> DOSData: 

176 ... 

177 

178 @overload # noqa F811 

179 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811 

180 ... 

181 

182 def __getitem__(self, item): # noqa F811 

183 if isinstance(item, int): 

184 return self._data[item] 

185 elif isinstance(item, slice): 

186 return type(self)(self._data[item]) 

187 else: 

188 raise TypeError("index in DOSCollection must be an integer or " 

189 "slice") 

190 

191 def __len__(self) -> int: 

192 return len(self._data) 

193 

194 def _almost_equals(self, other: Any) -> bool: 

195 """Compare with another DOSCollection for testing purposes""" 

196 if not isinstance(other, type(self)): 

197 return False 

198 elif len(self) != len(other): 

199 return False 

200 else: 

201 return all(a._almost_equals(b) for a, b in zip(self, other)) 

202 

203 def total(self) -> DOSData: 

204 """Sum all the DOSData in this Collection and label it as 'Total'""" 

205 data = self.sum_all() 

206 data.info.update({'label': 'Total'}) 

207 return data 

208 

209 def sum_all(self) -> DOSData: 

210 """Sum all the DOSData contained in this Collection""" 

211 if len(self) == 0: 

212 raise IndexError("No data to sum") 

213 elif len(self) == 1: 

214 data = self[0].copy() 

215 else: 

216 data = reduce(lambda x, y: x + y, self) 

217 return data 

218 

219 D = TypeVar('D', bound=DOSData) 

220 

221 @staticmethod 

222 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes 

223 info_selection: Dict[str, str], # misses 'D' def 

224 negative: bool = False) -> List[D]: # noqa: F821 

225 query = set(info_selection.items()) 

226 

227 if negative: 

228 return [data for data in dos_collection 

229 if not query.issubset(set(data.info.items()))] 

230 else: 

231 return [data for data in dos_collection 

232 if query.issubset(set(data.info.items()))] 

233 

234 def select(self, **info_selection: str) -> 'DOSCollection': 

235 """Narrow DOSCollection to items with specified info 

236 

237 For example, if :: 

238 

239 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

240 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

241 

242 then :: 

243 

244 dc.select(b='1') 

245 

246 will return an identical object to dc, while :: 

247 

248 dc.select(a='1') 

249 

250 will return a DOSCollection with only the first item and :: 

251 

252 dc.select(a='2', b='1') 

253 

254 will return a DOSCollection with only the second item. 

255 

256 """ 

257 

258 matches = self._select_to_list(self, info_selection) 

259 return type(self)(matches) 

260 

261 def select_not(self, **info_selection: str) -> 'DOSCollection': 

262 """Narrow DOSCollection to items without specified info 

263 

264 For example, if :: 

265 

266 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

267 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

268 

269 then :: 

270 

271 dc.select_not(b='2') 

272 

273 will return an identical object to dc, while :: 

274 

275 dc.select_not(a='2') 

276 

277 will return a DOSCollection with only the first item and :: 

278 

279 dc.select_not(a='1', b='1') 

280 

281 will return a DOSCollection with only the second item. 

282 

283 """ 

284 matches = self._select_to_list(self, info_selection, negative=True) 

285 return type(self)(matches) 

286 

287 def sum_by(self, *info_keys: str) -> 'DOSCollection': 

288 """Return a DOSCollection with some data summed by common attributes 

289 

290 For example, if :: 

291 

292 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

293 DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

294 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

295 

296 then :: 

297 

298 dc.sum_by('b') 

299 

300 will return a collection equivalent to :: 

301 

302 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}) 

303 + DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

304 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

305 

306 where the resulting contained DOSData have info attributes of 

307 {'b': '1'} and {'b': '2'} respectively. 

308 

309 dc.sum_by('a', 'b') on the other hand would return the full three-entry 

310 collection, as none of the entries have common 'a' *and* 'b' info. 

311 

312 """ 

313 

314 def _matching_info_tuples(data: DOSData): 

315 """Get relevent dict entries in tuple form 

316 

317 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3} 

318 and info_keys = ('a', 'c') 

319 

320 then return (('a', 1), ('c': 3)) 

321 """ 

322 matched_keys = set(info_keys) & set(data.info) 

323 return tuple(sorted([(key, data.info[key]) 

324 for key in matched_keys])) 

325 

326 # Sorting inside info matching helps set() to remove redundant matches; 

327 # combos are then sorted() to ensure consistent output across sessions. 

328 all_combos = map(_matching_info_tuples, self) 

329 unique_combos = sorted(set(all_combos)) 

330 

331 # For each key/value combination, perform a select() to obtain all 

332 # the matching entries and sum them together. 

333 collection_data = [self.select(**dict(combo)).sum_all() 

334 for combo in unique_combos] 

335 return type(self)(collection_data) 

336 

337 def __add__(self, other: Union['DOSCollection', DOSData] 

338 ) -> 'DOSCollection': 

339 """Join entries between two DOSCollection objects of the same type 

340 

341 It is also possible to add a single DOSData object without wrapping it 

342 in a new collection: i.e. :: 

343 

344 DOSCollection([dosdata1]) + DOSCollection([dosdata2]) 

345 

346 or :: 

347 

348 DOSCollection([dosdata1]) + dosdata2 

349 

350 will return :: 

351 

352 DOSCollection([dosdata1, dosdata2]) 

353 

354 """ 

355 return _add_to_collection(other, self) 

356 

357 

358@singledispatch 

359def _add_to_collection(other: Union[DOSData, DOSCollection], 

360 collection: DOSCollection) -> DOSCollection: 

361 if isinstance(other, type(collection)): 

362 return type(collection)(list(collection) + list(other)) 

363 elif isinstance(other, DOSCollection): 

364 raise TypeError("Only DOSCollection objects of the same type may " 

365 "be joined with '+'.") 

366 else: 

367 raise TypeError("DOSCollection may only be joined to DOSData or " 

368 "DOSCollection objects with '+'.") 

369 

370 

371@_add_to_collection.register(DOSData) 

372def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection: 

373 """Return a new DOSCollection with an additional DOSData item""" 

374 return type(collection)(list(collection) + [other]) 

375 

376 

377class RawDOSCollection(DOSCollection): 

378 def __init__(self, dos_series: Iterable[RawDOSData]) -> None: 

379 super().__init__(dos_series) 

380 for dos_data in self: 

381 if not isinstance(dos_data, RawDOSData): 

382 raise TypeError("RawDOSCollection can only store " 

383 "RawDOSData objects.") 

384 

385 

386class GridDOSCollection(DOSCollection): 

387 def __init__(self, dos_series: Iterable[GridDOSData], 

388 energies: Optional[Floats] = None) -> None: 

389 dos_list = list(dos_series) 

390 if energies is None: 

391 if len(dos_list) == 0: 

392 raise ValueError("Must provide energies to create a " 

393 "GridDOSCollection without any DOS data.") 

394 self._energies = dos_list[0].get_energies() 

395 else: 

396 self._energies = np.asarray(energies) 

397 

398 self._weights = np.empty((len(dos_list), len(self._energies)), float) 

399 self._info = [] 

400 

401 for i, dos_data in enumerate(dos_list): 

402 if not isinstance(dos_data, GridDOSData): 

403 raise TypeError("GridDOSCollection can only store " 

404 "GridDOSData objects.") 

405 if (dos_data.get_energies().shape != self._energies.shape 

406 or not np.allclose(dos_data.get_energies(), 

407 self._energies)): 

408 raise ValueError("All GridDOSData objects in GridDOSCollection" 

409 " must have the same energy axis.") 

410 self._weights[i, :] = dos_data.get_weights() 

411 self._info.append(dos_data.info) 

412 

413 def get_energies(self) -> Floats: 

414 return self._energies.copy() 

415 

416 def get_all_weights(self) -> Union[Sequence[Floats], np.ndarray]: 

417 return self._weights.copy() 

418 

419 def __len__(self) -> int: 

420 return self._weights.shape[0] 

421 

422 @overload # noqa F811 

423 def __getitem__(self, item: int) -> DOSData: 

424 ... 

425 

426 @overload # noqa F811 

427 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811 

428 ... 

429 

430 def __getitem__(self, item): # noqa F811 

431 if isinstance(item, int): 

432 return GridDOSData(self._energies, self._weights[item, :], 

433 info=self._info[item]) 

434 elif isinstance(item, slice): 

435 return type(self)([self[i] for i in range(len(self))[item]]) 

436 else: 

437 raise TypeError("index in DOSCollection must be an integer or " 

438 "slice") 

439 

440 @classmethod 

441 def from_data(cls, 

442 energies: Floats, 

443 weights: Sequence[Floats], 

444 info: Sequence[Info] = None) -> 'GridDOSCollection': 

445 """Create a GridDOSCollection from data with a common set of energies 

446 

447 This convenience method may also be more efficient as it limits 

448 redundant copying/checking of the data. 

449 

450 Args: 

451 energies: common set of energy values for input data 

452 weights: array of DOS weights with rows corresponding to different 

453 datasets 

454 info: sequence of info dicts corresponding to weights rows. 

455 

456 Returns: 

457 Collection of DOS data (in RawDOSData format) 

458 """ 

459 

460 weights_array = np.asarray(weights, dtype=float) 

461 if len(weights_array.shape) != 2: 

462 raise IndexError("Weights must be a 2-D array or nested sequence") 

463 if weights_array.shape[0] < 1: 

464 raise IndexError("Weights cannot be empty") 

465 if weights_array.shape[1] != len(energies): 

466 raise IndexError("Length of weights rows must equal size of x") 

467 

468 info = cls._check_weights_and_info(weights, info) 

469 

470 dos_collection = cls([GridDOSData(energies, weights_array[0])]) 

471 dos_collection._weights = weights_array 

472 dos_collection._info = list(info) 

473 

474 return dos_collection 

475 

476 def select(self, **info_selection: str) -> 'DOSCollection': 

477 """Narrow GridDOSCollection to items with specified info 

478 

479 For example, if :: 

480 

481 dc = GridDOSCollection([GridDOSData(x, y1, 

482 info={'a': '1', 'b': '1'}), 

483 GridDOSData(x, y2, 

484 info={'a': '2', 'b': '1'})]) 

485 

486 then :: 

487 

488 dc.select(b='1') 

489 

490 will return an identical object to dc, while :: 

491 

492 dc.select(a='1') 

493 

494 will return a DOSCollection with only the first item and :: 

495 

496 dc.select(a='2', b='1') 

497 

498 will return a DOSCollection with only the second item. 

499 

500 """ 

501 

502 matches = self._select_to_list(self, info_selection) 

503 if len(matches) == 0: 

504 return type(self)([], energies=self._energies) 

505 else: 

506 return type(self)(matches) 

507 

508 def select_not(self, **info_selection: str) -> 'DOSCollection': 

509 """Narrow GridDOSCollection to items without specified info 

510 

511 For example, if :: 

512 

513 dc = GridDOSCollection([GridDOSData(x, y1, 

514 info={'a': '1', 'b': '1'}), 

515 GridDOSData(x, y2, 

516 info={'a': '2', 'b': '1'})]) 

517 

518 then :: 

519 

520 dc.select_not(b='2') 

521 

522 will return an identical object to dc, while :: 

523 

524 dc.select_not(a='2') 

525 

526 will return a DOSCollection with only the first item and :: 

527 

528 dc.select_not(a='1', b='1') 

529 

530 will return a DOSCollection with only the second item. 

531 

532 """ 

533 matches = self._select_to_list(self, info_selection, negative=True) 

534 if len(matches) == 0: 

535 return type(self)([], energies=self._energies) 

536 else: 

537 return type(self)(matches) 

538 

539 def plot(self, 

540 npts: int = 0, 

541 xmin: float = None, 

542 xmax: float = None, 

543 width: float = None, 

544 smearing: str = 'Gauss', 

545 ax: Axes = None, 

546 show: bool = False, 

547 filename: str = None, 

548 mplargs: dict = None) -> Axes: 

549 """Simple plot of collected DOS data, resampled onto a grid 

550 

551 If the special key 'label' is present in self.info, this will be set 

552 as the label for the plotted line (unless overruled in mplargs). The 

553 label is only seen if a legend is added to the plot (i.e. by calling 

554 `ax.legend()`). 

555 

556 Args: 

557 npts: 

558 Number of points in resampled x-axis. If set to zero (default), 

559 no resampling is performed and the stored data is plotted 

560 directly. 

561 xmin, xmax: 

562 output data range; this limits the resampling range as well as 

563 the plotting output 

564 width: Width of broadening kernel, passed to self.sample() 

565 smearing: selection of broadening kernel, passed to self.sample() 

566 ax: existing Matplotlib axes object. If not provided, a new figure 

567 with one set of axes will be created using Pyplot 

568 show: show the figure on-screen 

569 filename: if a path is given, save the figure to this file 

570 mplargs: additional arguments to pass to matplotlib plot command 

571 (e.g. {'linewidth': 2} for a thicker line). 

572 

573 Returns: 

574 Plotting axes. If "ax" was set, this is the same object. 

575 """ 

576 

577 # Apply defaults if necessary 

578 npts, width = GridDOSData._interpret_smearing_args(npts, width) 

579 

580 if npts: 

581 assert isinstance(width, float) 

582 dos = self.sample_grid(npts, 

583 xmin=xmin, xmax=xmax, 

584 width=width, smearing=smearing) 

585 else: 

586 dos = self 

587 

588 energies, all_y = dos._energies, dos._weights 

589 

590 all_labels = [DOSData.label_from_info(data.info) for data in self] 

591 

592 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

593 self._plot_broadened(ax, energies, all_y, all_labels, mplargs) 

594 

595 return ax 

596 

597 @staticmethod 

598 def _plot_broadened(ax: Axes, 

599 energies: Floats, 

600 all_y: np.ndarray, 

601 all_labels: Sequence[str], 

602 mplargs: Optional[Dict]): 

603 """Plot DOS data with labels to axes 

604 

605 This is separated into another function so that subclasses can 

606 manipulate broadening, labels etc in their plot() method.""" 

607 if mplargs is None: 

608 mplargs = {} 

609 

610 all_lines = ax.plot(energies, all_y.T, **mplargs) 

611 for line, label in zip(all_lines, all_labels): 

612 line.set_label(label) 

613 ax.legend() 

614 

615 ax.set_xlim(left=min(energies), right=max(energies)) 

616 ax.set_ylim(bottom=0)