Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import collections 

2from functools import reduce, singledispatch 

3from typing import (Any, Dict, Iterable, List, Optional, 

4 overload, Sequence, TypeVar, Union) 

5 

6import numpy as np 

7from ase.spectrum.dosdata import DOSData, RawDOSData, GridDOSData, Info 

8from ase.utils.plotting import SimplePlottingAxes 

9 

10# This import is for the benefit of type-checking / mypy 

11if False: 

12 import matplotlib.axes 

13 

14 

15class DOSCollection(collections.abc.Sequence): 

16 """Base class for a collection of DOSData objects""" 

17 def __init__(self, dos_series: Iterable[DOSData]) -> None: 

18 self._data = list(dos_series) 

19 

20 def _sample(self, 

21 energies: Sequence[float], 

22 width: float = 0.1, 

23 smearing: str = 'Gauss') -> np.ndarray: 

24 """Sample the DOS data at chosen points, with broadening 

25 

26 This samples the underlying DOS data in the same way as the .sample() 

27 method of those DOSData items, returning a 2-D array with columns 

28 corresponding to x and rows corresponding to the collected data series. 

29 

30 Args: 

31 energies: energy values for sampling 

32 width: Width of broadening kernel 

33 smearing: selection of broadening kernel (only "Gauss" is currently 

34 supported) 

35 

36 Returns: 

37 Weights sampled from a broadened DOS at values corresponding to x, 

38 in rows corresponding to DOSData entries contained in this object 

39 """ 

40 

41 if len(self) == 0: 

42 raise IndexError("No data to sample") 

43 

44 return np.asarray( 

45 [data._sample(energies, width=width, smearing=smearing) 

46 for data in self]) 

47 

48 def plot(self, 

49 npts: int = 1000, 

50 xmin: float = None, 

51 xmax: float = None, 

52 width: float = 0.1, 

53 smearing: str = 'Gauss', 

54 ax: 'matplotlib.axes.Axes' = None, 

55 show: bool = False, 

56 filename: str = None, 

57 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

58 """Simple plot of collected DOS data, resampled onto a grid 

59 

60 If the special key 'label' is present in self.info, this will be set 

61 as the label for the plotted line (unless overruled in mplargs). The 

62 label is only seen if a legend is added to the plot (i.e. by calling 

63 `ax.legend()`). 

64 

65 Args: 

66 npts, xmin, xmax: output data range, as passed to self.sample_grid 

67 width: Width of broadening kernel, passed to self.sample_grid() 

68 smearing: selection of broadening kernel for self.sample_grid() 

69 ax: existing Matplotlib axes object. If not provided, a new figure 

70 with one set of axes will be created using Pyplot 

71 show: show the figure on-screen 

72 filename: if a path is given, save the figure to this file 

73 mplargs: additional arguments to pass to matplotlib plot command 

74 (e.g. {'linewidth': 2} for a thicker line). 

75 

76 Returns: 

77 Plotting axes. If "ax" was set, this is the same object. 

78 """ 

79 return self.sample_grid(npts, 

80 xmin=xmin, xmax=xmax, 

81 width=width, smearing=smearing 

82 ).plot(npts=npts, 

83 xmin=xmin, xmax=xmax, 

84 width=width, smearing=smearing, 

85 ax=ax, show=show, filename=filename, 

86 mplargs=mplargs) 

87 

88 def sample_grid(self, 

89 npts: int, 

90 xmin: float = None, 

91 xmax: float = None, 

92 padding: float = 3, 

93 width: float = 0.1, 

94 smearing: str = 'Gauss', 

95 ) -> 'GridDOSCollection': 

96 """Sample the DOS data on an evenly-spaced energy grid 

97 

98 Args: 

99 npts: Number of sampled points 

100 xmin: Minimum sampled energy value; if unspecified, a default is 

101 chosen 

102 xmax: Maximum sampled energy value; if unspecified, a default is 

103 chosen 

104 padding: If xmin/xmax is unspecified, default value will be padded 

105 by padding * width to avoid cutting off peaks. 

106 width: Width of broadening kernel, passed to self.sample_grid() 

107 smearing: selection of broadening kernel, for self.sample_grid() 

108 

109 Returns: 

110 (energy values, sampled DOS) 

111 """ 

112 if len(self) == 0: 

113 raise IndexError("No data to sample") 

114 

115 if xmin is None: 

116 xmin = (min(min(data.get_energies()) for data in self) 

117 - (padding * width)) 

118 if xmax is None: 

119 xmax = (max(max(data.get_energies()) for data in self) 

120 + (padding * width)) 

121 

122 return GridDOSCollection( 

123 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width, 

124 smearing=smearing) 

125 for data in self]) 

