Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import collections
2from functools import reduce, singledispatch
3from typing import (Any, Dict, Iterable, List, Optional,
4 overload, Sequence, TypeVar, Union)
6import numpy as np
7from ase.spectrum.dosdata import DOSData, RawDOSData, GridDOSData, Info
8from ase.utils.plotting import SimplePlottingAxes
10# This import is for the benefit of type-checking / mypy
11if False:
12 import matplotlib.axes
15class DOSCollection(collections.abc.Sequence):
16 """Base class for a collection of DOSData objects"""
17 def __init__(self, dos_series: Iterable[DOSData]) -> None:
18 self._data = list(dos_series)
20 def _sample(self,
21 energies: Sequence[float],
22 width: float = 0.1,
23 smearing: str = 'Gauss') -> np.ndarray:
24 """Sample the DOS data at chosen points, with broadening
26 This samples the underlying DOS data in the same way as the .sample()
27 method of those DOSData items, returning a 2-D array with columns
28 corresponding to x and rows corresponding to the collected data series.
30 Args:
31 energies: energy values for sampling
32 width: Width of broadening kernel
33 smearing: selection of broadening kernel (only "Gauss" is currently
34 supported)
36 Returns:
37 Weights sampled from a broadened DOS at values corresponding to x,
38 in rows corresponding to DOSData entries contained in this object
39 """
41 if len(self) == 0:
42 raise IndexError("No data to sample")
44 return np.asarray(
45 [data._sample(energies, width=width, smearing=smearing)
46 for data in self])
48 def plot(self,
49 npts: int = 1000,
50 xmin: float = None,
51 xmax: float = None,
52 width: float = 0.1,
53 smearing: str = 'Gauss',
54 ax: 'matplotlib.axes.Axes' = None,
55 show: bool = False,
56 filename: str = None,
57 mplargs: dict = None) -> 'matplotlib.axes.Axes':
58 """Simple plot of collected DOS data, resampled onto a grid
60 If the special key 'label' is present in self.info, this will be set
61 as the label for the plotted line (unless overruled in mplargs). The
62 label is only seen if a legend is added to the plot (i.e. by calling
63 `ax.legend()`).
65 Args:
66 npts, xmin, xmax: output data range, as passed to self.sample_grid
67 width: Width of broadening kernel, passed to self.sample_grid()
68 smearing: selection of broadening kernel for self.sample_grid()
69 ax: existing Matplotlib axes object. If not provided, a new figure
70 with one set of axes will be created using Pyplot
71 show: show the figure on-screen
72 filename: if a path is given, save the figure to this file
73 mplargs: additional arguments to pass to matplotlib plot command
74 (e.g. {'linewidth': 2} for a thicker line).
76 Returns:
77 Plotting axes. If "ax" was set, this is the same object.
78 """
79 return self.sample_grid(npts,
80 xmin=xmin, xmax=xmax,
81 width=width, smearing=smearing
82 ).plot(npts=npts,
83 xmin=xmin, xmax=xmax,
84 width=width, smearing=smearing,
85 ax=ax, show=show, filename=filename,
86 mplargs=mplargs)
88 def sample_grid(self,
89 npts: int,
90 xmin: float = None,
91 xmax: float = None,
92 padding: float = 3,
93 width: float = 0.1,
94 smearing: str = 'Gauss',
95 ) -> 'GridDOSCollection':
96 """Sample the DOS data on an evenly-spaced energy grid
98 Args:
99 npts: Number of sampled points
100 xmin: Minimum sampled energy value; if unspecified, a default is
101 chosen
102 xmax: Maximum sampled energy value; if unspecified, a default is
103 chosen
104 padding: If xmin/xmax is unspecified, default value will be padded
105 by padding * width to avoid cutting off peaks.
106 width: Width of broadening kernel, passed to self.sample_grid()
107 smearing: selection of broadening kernel, for self.sample_grid()
109 Returns:
110 (energy values, sampled DOS)
111 """
112 if len(self) == 0:
113 raise IndexError("No data to sample")
115 if xmin is None:
116 xmin = (min(min(data.get_energies()) for data in self)
117 - (padding * width))
118 if xmax is None:
119 xmax = (max(max(data.get_energies()) for data in self)
120 + (padding * width))
122 return GridDOSCollection(
123 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width,
124 smearing=smearing)
125 for data in self])
127 @classmethod
128 def from_data(cls,
129 energies: Sequence[float],
130 weights: Sequence[Sequence[float]],
131 info: Sequence[Info] = None) -> 'DOSCollection':
132 """Create a DOSCollection from data sharing a common set of energies
134 This is a convenience method to be used when all the DOS data in the
135 collection has a common energy axis. There is no performance advantage
136 in using this method for the generic DOSCollection, but for
137 GridDOSCollection it is more efficient.
139 Args:
140 energy: common set of energy values for input data
141 weights: array of DOS weights with rows corresponding to different
142 datasets
143 info: sequence of info dicts corresponding to weights rows.
145 Returns:
146 Collection of DOS data (in RawDOSData format)
147 """
149 info = cls._check_weights_and_info(weights, info)
151 return cls(RawDOSData(energies, row_weights, row_info)
152 for row_weights, row_info in zip(weights, info))
154 @staticmethod
155 def _check_weights_and_info(weights: Sequence[Sequence[float]],
156 info: Union[Sequence[Info], None],
157 ) -> Sequence[Info]:
158 if info is None:
159 info = [{} for _ in range(len(weights))]
160 else:
161 if len(info) != len(weights):
162 raise ValueError("Length of info must match number of rows in "
163 "weights")
164 return info
166 @overload
167 def __getitem__(self, item: int) -> DOSData:
168 ...
170 @overload # noqa F811
171 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811
172 ...
174 def __getitem__(self, item): # noqa F811
175 if isinstance(item, int):
176 return self._data[item]
177 elif isinstance(item, slice):
178 return type(self)(self._data[item])
179 else:
180 raise TypeError("index in DOSCollection must be an integer or "
181 "slice")
183 def __len__(self) -> int:
184 return len(self._data)
186 def _almost_equals(self, other: Any) -> bool:
187 """Compare with another DOSCollection for testing purposes"""
188 if not isinstance(other, type(self)):
189 return False
190 elif not len(self) == len(other):
191 return False
192 else:
193 return all([a._almost_equals(b) for a, b in zip(self, other)])
195 def total(self) -> DOSData:
196 """Sum all the DOSData in this Collection and label it as 'Total'"""
197 data = self.sum_all()
198 data.info.update({'label': 'Total'})
199 return data
201 def sum_all(self) -> DOSData:
202 """Sum all the DOSData contained in this Collection"""
203 if len(self) == 0:
204 raise IndexError("No data to sum")
205 elif len(self) == 1:
206 data = self[0].copy()
207 else:
208 data = reduce(lambda x, y: x + y, self)
209 return data
211 D = TypeVar('D', bound=DOSData)
213 @staticmethod
214 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes
215 info_selection: Dict[str, str], # misses 'D' def
216 negative: bool = False) -> List[D]: # noqa: F821
217 query = set(info_selection.items())
219 if negative:
220 return [data for data in dos_collection
221 if not query.issubset(set(data.info.items()))]
222 else:
223 return [data for data in dos_collection
224 if query.issubset(set(data.info.items()))]
226 def select(self, **info_selection: str) -> 'DOSCollection':
227 """Narrow DOSCollection to items with specified info
229 For example, if ::
231 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
232 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
234 then ::
236 dc.select(b='1')
238 will return an identical object to dc, while ::
240 dc.select(a='1')
242 will return a DOSCollection with only the first item and ::
244 dc.select(a='2', b='1')
246 will return a DOSCollection with only the second item.
248 """
250 matches = self._select_to_list(self, info_selection)
251 return type(self)(matches)
253 def select_not(self, **info_selection: str) -> 'DOSCollection':
254 """Narrow DOSCollection to items without specified info
256 For example, if ::
258 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
259 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
261 then ::
263 dc.select_not(b='2')
265 will return an identical object to dc, while ::
267 dc.select_not(a='2')
269 will return a DOSCollection with only the first item and ::
271 dc.select_not(a='1', b='1')
273 will return a DOSCollection with only the second item.
275 """
276 matches = self._select_to_list(self, info_selection, negative=True)
277 return type(self)(matches)
279 def sum_by(self, *info_keys: str) -> 'DOSCollection':
280 """Return a DOSCollection with some data summed by common attributes
282 For example, if ::
284 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
285 DOSData(x2, y2, info={'a': '2', 'b': '1'}),
286 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
288 then ::
290 dc.sum_by('b')
292 will return a collection equivalent to ::
294 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'})
295 + DOSData(x2, y2, info={'a': '2', 'b': '1'}),
296 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
298 where the resulting contained DOSData have info attributes of
299 {'b': '1'} and {'b': '2'} respectively.
301 dc.sum_by('a', 'b') on the other hand would return the full three-entry
302 collection, as none of the entries have common 'a' *and* 'b' info.
304 """
306 def _matching_info_tuples(data: DOSData):
307 """Get relevent dict entries in tuple form
309 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3}
310 and info_keys = ('a', 'c')
312 then return (('a', 1), ('c': 3))
313 """
314 matched_keys = set(info_keys) & set(data.info)
315 return tuple(sorted([(key, data.info[key])
316 for key in matched_keys]))
318 # Sorting inside info matching helps set() to remove redundant matches;
319 # combos are then sorted() to ensure consistent output across sessions.
320 all_combos = map(_matching_info_tuples, self)
321 unique_combos = sorted(set(all_combos))
323 # For each key/value combination, perform a select() to obtain all
324 # the matching entries and sum them together.
325 collection_data = [self.select(**dict(combo)).sum_all()
326 for combo in unique_combos]
327 return type(self)(collection_data)
329 def __add__(self, other: Union['DOSCollection', DOSData]
330 ) -> 'DOSCollection':
331 """Join entries between two DOSCollection objects of the same type
333 It is also possible to add a single DOSData object without wrapping it
334 in a new collection: i.e. ::
336 DOSCollection([dosdata1]) + DOSCollection([dosdata2])
338 or ::
340 DOSCollection([dosdata1]) + dosdata2
342 will return ::
344 DOSCollection([dosdata1, dosdata2])
346 """
347 return _add_to_collection(other, self)
350@singledispatch
351def _add_to_collection(other: DOSCollection,
352 collection: DOSCollection) -> DOSCollection:
353 if isinstance(other, type(collection)):
354 return type(collection)(list(collection) + list(other))
355 elif isinstance(other, DOSCollection):
356 raise TypeError("Only DOSCollection objects of the same type may "
357 "be joined with '+'.")
358 else:
359 raise TypeError("DOSCollection may only be joined to DOSData or "
360 "DOSCollection objects with '+'.")
363@_add_to_collection.register(DOSData)
364def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection:
365 """Return a new DOSCollection with an additional DOSData item"""
366 return type(collection)(list(collection) + [other])
369class RawDOSCollection(DOSCollection):
370 def __init__(self, dos_series: Iterable[RawDOSData]) -> None:
371 super().__init__(dos_series)
372 for dos_data in self:
373 if not isinstance(dos_data, RawDOSData):
374 raise TypeError("RawDOSCollection can only store "
375 "RawDOSData objects.")
378class GridDOSCollection(DOSCollection):
379 def __init__(self, dos_series: Iterable[GridDOSData],
380 energies: Optional[Sequence[float]] = None) -> None:
381 dos_list = list(dos_series)
382 if energies is None:
383 if len(dos_list) == 0:
384 raise ValueError("Must provide energies to create a "
385 "GridDOSCollection without any DOS data.")
386 self._energies = dos_list[0].get_energies()
387 else:
388 self._energies = np.asarray(energies)
390 self._weights = np.empty((len(dos_list), len(self._energies)), float)
391 self._info = []
393 for i, dos_data in enumerate(dos_list):
394 if not isinstance(dos_data, GridDOSData):
395 raise TypeError("GridDOSCollection can only store "
396 "GridDOSData objects.")
397 if (dos_data.get_energies().shape != self._energies.shape
398 or not np.allclose(dos_data.get_energies(), self._energies)):
399 raise ValueError("All GridDOSData objects in GridDOSCollection"
400 " must have the same energy axis.")
401 self._weights[i, :] = dos_data.get_weights()
402 self._info.append(dos_data.info)
404 def get_energies(self) -> Sequence[float]:
405 return self._energies.copy()
407 def get_all_weights(self) -> Sequence[Sequence[float]]:
408 return self._weights.copy()
410 def __len__(self) -> int:
411 return self._weights.shape[0]
413 @overload # noqa F811
414 def __getitem__(self, item: int) -> DOSData:
415 ...
417 @overload # noqa F811
418 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811
419 ...
421 def __getitem__(self, item): # noqa F811
422 if isinstance(item, int):
423 return GridDOSData(self._energies, self._weights[item, :],
424 info=self._info[item])
425 elif isinstance(item, slice):
426 return type(self)([self[i] for i in range(len(self))[item]])
427 else:
428 raise TypeError("index in DOSCollection must be an integer or "
429 "slice")
431 @classmethod
432 def from_data(cls,
433 energies: Sequence[float],
434 weights: Sequence[Sequence[float]],
435 info: Sequence[Info] = None) -> 'GridDOSCollection':
436 """Create a GridDOSCollection from data with a common set of energies
438 This convenience method may also be more efficient as it limits
439 redundant copying/checking of the data.
441 Args:
442 energies: common set of energy values for input data
443 weights: array of DOS weights with rows corresponding to different
444 datasets
445 info: sequence of info dicts corresponding to weights rows.
447 Returns:
448 Collection of DOS data (in RawDOSData format)
449 """
451 weights_array = np.asarray(weights, dtype=float)
452 if len(weights_array.shape) != 2:
453 raise IndexError("Weights must be a 2-D array or nested sequence")
454 if weights_array.shape[0] < 1:
455 raise IndexError("Weights cannot be empty")
456 if weights_array.shape[1] != len(energies):
457 raise IndexError("Length of weights rows must equal size of x")
459 info = cls._check_weights_and_info(weights, info)
461 dos_collection = cls([GridDOSData(energies, weights_array[0])])
462 dos_collection._weights = weights_array
463 dos_collection._info = list(info)
465 return dos_collection
467 def select(self, **info_selection: str) -> 'DOSCollection':
468 """Narrow GridDOSCollection to items with specified info
470 For example, if ::
472 dc = GridDOSCollection([GridDOSData(x, y1,
473 info={'a': '1', 'b': '1'}),
474 GridDOSData(x, y2,
475 info={'a': '2', 'b': '1'})])
477 then ::
479 dc.select(b='1')
481 will return an identical object to dc, while ::
483 dc.select(a='1')
485 will return a DOSCollection with only the first item and ::
487 dc.select(a='2', b='1')
489 will return a DOSCollection with only the second item.
491 """
493 matches = self._select_to_list(self, info_selection)
494 if len(matches) == 0:
495 return type(self)([], energies=self._energies)
496 else:
497 return type(self)(matches)
499 def select_not(self, **info_selection: str) -> 'DOSCollection':
500 """Narrow GridDOSCollection to items without specified info
502 For example, if ::
504 dc = GridDOSCollection([GridDOSData(x, y1,
505 info={'a': '1', 'b': '1'}),
506 GridDOSData(x, y2,
507 info={'a': '2', 'b': '1'})])
509 then ::
511 dc.select_not(b='2')
513 will return an identical object to dc, while ::
515 dc.select_not(a='2')
517 will return a DOSCollection with only the first item and ::
519 dc.select_not(a='1', b='1')
521 will return a DOSCollection with only the second item.
523 """
524 matches = self._select_to_list(self, info_selection, negative=True)
525 if len(matches) == 0:
526 return type(self)([], energies=self._energies)
527 else:
528 return type(self)(matches)
530 def plot(self,
531 npts: int = 0,
532 xmin: float = None,
533 xmax: float = None,
534 width: float = None,
535 smearing: str = 'Gauss',
536 ax: 'matplotlib.axes.Axes' = None,
537 show: bool = False,
538 filename: str = None,
539 mplargs: dict = None) -> 'matplotlib.axes.Axes':
540 """Simple plot of collected DOS data, resampled onto a grid
542 If the special key 'label' is present in self.info, this will be set
543 as the label for the plotted line (unless overruled in mplargs). The
544 label is only seen if a legend is added to the plot (i.e. by calling
545 `ax.legend()`).
547 Args:
548 npts:
549 Number of points in resampled x-axis. If set to zero (default),
550 no resampling is performed and the stored data is plotted
551 directly.
552 xmin, xmax:
553 output data range; this limits the resampling range as well as
554 the plotting output
555 width: Width of broadening kernel, passed to self.sample()
556 smearing: selection of broadening kernel, passed to self.sample()
557 ax: existing Matplotlib axes object. If not provided, a new figure
558 with one set of axes will be created using Pyplot
559 show: show the figure on-screen
560 filename: if a path is given, save the figure to this file
561 mplargs: additional arguments to pass to matplotlib plot command
562 (e.g. {'linewidth': 2} for a thicker line).
564 Returns:
565 Plotting axes. If "ax" was set, this is the same object.
566 """
568 # Apply defaults if necessary
569 npts, width = GridDOSData._interpret_smearing_args(npts, width)
571 if npts:
572 assert isinstance(width, float)
573 dos = self.sample_grid(npts,
574 xmin=xmin, xmax=xmax,
575 width=width, smearing=smearing)
576 else:
577 dos = self
579 energies, all_y = dos._energies, dos._weights
581 all_labels = [DOSData.label_from_info(data.info) for data in self]
583 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
584 self._plot_broadened(ax, energies, all_y, all_labels, mplargs)
586 return ax
588 @staticmethod
589 def _plot_broadened(ax: 'matplotlib.axes.Axes',
590 energies: Sequence[float],
591 all_y: np.ndarray,
592 all_labels: Sequence[str],
593 mplargs: Union[Dict, None]):
594 """Plot DOS data with labels to axes
596 This is separated into another function so that subclasses can
597 manipulate broadening, labels etc in their plot() method."""
598 if mplargs is None:
599 mplargs = {}
601 all_lines = ax.plot(energies, all_y.T, **mplargs)
602 for line, label in zip(all_lines, all_labels):
603 line.set_label(label)
604 ax.legend()
606 ax.set_xlim(left=min(energies), right=max(energies))
607 ax.set_ylim(bottom=0)