Coverage for /builds/debichem-team/python-ase/ase/spectrum/dosdata.py: 100.00%
152 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1# Refactor of DOS-like data objects
2# towards replacing ase.dft.dos and ase.dft.pdos
3import warnings
4from abc import ABCMeta, abstractmethod
5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union
7import numpy as np
8from matplotlib.axes import Axes
10from ase.utils.plotting import SimplePlottingAxes
12# For now we will be strict about Info and say it has to be str->str. Perhaps
13# later we will allow other types that have reliable comparison operations.
14Info = Dict[str, str]
16# Still no good solution to type checking with arrays.
17Floats = Union[Sequence[float], np.ndarray]
20class DOSData(metaclass=ABCMeta):
21 """Abstract base class for a single series of DOS-like data
23 Only the 'info' is a mutable attribute; DOS data is set at init"""
25 def __init__(self,
26 info: Info = None) -> None:
27 if info is None:
28 self.info = {}
29 elif isinstance(info, dict):
30 self.info = info
31 else:
32 raise TypeError("Info must be a dict or None")
34 @abstractmethod
35 def get_energies(self) -> Floats:
36 """Get energy data stored in this object"""
38 @abstractmethod
39 def get_weights(self) -> Floats:
40 """Get DOS weights stored in this object"""
42 @abstractmethod
43 def copy(self) -> 'DOSData':
44 """Returns a copy in which info dict can be safely mutated"""
46 def _sample(self,
47 energies: Floats,
48 width: float = 0.1,
49 smearing: str = 'Gauss') -> np.ndarray:
50 """Sample the DOS data at chosen points, with broadening
52 Note that no correction is made here for the sampling bin width; total
53 intensity will vary with sampling density.
55 Args:
56 energies: energy values for sampling
57 width: Width of broadening kernel
58 smearing: selection of broadening kernel (only "Gauss" is currently
59 supported)
61 Returns:
62 Weights sampled from a broadened DOS at values corresponding to x
63 """
65 self._check_positive_width(width)
66 weights_grid = np.zeros(len(energies), float)
67 weights = self.get_weights()
68 energies = np.asarray(energies, float)
70 for i, raw_energy in enumerate(self.get_energies()):
71 delta = self._delta(energies, raw_energy, width, smearing=smearing)
72 weights_grid += weights[i] * delta
73 return weights_grid
75 def _almost_equals(self, other: Any) -> bool:
76 """Compare with another DOSData for testing purposes"""
77 if not isinstance(other, type(self)):
78 return False
79 if self.info != other.info:
80 return False
81 if not np.allclose(self.get_weights(), other.get_weights()):
82 return False
83 return np.allclose(self.get_energies(), other.get_energies())
85 @staticmethod
86 def _delta(x: np.ndarray,
87 x0: float,
88 width: float,
89 smearing: str = 'Gauss') -> np.ndarray:
90 """Return a delta-function centered at 'x0'.
92 This function is used with numpy broadcasting; if x is a row and x0 is
93 a column vector, the returned data will be a 2D array with each row
94 corresponding to a different delta center.
95 """
96 if smearing.lower() == 'gauss':
97 x1 = -0.5 * ((x - x0) / width)**2
98 return np.exp(x1) / (np.sqrt(2 * np.pi) * width)
99 else:
100 msg = 'Requested smearing type not recognized. Got {}'.format(
101 smearing)
102 raise ValueError(msg)
104 @staticmethod
105 def _check_positive_width(width):
106 if width <= 0.0:
107 msg = 'Cannot add 0 or negative width smearing'
108 raise ValueError(msg)
110 def sample_grid(self,
111 npts: int,
112 xmin: float = None,
113 xmax: float = None,
114 padding: float = 3,
115 width: float = 0.1,
116 smearing: str = 'Gauss',
117 ) -> 'GridDOSData':
118 """Sample the DOS data on an evenly-spaced energy grid
120 Args:
121 npts: Number of sampled points
122 xmin: Minimum sampled x value; if unspecified, a default is chosen
123 xmax: Maximum sampled x value; if unspecified, a default is chosen
124 padding: If xmin/xmax is unspecified, default value will be padded
125 by padding * width to avoid cutting off peaks.
126 width: Width of broadening kernel
127 smearing: selection of broadening kernel (only 'Gauss' is
128 implemented)
130 Returns:
131 (energy values, sampled DOS)
132 """
134 if xmin is None:
135 xmin = min(self.get_energies()) - (padding * width)
136 if xmax is None:
137 xmax = max(self.get_energies()) + (padding * width)
138 energies_grid = np.linspace(xmin, xmax, npts)
139 weights_grid = self._sample(energies_grid, width=width,
140 smearing=smearing)
142 return GridDOSData(energies_grid, weights_grid, info=self.info.copy())
144 def plot(self,
145 npts: int = 1000,
146 xmin: float = None,
147 xmax: float = None,
148 width: float = 0.1,
149 smearing: str = 'Gauss',
150 ax: Axes = None,
151 show: bool = False,
152 filename: str = None,
153 mplargs: dict = None) -> Axes:
154 """Simple 1-D plot of DOS data, resampled onto a grid
156 If the special key 'label' is present in self.info, this will be set
157 as the label for the plotted line (unless overruled in mplargs). The
158 label is only seen if a legend is added to the plot (i.e. by calling
159 ``ax.legend()``).
161 Args:
162 npts, xmin, xmax: output data range, as passed to self.sample_grid
163 width: Width of broadening kernel for self.sample_grid()
164 smearing: selection of broadening kernel for self.sample_grid()
165 ax: existing Matplotlib axes object. If not provided, a new figure
166 with one set of axes will be created using Pyplot
167 show: show the figure on-screen
168 filename: if a path is given, save the figure to this file
169 mplargs: additional arguments to pass to matplotlib plot command
170 (e.g. {'linewidth': 2} for a thicker line).
173 Returns:
174 Plotting axes. If "ax" was set, this is the same object.
175 """
177 if mplargs is None:
178 mplargs = {}
179 if 'label' not in mplargs:
180 mplargs.update({'label': self.label_from_info(self.info)})
182 return self.sample_grid(npts, xmin=xmin, xmax=xmax,
183 width=width,
184 smearing=smearing
185 ).plot(ax=ax, xmin=xmin, xmax=xmax,
186 show=show, filename=filename,
187 mplargs=mplargs)
189 @staticmethod
190 def label_from_info(info: Dict[str, str]):
191 """Generate an automatic legend label from info dict"""
192 if 'label' in info:
193 return info['label']
194 else:
195 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}',
196 info.items()))
199class GeneralDOSData(DOSData):
200 """Base class for a single series of DOS-like data
202 Only the 'info' is a mutable attribute; DOS data is set at init
204 This is the base class for DOSData objects that accept/set seperate
205 "energies" and "weights" sequences of equal length at init.
207 """
209 def __init__(self,
210 energies: Floats,
211 weights: Floats,
212 info: Info = None) -> None:
213 super().__init__(info=info)
215 n_entries = len(energies)
216 if len(weights) != n_entries:
217 raise ValueError("Energies and weights must be the same length")
219 # Internally store the data as a np array with two rows; energy, weight
220 self._data = np.empty((2, n_entries), dtype=float, order='C')
221 self._data[0, :] = energies
222 self._data[1, :] = weights
224 def get_energies(self) -> np.ndarray:
225 return self._data[0, :].copy()
227 def get_weights(self) -> np.ndarray:
228 return self._data[1, :].copy()
230 D = TypeVar('D', bound='GeneralDOSData')
232 def copy(self: D) -> D: # noqa F821
233 return type(self)(self.get_energies(), self.get_weights(),
234 info=self.info.copy())
237class RawDOSData(GeneralDOSData):
238 """A collection of weighted delta functions which sum to form a DOS
240 This is an appropriate data container for density-of-states (DOS) or
241 spectral data where the energy data values not form a known regular
242 grid. The data may be plotted or resampled for further analysis using the
243 sample_grid() and plot() methods. Multiple weights at the same
244 energy value will *only* be combined in output data, and data stored in
245 RawDOSData is never resampled. A plot_deltas() function is also provided
246 which plots the raw data.
248 Metadata may be stored in the info dict, in which keys and values must be
249 strings. This data is used for selecting and combining multiple DOSData
250 objects in a DOSCollection object.
252 When RawDOSData objects are combined with the addition operator::
254 big_dos = raw_dos_1 + raw_dos_2
256 the energy and weights data is *concatenated* (i.e. combined without
257 sorting or replacement) and the new info dictionary consists of the
258 *intersection* of the inputs: only key-value pairs that were common to both
259 of the input objects will be retained in the new combined object. For
260 example::
262 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'})
263 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'}))
265 will yield the equivalent of::
267 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'})
269 """
271 def __add__(self, other: 'RawDOSData') -> 'RawDOSData':
272 if not isinstance(other, RawDOSData):
273 raise TypeError("RawDOSData can only be combined with other "
274 "RawDOSData objects")
276 # Take intersection of metadata (i.e. only common entries are retained)
277 new_info = dict(set(self.info.items()) & set(other.info.items()))
279 # Concatenate the energy/weight data
280 new_data = np.concatenate((self._data, other._data), axis=1)
282 new_object = RawDOSData([], [], info=new_info)
283 new_object._data = new_data
285 return new_object
287 def plot_deltas(self,
288 ax: Axes = None,
289 show: bool = False,
290 filename: str = None,
291 mplargs: dict = None) -> Axes:
292 """Simple plot of sparse DOS data as a set of delta functions
294 Items at the same x-value can overlap and will not be summed together
296 Args:
297 ax: existing Matplotlib axes object. If not provided, a new figure
298 with one set of axes will be created using Pyplot
299 show: show the figure on-screen
300 filename: if a path is given, save the figure to this file
301 mplargs: additional arguments to pass to matplotlib Axes.vlines
302 command (e.g. {'linewidth': 2} for a thicker line).
304 Returns:
305 Plotting axes. If "ax" was set, this is the same object.
306 """
308 if mplargs is None:
309 mplargs = {}
311 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
312 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs)
314 return ax
317class GridDOSData(GeneralDOSData):
318 """A collection of regularly-sampled data which represents a DOS
320 This is an appropriate data container for density-of-states (DOS) or
321 spectral data where the intensity values form a regular grid. This
322 is generally the result of sampling or integrating into discrete
323 bins, rather than a collection of unique states. The data may be
324 plotted or resampled for further analysis using the sample_grid()
325 and plot() methods.
327 Metadata may be stored in the info dict, in which keys and values must be
328 strings. This data is used for selecting and combining multiple DOSData
329 objects in a DOSCollection object.
331 When RawDOSData objects are combined with the addition operator::
333 big_dos = raw_dos_1 + raw_dos_2
335 the weights data is *summed* (requiring a consistent energy grid) and the
336 new info dictionary consists of the *intersection* of the inputs: only
337 key-value pairs that were common to both of the input objects will be
338 retained in the new combined object. For example::
340 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3],
341 info={'symbol': 'O', 'index': '1'})
342 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6],
343 info={'symbol': 'O', 'index': '2'}))
345 will yield the equivalent of::
347 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'})
349 """
351 def __init__(self,
352 energies: Floats,
353 weights: Floats,
354 info: Info = None) -> None:
355 n_entries = len(energies)
356 if not np.allclose(energies,
357 np.linspace(energies[0], energies[-1], n_entries)):
358 raise ValueError("Energies must be an evenly-spaced 1-D grid")
360 if len(weights) != n_entries:
361 raise ValueError("Energies and weights must be the same length")
363 super().__init__(energies, weights, info=info)
364 self.sigma_cutoff = 3
366 def _check_spacing(self, width) -> float:
367 current_spacing = self._data[0, 1] - self._data[0, 0]
368 if width < (2 * current_spacing):
369 warnings.warn(
370 "The broadening width is small compared to the original "
371 "sampling density. The results are unlikely to be smooth.")
372 return current_spacing
374 def _sample(self,
375 energies: Floats,
376 width: float = 0.1,
377 smearing: str = 'Gauss') -> np.ndarray:
378 current_spacing = self._check_spacing(width)
379 return super()._sample(energies=energies,
380 width=width, smearing=smearing
381 ) * current_spacing
383 def __add__(self, other: 'GridDOSData') -> 'GridDOSData':
384 # This method uses direct access to the mutable energy and weights data
385 # (self._data) to avoid redundant copying operations. The __init__
386 # method of GridDOSData will write this to a new array, so on this
387 # occasion it is safe to pass references to the mutable data.
389 if not isinstance(other, GridDOSData):
390 raise TypeError("GridDOSData can only be combined with other "
391 "GridDOSData objects")
392 if len(self._data[0, :]) != len(other.get_energies()):
393 raise ValueError("Cannot add GridDOSData objects with different-"
394 "length energy grids.")
396 if not np.allclose(self._data[0, :], other.get_energies()):
397 raise ValueError("Cannot add GridDOSData objects with different "
398 "energy grids.")
400 # Take intersection of metadata (i.e. only common entries are retained)
401 new_info = dict(set(self.info.items()) & set(other.info.items()))
403 # Sum the energy/weight data
404 new_weights = self._data[1, :] + other.get_weights()
406 new_object = GridDOSData(self._data[0, :], new_weights,
407 info=new_info)
408 return new_object
410 @staticmethod
411 def _interpret_smearing_args(npts: int,
412 width: float = None,
413 default_npts: int = 1000,
414 default_width: float = 0.1
415 ) -> Tuple[int, Union[float, None]]:
416 """Figure out what the user intended: resample if width provided"""
417 if width is not None:
418 if npts:
419 return (npts, float(width))
420 else:
421 return (default_npts, float(width))
422 else:
423 if npts:
424 return (npts, default_width)
425 else:
426 return (0, None)
428 def plot(self,
429 npts: int = 0,
430 xmin: float = None,
431 xmax: float = None,
432 width: float = None,
433 smearing: str = 'Gauss',
434 ax: Axes = None,
435 show: bool = False,
436 filename: str = None,
437 mplargs: dict = None) -> Axes:
438 """Simple 1-D plot of DOS data
440 Data will be resampled onto a grid with `npts` points unless `npts` is
441 set to zero, in which case:
443 - no resampling takes place
444 - `width` and `smearing` are ignored
445 - `xmin` and `xmax` affect the axis limits of the plot, not the
446 underlying data.
448 If the special key 'label' is present in self.info, this will be set
449 as the label for the plotted line (unless overruled in mplargs). The
450 label is only seen if a legend is added to the plot (i.e. by calling
451 ``ax.legend()``).
453 Args:
454 npts, xmin, xmax: output data range, as passed to self.sample_grid
455 width: Width of broadening kernel, passed to self.sample_grid().
456 If no npts was set but width is set, npts will be set to 1000.
457 smearing: selection of broadening kernel for self.sample_grid()
458 ax: existing Matplotlib axes object. If not provided, a new figure
459 with one set of axes will be created using Pyplot
460 show: show the figure on-screen
461 filename: if a path is given, save the figure to this file
462 mplargs: additional arguments to pass to matplotlib plot command
463 (e.g. {'linewidth': 2} for a thicker line).
465 Returns:
466 Plotting axes. If "ax" was set, this is the same object.
467 """
469 npts, width = self._interpret_smearing_args(npts, width)
471 if mplargs is None:
472 mplargs = {}
473 if 'label' not in mplargs:
474 mplargs.update({'label': self.label_from_info(self.info)})
476 if npts:
477 assert isinstance(width, float)
478 dos = self.sample_grid(npts, xmin=xmin,
479 xmax=xmax, width=width,
480 smearing=smearing)
481 else:
482 dos = self
484 energies, intensity = dos.get_energies(), dos.get_weights()
486 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
487 ax.plot(energies, intensity, **mplargs)
488 ax.set_xlim(left=xmin, right=xmax)
490 return ax