126 

127 @classmethod 

128 def from_data(cls, 

129 energies: Sequence[float], 

130 weights: Sequence[Sequence[float]], 

131 info: Sequence[Info] = None) -> 'DOSCollection': 

132 """Create a DOSCollection from data sharing a common set of energies 

133 

134 This is a convenience method to be used when all the DOS data in the 

135 collection has a common energy axis. There is no performance advantage 

136 in using this method for the generic DOSCollection, but for 

137 GridDOSCollection it is more efficient. 

138 

139 Args: 

140 energy: common set of energy values for input data 

141 weights: array of DOS weights with rows corresponding to different 

142 datasets 

143 info: sequence of info dicts corresponding to weights rows. 

144 

145 Returns: 

146 Collection of DOS data (in RawDOSData format) 

147 """ 

148 

149 info = cls._check_weights_and_info(weights, info) 

150 

151 return cls(RawDOSData(energies, row_weights, row_info) 

152 for row_weights, row_info in zip(weights, info)) 

153 

154 @staticmethod 

155 def _check_weights_and_info(weights: Sequence[Sequence[float]], 

156 info: Union[Sequence[Info], None], 

157 ) -> Sequence[Info]: 

158 if info is None: 

159 info = [{} for _ in range(len(weights))] 

160 else: 

161 if len(info) != len(weights): 

162 raise ValueError("Length of info must match number of rows in " 

163 "weights") 

164 return info 

165 

166 @overload 

167 def __getitem__(self, item: int) -> DOSData: 

168 ... 

169 

170 @overload # noqa F811 

171 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811 

172 ... 

173 

174 def __getitem__(self, item): # noqa F811 

175 if isinstance(item, int): 

176 return self._data[item] 

177 elif isinstance(item, slice): 

178 return type(self)(self._data[item]) 

179 else: 

180 raise TypeError("index in DOSCollection must be an integer or " 

181 "slice") 

182 

183 def __len__(self) -> int: 

184 return len(self._data) 

185 

186 def _almost_equals(self, other: Any) -> bool: 

187 """Compare with another DOSCollection for testing purposes""" 

188 if not isinstance(other, type(self)): 

189 return False 

190 elif not len(self) == len(other): 

191 return False 

192 else: 

193 return all([a._almost_equals(b) for a, b in zip(self, other)]) 

194 

195 def total(self) -> DOSData: 

196 """Sum all the DOSData in this Collection and label it as 'Total'""" 

197 data = self.sum_all() 

198 data.info.update({'label': 'Total'}) 

199 return data 

200 

201 def sum_all(self) -> DOSData: 

202 """Sum all the DOSData contained in this Collection""" 

203 if len(self) == 0: 

204 raise IndexError("No data to sum") 

205 elif len(self) == 1: 

206 data = self[0].copy() 

207 else: 

208 data = reduce(lambda x, y: x + y, self) 

209 return data 

210 

211 D = TypeVar('D', bound=DOSData) 

212 

213 @staticmethod 

214 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes 

215 info_selection: Dict[str, str], # misses 'D' def 

216 negative: bool = False) -> List[D]: # noqa: F821 

217 query = set(info_selection.items()) 

218 

219 if negative: 

220 return [data for data in dos_collection 

221 if not query.issubset(set(data.info.items()))] 

222 else: 

223 return [data for data in dos_collection 

224 if query.issubset(set(data.info.items()))] 

225 

226 def select(self, **info_selection: str) -> 'DOSCollection': 

227 """Narrow DOSCollection to items with specified info 

228 

229 For example, if :: 

230 

231 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

232 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

233 

234 then :: 

235 

236 dc.select(b='1') 

237 

238 will return an identical object to dc, while :: 

239 

240 dc.select(a='1') 

241 

242 will return a DOSCollection with only the first item and :: 

243 

244 dc.select(a='2', b='1') 

245 

246 will return a DOSCollection with only the second item. 

247 

248 """ 

