Coverage for /builds/debichem-team/python-ase/ase/spectrum/doscollection.py: 97.84%
185 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import collections
2from functools import reduce, singledispatch
3from typing import (
4 Any,
5 Dict,
6 Iterable,
7 List,
8 Optional,
9 Sequence,
10 TypeVar,
11 Union,
12 overload,
13)
15import numpy as np
16from matplotlib.axes import Axes
18from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData
19from ase.utils.plotting import SimplePlottingAxes
22class DOSCollection(collections.abc.Sequence):
23 """Base class for a collection of DOSData objects"""
25 def __init__(self, dos_series: Iterable[DOSData]) -> None:
26 self._data = list(dos_series)
28 def _sample(self,
29 energies: Floats,
30 width: float = 0.1,
31 smearing: str = 'Gauss') -> np.ndarray:
32 """Sample the DOS data at chosen points, with broadening
34 This samples the underlying DOS data in the same way as the .sample()
35 method of those DOSData items, returning a 2-D array with columns
36 corresponding to x and rows corresponding to the collected data series.
38 Args:
39 energies: energy values for sampling
40 width: Width of broadening kernel
41 smearing: selection of broadening kernel (only "Gauss" is currently
42 supported)
44 Returns:
45 Weights sampled from a broadened DOS at values corresponding to x,
46 in rows corresponding to DOSData entries contained in this object
47 """
49 if len(self) == 0:
50 raise IndexError("No data to sample")
52 return np.asarray(
53 [data._sample(energies, width=width, smearing=smearing)
54 for data in self])
56 def plot(self,
57 npts: int = 1000,
58 xmin: float = None,
59 xmax: float = None,
60 width: float = 0.1,
61 smearing: str = 'Gauss',
62 ax: Axes = None,
63 show: bool = False,
64 filename: str = None,
65 mplargs: dict = None) -> Axes:
66 """Simple plot of collected DOS data, resampled onto a grid
68 If the special key 'label' is present in self.info, this will be set
69 as the label for the plotted line (unless overruled in mplargs). The
70 label is only seen if a legend is added to the plot (i.e. by calling
71 `ax.legend()`).
73 Args:
74 npts, xmin, xmax: output data range, as passed to self.sample_grid
75 width: Width of broadening kernel, passed to self.sample_grid()
76 smearing: selection of broadening kernel for self.sample_grid()
77 ax: existing Matplotlib axes object. If not provided, a new figure
78 with one set of axes will be created using Pyplot
79 show: show the figure on-screen
80 filename: if a path is given, save the figure to this file
81 mplargs: additional arguments to pass to matplotlib plot command
82 (e.g. {'linewidth': 2} for a thicker line).
84 Returns:
85 Plotting axes. If "ax" was set, this is the same object.
86 """
87 return self.sample_grid(npts,
88 xmin=xmin, xmax=xmax,
89 width=width, smearing=smearing
90 ).plot(npts=npts,
91 xmin=xmin, xmax=xmax,
92 width=width, smearing=smearing,
93 ax=ax, show=show, filename=filename,
94 mplargs=mplargs)
96 def sample_grid(self,
97 npts: int,
98 xmin: float = None,
99 xmax: float = None,
100 padding: float = 3,
101 width: float = 0.1,
102 smearing: str = 'Gauss',
103 ) -> 'GridDOSCollection':
104 """Sample the DOS data on an evenly-spaced energy grid
106 Args:
107 npts: Number of sampled points
108 xmin: Minimum sampled energy value; if unspecified, a default is
109 chosen
110 xmax: Maximum sampled energy value; if unspecified, a default is
111 chosen
112 padding: If xmin/xmax is unspecified, default value will be padded
113 by padding * width to avoid cutting off peaks.
114 width: Width of broadening kernel, passed to self.sample_grid()
115 smearing: selection of broadening kernel, for self.sample_grid()
117 Returns:
118 (energy values, sampled DOS)
119 """
120 if len(self) == 0:
121 raise IndexError("No data to sample")
123 if xmin is None:
124 xmin = (min(min(data.get_energies()) for data in self)
125 - (padding * width))
126 if xmax is None:
127 xmax = (max(max(data.get_energies()) for data in self)
128 + (padding * width))
130 return GridDOSCollection(
131 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width,
132 smearing=smearing)
133 for data in self])
135 @classmethod
136 def from_data(cls,
137 energies: Floats,
138 weights: Sequence[Floats],
139 info: Sequence[Info] = None) -> 'DOSCollection':
140 """Create a DOSCollection from data sharing a common set of energies
142 This is a convenience method to be used when all the DOS data in the
143 collection has a common energy axis. There is no performance advantage
144 in using this method for the generic DOSCollection, but for
145 GridDOSCollection it is more efficient.
147 Args:
148 energy: common set of energy values for input data
149 weights: array of DOS weights with rows corresponding to different
150 datasets
151 info: sequence of info dicts corresponding to weights rows.
153 Returns:
154 Collection of DOS data (in RawDOSData format)
155 """
157 info = cls._check_weights_and_info(weights, info)
159 return cls(RawDOSData(energies, row_weights, row_info)
160 for row_weights, row_info in zip(weights, info))
162 @staticmethod
163 def _check_weights_and_info(weights: Sequence[Floats],
164 info: Optional[Sequence[Info]],
165 ) -> Sequence[Info]:
166 if info is None:
167 info = [{} for _ in range(len(weights))]
168 else:
169 if len(info) != len(weights):
170 raise ValueError("Length of info must match number of rows in "
171 "weights")
172 return info
174 @overload
175 def __getitem__(self, item: int) -> DOSData:
176 ...
178 @overload # noqa F811
179 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811
180 ...
182 def __getitem__(self, item): # noqa F811
183 if isinstance(item, int):
184 return self._data[item]
185 elif isinstance(item, slice):
186 return type(self)(self._data[item])
187 else:
188 raise TypeError("index in DOSCollection must be an integer or "
189 "slice")
191 def __len__(self) -> int:
192 return len(self._data)
194 def _almost_equals(self, other: Any) -> bool:
195 """Compare with another DOSCollection for testing purposes"""
196 if not isinstance(other, type(self)):
197 return False
198 elif len(self) != len(other):
199 return False
200 else:
201 return all(a._almost_equals(b) for a, b in zip(self, other))
203 def total(self) -> DOSData:
204 """Sum all the DOSData in this Collection and label it as 'Total'"""
205 data = self.sum_all()
206 data.info.update({'label': 'Total'})
207 return data
209 def sum_all(self) -> DOSData:
210 """Sum all the DOSData contained in this Collection"""
211 if len(self) == 0:
212 raise IndexError("No data to sum")
213 elif len(self) == 1:
214 data = self[0].copy()
215 else:
216 data = reduce(lambda x, y: x + y, self)
217 return data
219 D = TypeVar('D', bound=DOSData)
221 @staticmethod
222 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes
223 info_selection: Dict[str, str], # misses 'D' def
224 negative: bool = False) -> List[D]: # noqa: F821
225 query = set(info_selection.items())
227 if negative:
228 return [data for data in dos_collection
229 if not query.issubset(set(data.info.items()))]
230 else:
231 return [data for data in dos_collection
232 if query.issubset(set(data.info.items()))]
234 def select(self, **info_selection: str) -> 'DOSCollection':
235 """Narrow DOSCollection to items with specified info
237 For example, if ::
239 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
240 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
242 then ::
244 dc.select(b='1')
246 will return an identical object to dc, while ::
248 dc.select(a='1')
250 will return a DOSCollection with only the first item and ::
252 dc.select(a='2', b='1')
254 will return a DOSCollection with only the second item.
256 """
258 matches = self._select_to_list(self, info_selection)
259 return type(self)(matches)
261 def select_not(self, **info_selection: str) -> 'DOSCollection':
262 """Narrow DOSCollection to items without specified info
264 For example, if ::
266 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
267 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
269 then ::
271 dc.select_not(b='2')
273 will return an identical object to dc, while ::
275 dc.select_not(a='2')
277 will return a DOSCollection with only the first item and ::
279 dc.select_not(a='1', b='1')
281 will return a DOSCollection with only the second item.
283 """
284 matches = self._select_to_list(self, info_selection, negative=True)
285 return type(self)(matches)
287 def sum_by(self, *info_keys: str) -> 'DOSCollection':
288 """Return a DOSCollection with some data summed by common attributes
290 For example, if ::
292 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
293 DOSData(x2, y2, info={'a': '2', 'b': '1'}),
294 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
296 then ::
298 dc.sum_by('b')
300 will return a collection equivalent to ::
302 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'})
303 + DOSData(x2, y2, info={'a': '2', 'b': '1'}),
304 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
306 where the resulting contained DOSData have info attributes of
307 {'b': '1'} and {'b': '2'} respectively.
309 dc.sum_by('a', 'b') on the other hand would return the full three-entry
310 collection, as none of the entries have common 'a' *and* 'b' info.
312 """
314 def _matching_info_tuples(data: DOSData):
315 """Get relevent dict entries in tuple form
317 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3}
318 and info_keys = ('a', 'c')
320 then return (('a', 1), ('c': 3))
321 """
322 matched_keys = set(info_keys) & set(data.info)
323 return tuple(sorted([(key, data.info[key])
324 for key in matched_keys]))
326 # Sorting inside info matching helps set() to remove redundant matches;
327 # combos are then sorted() to ensure consistent output across sessions.
328 all_combos = map(_matching_info_tuples, self)
329 unique_combos = sorted(set(all_combos))
331 # For each key/value combination, perform a select() to obtain all
332 # the matching entries and sum them together.
333 collection_data = [self.select(**dict(combo)).sum_all()
334 for combo in unique_combos]
335 return type(self)(collection_data)
337 def __add__(self, other: Union['DOSCollection', DOSData]
338 ) -> 'DOSCollection':
339 """Join entries between two DOSCollection objects of the same type
341 It is also possible to add a single DOSData object without wrapping it
342 in a new collection: i.e. ::
344 DOSCollection([dosdata1]) + DOSCollection([dosdata2])
346 or ::
348 DOSCollection([dosdata1]) + dosdata2
350 will return ::
352 DOSCollection([dosdata1, dosdata2])
354 """
355 return _add_to_collection(other, self)
358@singledispatch
359def _add_to_collection(other: Union[DOSData, DOSCollection],
360 collection: DOSCollection) -> DOSCollection:
361 if isinstance(other, type(collection)):
362 return type(collection)(list(collection) + list(other))
363 elif isinstance(other, DOSCollection):
364 raise TypeError("Only DOSCollection objects of the same type may "
365 "be joined with '+'.")
366 else:
367 raise TypeError("DOSCollection may only be joined to DOSData or "
368 "DOSCollection objects with '+'.")
371@_add_to_collection.register(DOSData)
372def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection:
373 """Return a new DOSCollection with an additional DOSData item"""
374 return type(collection)(list(collection) + [other])
377class RawDOSCollection(DOSCollection):
378 def __init__(self, dos_series: Iterable[RawDOSData]) -> None:
379 super().__init__(dos_series)
380 for dos_data in self:
381 if not isinstance(dos_data, RawDOSData):
382 raise TypeError("RawDOSCollection can only store "
383 "RawDOSData objects.")
386class GridDOSCollection(DOSCollection):
387 def __init__(self, dos_series: Iterable[GridDOSData],
388 energies: Optional[Floats] = None) -> None:
389 dos_list = list(dos_series)
390 if energies is None:
391 if len(dos_list) == 0:
392 raise ValueError("Must provide energies to create a "
393 "GridDOSCollection without any DOS data.")
394 self._energies = dos_list[0].get_energies()
395 else:
396 self._energies = np.asarray(energies)
398 self._weights = np.empty((len(dos_list), len(self._energies)), float)
399 self._info = []
401 for i, dos_data in enumerate(dos_list):
402 if not isinstance(dos_data, GridDOSData):
403 raise TypeError("GridDOSCollection can only store "
404 "GridDOSData objects.")
405 if (dos_data.get_energies().shape != self._energies.shape
406 or not np.allclose(dos_data.get_energies(),
407 self._energies)):
408 raise ValueError("All GridDOSData objects in GridDOSCollection"
409 " must have the same energy axis.")
410 self._weights[i, :] = dos_data.get_weights()
411 self._info.append(dos_data.info)
413 def get_energies(self) -> Floats:
414 return self._energies.copy()
416 def get_all_weights(self) -> Union[Sequence[Floats], np.ndarray]:
417 return self._weights.copy()
419 def __len__(self) -> int:
420 return self._weights.shape[0]
422 @overload # noqa F811
423 def __getitem__(self, item: int) -> DOSData:
424 ...
426 @overload # noqa F811
427 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811
428 ...
430 def __getitem__(self, item): # noqa F811
431 if isinstance(item, int):
432 return GridDOSData(self._energies, self._weights[item, :],
433 info=self._info[item])
434 elif isinstance(item, slice):
435 return type(self)([self[i] for i in range(len(self))[item]])
436 else:
437 raise TypeError("index in DOSCollection must be an integer or "
438 "slice")
440 @classmethod
441 def from_data(cls,
442 energies: Floats,
443 weights: Sequence[Floats],
444 info: Sequence[Info] = None) -> 'GridDOSCollection':
445 """Create a GridDOSCollection from data with a common set of energies
447 This convenience method may also be more efficient as it limits
448 redundant copying/checking of the data.
450 Args:
451 energies: common set of energy values for input data
452 weights: array of DOS weights with rows corresponding to different
453 datasets
454 info: sequence of info dicts corresponding to weights rows.
456 Returns:
457 Collection of DOS data (in RawDOSData format)
458 """
460 weights_array = np.asarray(weights, dtype=float)
461 if len(weights_array.shape) != 2:
462 raise IndexError("Weights must be a 2-D array or nested sequence")
463 if weights_array.shape[0] < 1:
464 raise IndexError("Weights cannot be empty")
465 if weights_array.shape[1] != len(energies):
466 raise IndexError("Length of weights rows must equal size of x")
468 info = cls._check_weights_and_info(weights, info)
470 dos_collection = cls([GridDOSData(energies, weights_array[0])])
471 dos_collection._weights = weights_array
472 dos_collection._info = list(info)
474 return dos_collection
476 def select(self, **info_selection: str) -> 'DOSCollection':
477 """Narrow GridDOSCollection to items with specified info
479 For example, if ::
481 dc = GridDOSCollection([GridDOSData(x, y1,
482 info={'a': '1', 'b': '1'}),
483 GridDOSData(x, y2,
484 info={'a': '2', 'b': '1'})])
486 then ::
488 dc.select(b='1')
490 will return an identical object to dc, while ::
492 dc.select(a='1')
494 will return a DOSCollection with only the first item and ::
496 dc.select(a='2', b='1')
498 will return a DOSCollection with only the second item.
500 """
502 matches = self._select_to_list(self, info_selection)
503 if len(matches) == 0:
504 return type(self)([], energies=self._energies)
505 else:
506 return type(self)(matches)
508 def select_not(self, **info_selection: str) -> 'DOSCollection':
509 """Narrow GridDOSCollection to items without specified info
511 For example, if ::
513 dc = GridDOSCollection([GridDOSData(x, y1,
514 info={'a': '1', 'b': '1'}),
515 GridDOSData(x, y2,
516 info={'a': '2', 'b': '1'})])
518 then ::
520 dc.select_not(b='2')
522 will return an identical object to dc, while ::
524 dc.select_not(a='2')
526 will return a DOSCollection with only the first item and ::
528 dc.select_not(a='1', b='1')
530 will return a DOSCollection with only the second item.
532 """
533 matches = self._select_to_list(self, info_selection, negative=True)
534 if len(matches) == 0:
535 return type(self)([], energies=self._energies)
536 else:
537 return type(self)(matches)
539 def plot(self,
540 npts: int = 0,
541 xmin: float = None,
542 xmax: float = None,
543 width: float = None,
544 smearing: str = 'Gauss',
545 ax: Axes = None,
546 show: bool = False,
547 filename: str = None,
548 mplargs: dict = None) -> Axes:
549 """Simple plot of collected DOS data, resampled onto a grid
551 If the special key 'label' is present in self.info, this will be set
552 as the label for the plotted line (unless overruled in mplargs). The
553 label is only seen if a legend is added to the plot (i.e. by calling
554 `ax.legend()`).
556 Args:
557 npts:
558 Number of points in resampled x-axis. If set to zero (default),
559 no resampling is performed and the stored data is plotted
560 directly.
561 xmin, xmax:
562 output data range; this limits the resampling range as well as
563 the plotting output
564 width: Width of broadening kernel, passed to self.sample()
565 smearing: selection of broadening kernel, passed to self.sample()
566 ax: existing Matplotlib axes object. If not provided, a new figure
567 with one set of axes will be created using Pyplot
568 show: show the figure on-screen
569 filename: if a path is given, save the figure to this file
570 mplargs: additional arguments to pass to matplotlib plot command
571 (e.g. {'linewidth': 2} for a thicker line).
573 Returns:
574 Plotting axes. If "ax" was set, this is the same object.
575 """
577 # Apply defaults if necessary
578 npts, width = GridDOSData._interpret_smearing_args(npts, width)
580 if npts:
581 assert isinstance(width, float)
582 dos = self.sample_grid(npts,
583 xmin=xmin, xmax=xmax,
584 width=width, smearing=smearing)
585 else:
586 dos = self
588 energies, all_y = dos._energies, dos._weights
590 all_labels = [DOSData.label_from_info(data.info) for data in self]
592 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
593 self._plot_broadened(ax, energies, all_y, all_labels, mplargs)
595 return ax
597 @staticmethod
598 def _plot_broadened(ax: Axes,
599 energies: Floats,
600 all_y: np.ndarray,
601 all_labels: Sequence[str],
602 mplargs: Optional[Dict]):
603 """Plot DOS data with labels to axes
605 This is separated into another function so that subclasses can
606 manipulate broadening, labels etc in their plot() method."""
607 if mplargs is None:
608 mplargs = {}
610 all_lines = ax.plot(energies, all_y.T, **mplargs)
611 for line, label in zip(all_lines, all_labels):
612 line.set_label(label)
613 ax.legend()
615 ax.set_xlim(left=min(energies), right=max(energies))
616 ax.set_ylim(bottom=0)