249 

250 matches = self._select_to_list(self, info_selection) 

251 return type(self)(matches) 

252 

253 def select_not(self, **info_selection: str) -> 'DOSCollection': 

254 """Narrow DOSCollection to items without specified info 

255 

256 For example, if :: 

257 

258 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

259 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

260 

261 then :: 

262 

263 dc.select_not(b='2') 

264 

265 will return an identical object to dc, while :: 

266 

267 dc.select_not(a='2') 

268 

269 will return a DOSCollection with only the first item and :: 

270 

271 dc.select_not(a='1', b='1') 

272 

273 will return a DOSCollection with only the second item. 

274 

275 """ 

276 matches = self._select_to_list(self, info_selection, negative=True) 

277 return type(self)(matches) 

278 

279 def sum_by(self, *info_keys: str) -> 'DOSCollection': 

280 """Return a DOSCollection with some data summed by common attributes 

281 

282 For example, if :: 

283 

284 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

285 DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

286 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

287 

288 then :: 

289 

290 dc.sum_by('b') 

291 

292 will return a collection equivalent to :: 

293 

294 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}) 

295 + DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

296 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

297 

298 where the resulting contained DOSData have info attributes of 

299 {'b': '1'} and {'b': '2'} respectively. 

300 

301 dc.sum_by('a', 'b') on the other hand would return the full three-entry 

302 collection, as none of the entries have common 'a' *and* 'b' info. 

303 

304 """ 

305 

306 def _matching_info_tuples(data: DOSData): 

307 """Get relevent dict entries in tuple form 

308 

309 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3} 

310 and info_keys = ('a', 'c') 

311 

312 then return (('a', 1), ('c': 3)) 

313 """ 

314 matched_keys = set(info_keys) & set(data.info) 

315 return tuple(sorted([(key, data.info[key]) 

316 for key in matched_keys])) 

317 

318 # Sorting inside info matching helps set() to remove redundant matches; 

319 # combos are then sorted() to ensure consistent output across sessions. 

320 all_combos = map(_matching_info_tuples, self) 

321 unique_combos = sorted(set(all_combos)) 

322 

323 # For each key/value combination, perform a select() to obtain all 

324 # the matching entries and sum them together. 

325 collection_data = [self.select(**dict(combo)).sum_all() 

326 for combo in unique_combos] 

327 return type(self)(collection_data) 

328 

329 def __add__(self, other: Union['DOSCollection', DOSData] 

330 ) -> 'DOSCollection': 

331 """Join entries between two DOSCollection objects of the same type 

332 

333 It is also possible to add a single DOSData object without wrapping it 

334 in a new collection: i.e. :: 

335 

336 DOSCollection([dosdata1]) + DOSCollection([dosdata2]) 

337 

338 or :: 

339 

340 DOSCollection([dosdata1]) + dosdata2 

341 

342 will return :: 

343 

344 DOSCollection([dosdata1, dosdata2]) 

345 

346 """ 

347 return _add_to_collection(other, self) 

348 

349 

350@singledispatch 

351def _add_to_collection(other: DOSCollection, 

352 collection: DOSCollection) -> DOSCollection: 

353 if isinstance(other, type(collection)): 

354 return type(collection)(list(collection) + list(other)) 

355 elif isinstance(other, DOSCollection): 

356 raise TypeError("Only DOSCollection objects of the same type may " 

357 "be joined with '+'.") 

358 else: 

359 raise TypeError("DOSCollection may only be joined to DOSData or " 

360 "DOSCollection objects with '+'.") 

361 

362 

363@_add_to_collection.register(DOSData) 

364def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection: 

365 """Return a new DOSCollection with an additional DOSData item""" 

366 return type(collection)(list(collection) + [other]) 

367 

368 

369class RawDOSCollection(DOSCollection): 

370 def __init__(self, dos_series: Iterable[RawDOSData]) -> None: 

371 super().__init__(dos_series) 

372 for dos_data in self: 

373 if not isinstance(dos_data, RawDOSData): 

374 raise TypeError("RawDOSCollection can only store " 

375 "RawDOSData objects.") 

376 

377 

378class GridDOSCollection(DOSCollection): 

379 def __init__(self, dos_series: Iterable[GridDOSData], 

380 energies: Optional[Sequence[float]] = None) -> None: 

381 dos_list = list(dos_series) 

382 if energies is None: 

383 if len(dos_list) == 0: 

384 raise ValueError("Must provide energies to create a " 

385 "GridDOSCollection without any DOS data.") 

386 self._energies = dos_list[0].get_energies() 

387 else: 

388 self._energies = np.asarray(energies) 

389 

390 self._weights = np.empty((len(dos_list), len(self._energies)), float) 

391 self._info = [] 

392 

393 for i, dos_data in enumerate(dos_list): 

394 if not isinstance(dos_data, GridDOSData): 

395 raise TypeError("GridDOSCollection can only store " 

396 "GridDOSData objects.") 

397 if (dos_data.get_energies().shape != self._energies.shape 

398 or not np.allclose(dos_data.get_energies(), self._energies)): 

399 raise ValueError("All GridDOSData objects in GridDOSCollection" 

400 " must have the same energy axis.") 

401 self._weights[i, :] = dos_data.get_weights() 

402 self._info.append(dos_data.info) 

403 

404 def get_energies(self) -> Sequence[float]: 

405 return self._energies.copy() 

406 

407 def get_all_weights(self) -> Sequence[Sequence[float]]: 

408 return self._weights.copy() 

409 

410 def __len__(self) -> int: 

411 return self._weights.shape[0] 

412 

413 @overload # noqa F811 

414 def __getitem__(self, item: int) -> DOSData: 

415 ... 

416 

417 @overload # noqa F811 

418 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811 

419 ... 

420 

421 def __getitem__(self, item): # noqa F811 

422 if isinstance(item, int): 

423 return GridDOSData(self._energies, self._weights[item, :], 

424 info=self._info[item]) 

425 elif isinstance(item, slice): 

426 return type(self)([self[i] for i in range(len(self))[item]]) 

427 else: 

428 raise TypeError("index in DOSCollection must be an integer or " 

429 "slice") 

430 

431 @classmethod 

432 def from_data(cls, 

433 energies: Sequence[float], 

434 weights: Sequence[Sequence[float]], 

435 info: Sequence[Info] = None) -> 'GridDOSCollection': 

436 """Create a GridDOSCollection from data with a common set of energies 

437 

438 This convenience method may also be more efficient as it limits 

439 redundant copying/checking of the data. 

440 

441 Args: 

442 energies: common set of energy values for input data 

443 weights: array of DOS weights with rows corresponding to different 

444 datasets 

445 info: sequence of info dicts corresponding to weights rows. 

446 

447 Returns: 

448 Collection of DOS data (in RawDOSData format) 

449 """ 

450 

451 weights_array = np.asarray(weights, dtype=float) 

452 if len(weights_array.shape) != 2: 

453 raise IndexError("Weights must be a 2-D array or nested sequence") 

454 if weights_array.shape[0] < 1: 

455 raise IndexError("Weights cannot be empty") 

456 if weights_array.shape[1] != len(energies): 

457 raise IndexError("Length of weights rows must equal size of x") 

458 

459 info = cls._check_weights_and_info(weights, info) 

460 

461 dos_collection = cls([GridDOSData(energies, weights_array[0])]) 

462 dos_collection._weights = weights_array 

463 dos_collection._info = list(info) 

464 

465 return dos_collection 

466 

467 def select(self, **info_selection: str) -> 'DOSCollection': 

468 """Narrow GridDOSCollection to items with specified info 

469 

470 For example, if :: 

471 

472 dc = GridDOSCollection([GridDOSData(x, y1, 

473 info={'a': '1', 'b': '1'}), 

474 GridDOSData(x, y2, 

475 info={'a': '2', 'b': '1'})]) 

476 

477 then :: 

478 

479 dc.select(b='1') 

480 

481 will return an identical object to dc, while :: 

482 

483 dc.select(a='1') 

484 

485 will return a DOSCollection with only the first item and :: 

486 

487 dc.select(a='2', b='1') 

488 

489 will return a DOSCollection with only the second item. 

490 

491 """ 

492 

493 matches = self._select_to_list(self, info_selection) 

494 if len(matches) == 0: 

495 return type(self)([], energies=self._energies) 

496 else: 

497 return type(self)(matches) 

498 

499 def select_not(self, **info_selection: str) -> 'DOSCollection': 

500 """Narrow GridDOSCollection to items without specified info 

501 

502 For example, if :: 

503 

504 dc = GridDOSCollection([GridDOSData(x, y1, 

505 info={'a': '1', 'b': '1'}), 

506 GridDOSData(x, y2, 

507 info={'a': '2', 'b': '1'})]) 

508 

509 then :: 

510 

511 dc.select_not(b='2') 

512 

513 will return an identical object to dc, while :: 

514 

515 dc.select_not(a='2') 

516 

517 will return a DOSCollection with only the first item and :: 

518 

519 dc.select_not(a='1', b='1') 

520 

521 will return a DOSCollection with only the second item. 

522 

523 """ 

524 matches = self._select_to_list(self, info_selection, negative=True) 

525 if len(matches) == 0: 

526 return type(self)([], energies=self._energies) 

527 else: 

528 return type(self)(matches) 

529 

530 def plot(self, 

531 npts: int = 0, 

532 xmin: float = None, 

533 xmax: float = None, 

534 width: float = None, 

535 smearing: str = 'Gauss', 

536 ax: 'matplotlib.axes.Axes' = None, 

537 show: bool = False, 

538 filename: str = None, 

539 mplargs: dict = None) -> 'matplotlib.axes.Axes': 

540 """Simple plot of collected DOS data, resampled onto a grid 

541 

542 If the special key 'label' is present in self.info, this will be set 

543 as the label for the plotted line (unless overruled in mplargs). The 

544 label is only seen if a legend is added to the plot (i.e. by calling 

545 `ax.legend()`). 

546 

547 Args: 

548 npts: 

549 Number of points in resampled x-axis. If set to zero (default), 

550 no resampling is performed and the stored data is plotted 

551 directly. 

552 xmin, xmax: 

553 output data range; this limits the resampling range as well as 

554 the plotting output 

555 width: Width of broadening kernel, passed to self.sample() 

556 smearing: selection of broadening kernel, passed to self.sample() 

557 ax: existing Matplotlib axes object. If not provided, a new figure 

558 with one set of axes will be created using Pyplot 

559 show: show the figure on-screen 

560 filename: if a path is given, save the figure to this file 

561 mplargs: additional arguments to pass to matplotlib plot command 

562 (e.g. {'linewidth': 2} for a thicker line). 

563 

564 Returns: 

565 Plotting axes. If "ax" was set, this is the same object. 

566 """ 

567 

568 # Apply defaults if necessary 

569 npts, width = GridDOSData._interpret_smearing_args(npts, width) 

570 

571 if npts: 

572 assert isinstance(width, float) 

573 dos = self.sample_grid(npts, 

574 xmin=xmin, xmax=xmax, 

575 width=width, smearing=smearing) 

576 else: 

577 dos = self 

578 

579 energies, all_y = dos._energies, dos._weights 

580 

581 all_labels = [DOSData.label_from_info(data.info) for data in self] 

582 

583 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

584 self._plot_broadened(ax, energies, all_y, all_labels, mplargs) 

585 

586 return ax 

587 

588 @staticmethod 

589 def _plot_broadened(ax: 'matplotlib.axes.Axes', 

590 energies: Sequence[float], 

591 all_y: np.ndarray, 

592 all_labels: Sequence[str], 

593 mplargs: Union[Dict, None]): 

594 """Plot DOS data with labels to axes 

595 

596 This is separated into another function so that subclasses can 

597 manipulate broadening, labels etc in their plot() method.""" 

598 if mplargs is None: 

599 mplargs = {} 

600 

601 all_lines = ax.plot(energies, all_y.T, **mplargs) 

602 for line, label in zip(all_lines, all_labels): 

603 line.set_label(label) 

604 ax.legend() 

605 

606 ax.set_xlim(left=min(energies), right=max(energies)) 

607 ax.set_ylim(bottom=0)