Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Refactor of DOS-like data objects
2# towards replacing ase.dft.dos and ase.dft.pdos
3from abc import ABCMeta, abstractmethod
4import warnings
5from typing import Any, Dict, Sequence, Tuple, TypeVar, Union
7import numpy as np
8from ase.utils.plotting import SimplePlottingAxes
10# This import is for the benefit of type-checking / mypy
11if False:
12 import matplotlib.axes
14# For now we will be strict about Info and say it has to be str->str. Perhaps
15# later we will allow other types that have reliable comparison operations.
16Info = Dict[str, str]
19class DOSData(metaclass=ABCMeta):
20 """Abstract base class for a single series of DOS-like data
22 Only the 'info' is a mutable attribute; DOS data is set at init"""
23 def __init__(self,
24 info: Info = None) -> None:
25 if info is None:
26 self.info = {}
27 elif isinstance(info, dict):
28 self.info = info
29 else:
30 raise TypeError("Info must be a dict or None")
32 @abstractmethod
33 def get_energies(self) -> Sequence[float]:
34 """Get energy data stored in this object"""
36 @abstractmethod
37 def get_weights(self) -> Sequence[float]:
38 """Get DOS weights stored in this object"""
40 @abstractmethod
41 def copy(self) -> 'DOSData':
42 """Returns a copy in which info dict can be safely mutated"""
44 def _sample(self,
45 energies: Sequence[float],
46 width: float = 0.1,
47 smearing: str = 'Gauss') -> np.ndarray:
48 """Sample the DOS data at chosen points, with broadening
50 Note that no correction is made here for the sampling bin width; total
51 intensity will vary with sampling density.
53 Args:
54 energies: energy values for sampling
55 width: Width of broadening kernel
56 smearing: selection of broadening kernel (only "Gauss" is currently
57 supported)
59 Returns:
60 Weights sampled from a broadened DOS at values corresponding to x
61 """
63 self._check_positive_width(width)
64 weights_grid = np.zeros(len(energies), float)
65 weights = self.get_weights()
66 energies = np.asarray(energies, float)
68 for i, raw_energy in enumerate(self.get_energies()):
69 delta = self._delta(energies, raw_energy, width, smearing=smearing)
70 weights_grid += weights[i] * delta
71 return weights_grid
73 def _almost_equals(self, other: Any) -> bool:
74 """Compare with another DOSData for testing purposes"""
75 if not isinstance(other, type(self)):
76 return False
77 if self.info != other.info:
78 return False
79 if not np.allclose(self.get_weights(), other.get_weights()):
80 return False
81 return np.allclose(self.get_energies(), other.get_energies())
83 @staticmethod
84 def _delta(x: np.ndarray,
85 x0: float,
86 width: float,
87 smearing: str = 'Gauss') -> np.ndarray:
88 """Return a delta-function centered at 'x0'.
90 This function is used with numpy broadcasting; if x is a row and x0 is
91 a column vector, the returned data will be a 2D array with each row
92 corresponding to a different delta center.
93 """
94 if smearing.lower() == 'gauss':
95 x1 = -0.5 * ((x - x0) / width)**2
96 return np.exp(x1) / (np.sqrt(2 * np.pi) * width)
97 else:
98 msg = 'Requested smearing type not recognized. Got {}'.format(
99 smearing)
100 raise ValueError(msg)
102 @staticmethod
103 def _check_positive_width(width):
104 if width <= 0.0:
105 msg = 'Cannot add 0 or negative width smearing'
106 raise ValueError(msg)
108 def sample_grid(self,
109 npts: int,
110 xmin: float = None,
111 xmax: float = None,
112 padding: float = 3,
113 width: float = 0.1,
114 smearing: str = 'Gauss',
115 ) -> 'GridDOSData':
116 """Sample the DOS data on an evenly-spaced energy grid
118 Args:
119 npts: Number of sampled points
120 xmin: Minimum sampled x value; if unspecified, a default is chosen
121 xmax: Maximum sampled x value; if unspecified, a default is chosen
122 padding: If xmin/xmax is unspecified, default value will be padded
123 by padding * width to avoid cutting off peaks.
124 width: Width of broadening kernel
125 smearing: selection of broadening kernel (only 'Gauss' is
126 implemented)
128 Returns:
129 (energy values, sampled DOS)
130 """
132 if xmin is None:
133 xmin = min(self.get_energies()) - (padding * width)
134 if xmax is None:
135 xmax = max(self.get_energies()) + (padding * width)
136 energies_grid = np.linspace(xmin, xmax, npts)
137 weights_grid = self._sample(energies_grid, width=width,
138 smearing=smearing)
140 return GridDOSData(energies_grid, weights_grid, info=self.info.copy())
142 def plot(self,
143 npts: int = 1000,
144 xmin: float = None,
145 xmax: float = None,
146 width: float = 0.1,
147 smearing: str = 'Gauss',
148 ax: 'matplotlib.axes.Axes' = None,
149 show: bool = False,
150 filename: str = None,
151 mplargs: dict = None) -> 'matplotlib.axes.Axes':
152 """Simple 1-D plot of DOS data, resampled onto a grid
154 If the special key 'label' is present in self.info, this will be set
155 as the label for the plotted line (unless overruled in mplargs). The
156 label is only seen if a legend is added to the plot (i.e. by calling
157 ``ax.legend()``).
159 Args:
160 npts, xmin, xmax: output data range, as passed to self.sample_grid
161 width: Width of broadening kernel for self.sample_grid()
162 smearing: selection of broadening kernel for self.sample_grid()
163 ax: existing Matplotlib axes object. If not provided, a new figure
164 with one set of axes will be created using Pyplot
165 show: show the figure on-screen
166 filename: if a path is given, save the figure to this file
167 mplargs: additional arguments to pass to matplotlib plot command
168 (e.g. {'linewidth': 2} for a thicker line).
171 Returns:
172 Plotting axes. If "ax" was set, this is the same object.
173 """
175 if mplargs is None:
176 mplargs = {}
177 if 'label' not in mplargs:
178 mplargs.update({'label': self.label_from_info(self.info)})
180 return self.sample_grid(npts, xmin=xmin, xmax=xmax,
181 width=width,
182 smearing=smearing
183 ).plot(ax=ax, xmin=xmin, xmax=xmax,
184 show=show, filename=filename,
185 mplargs=mplargs)
187 @staticmethod
188 def label_from_info(info: Dict[str, str]):
189 """Generate an automatic legend label from info dict"""
190 if 'label' in info:
191 return info['label']
192 else:
193 return '; '.join(map(lambda x: '{}: {}'.format(x[0], x[1]),
194 info.items()))
197class GeneralDOSData(DOSData):
198 """Base class for a single series of DOS-like data
200 Only the 'info' is a mutable attribute; DOS data is set at init
202 This is the base class for DOSData objects that accept/set seperate
203 "energies" and "weights" sequences of equal length at init.
205 """
206 def __init__(self,
207 energies: Union[Sequence[float], np.ndarray],
208 weights: Union[Sequence[float], np.ndarray],
209 info: Info = None) -> None:
210 super().__init__(info=info)
212 n_entries = len(energies)
213 if len(weights) != n_entries:
214 raise ValueError("Energies and weights must be the same length")
216 # Internally store the data as a np array with two rows; energy, weight
217 self._data = np.empty((2, n_entries), dtype=float, order='C')
218 self._data[0, :] = energies
219 self._data[1, :] = weights
221 def get_energies(self) -> np.ndarray:
222 return self._data[0, :].copy()
224 def get_weights(self) -> np.ndarray:
225 return self._data[1, :].copy()
227 D = TypeVar('D', bound='GeneralDOSData')
229 def copy(self: D) -> D: # noqa F821
230 return type(self)(self.get_energies(), self.get_weights(),
231 info=self.info.copy())
234class RawDOSData(GeneralDOSData):
235 """A collection of weighted delta functions which sum to form a DOS
237 This is an appropriate data container for density-of-states (DOS) or
238 spectral data where the energy data values not form a known regular
239 grid. The data may be plotted or resampled for further analysis using the
240 sample_grid() and plot() methods. Multiple weights at the same
241 energy value will *only* be combined in output data, and data stored in
242 RawDOSData is never resampled. A plot_deltas() function is also provided
243 which plots the raw data.
245 Metadata may be stored in the info dict, in which keys and values must be
246 strings. This data is used for selecting and combining multiple DOSData
247 objects in a DOSCollection object.
249 When RawDOSData objects are combined with the addition operator::
251 big_dos = raw_dos_1 + raw_dos_2
253 the energy and weights data is *concatenated* (i.e. combined without
254 sorting or replacement) and the new info dictionary consists of the
255 *intersection* of the inputs: only key-value pairs that were common to both
256 of the input objects will be retained in the new combined object. For
257 example::
259 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'})
260 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'}))
262 will yield the equivalent of::
264 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'})
266 """
268 def __add__(self, other: 'RawDOSData') -> 'RawDOSData':
269 if not isinstance(other, RawDOSData):
270 raise TypeError("RawDOSData can only be combined with other "
271 "RawDOSData objects")
273 # Take intersection of metadata (i.e. only common entries are retained)
274 new_info = dict(set(self.info.items()) & set(other.info.items()))
276 # Concatenate the energy/weight data
277 new_data = np.concatenate((self._data, other._data), axis=1)
279 new_object = RawDOSData([], [], info=new_info)
280 new_object._data = new_data
282 return new_object
284 def plot_deltas(self,
285 ax: 'matplotlib.axes.Axes' = None,
286 show: bool = False,
287 filename: str = None,
288 mplargs: dict = None) -> 'matplotlib.axes.Axes':
289 """Simple plot of sparse DOS data as a set of delta functions
291 Items at the same x-value can overlap and will not be summed together
293 Args:
294 ax: existing Matplotlib axes object. If not provided, a new figure
295 with one set of axes will be created using Pyplot
296 show: show the figure on-screen
297 filename: if a path is given, save the figure to this file
298 mplargs: additional arguments to pass to matplotlib Axes.vlines
299 command (e.g. {'linewidth': 2} for a thicker line).
301 Returns:
302 Plotting axes. If "ax" was set, this is the same object.
303 """
305 if mplargs is None:
306 mplargs = {}
308 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
309 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs)
311 return ax
314class GridDOSData(GeneralDOSData):
315 """A collection of regularly-sampled data which represents a DOS
317 This is an appropriate data container for density-of-states (DOS) or
318 spectral data where the intensity values form a regular grid. This
319 is generally the result of sampling or integrating into discrete
320 bins, rather than a collection of unique states. The data may be
321 plotted or resampled for further analysis using the sample_grid()
322 and plot() methods.
324 Metadata may be stored in the info dict, in which keys and values must be
325 strings. This data is used for selecting and combining multiple DOSData
326 objects in a DOSCollection object.
328 When RawDOSData objects are combined with the addition operator::
330 big_dos = raw_dos_1 + raw_dos_2
332 the weights data is *summed* (requiring a consistent energy grid) and the
333 new info dictionary consists of the *intersection* of the inputs: only
334 key-value pairs that were common to both of the input objects will be
335 retained in the new combined object. For example::
337 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3],
338 info={'symbol': 'O', 'index': '1'})
339 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6],
340 info={'symbol': 'O', 'index': '2'}))
342 will yield the equivalent of::
344 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'})
346 """
347 def __init__(self,
348 energies: Sequence[float],
349 weights: Sequence[float],
350 info: Info = None) -> None:
351 n_entries = len(energies)
352 if not np.allclose(energies,
353 np.linspace(energies[0], energies[-1], n_entries)):
354 raise ValueError("Energies must be an evenly-spaced 1-D grid")
356 if len(weights) != n_entries:
357 raise ValueError("Energies and weights must be the same length")
359 super().__init__(energies, weights, info=info)
360 self.sigma_cutoff = 3
362 def _check_spacing(self, width) -> float:
363 current_spacing = self._data[0, 1] - self._data[0, 0]
364 if width < (2 * current_spacing):
365 warnings.warn(
366 "The broadening width is small compared to the original "
367 "sampling density. The results are unlikely to be smooth.")
368 return current_spacing
370 def _sample(self,
371 energies: Sequence[float],
372 width: float = 0.1,
373 smearing: str = 'Gauss') -> np.ndarray:
374 current_spacing = self._check_spacing(width)
375 return super()._sample(energies=energies,
376 width=width, smearing=smearing
377 ) * current_spacing
379 def __add__(self, other: 'GridDOSData') -> 'GridDOSData':
380 # This method uses direct access to the mutable energy and weights data
381 # (self._data) to avoid redundant copying operations. The __init__
382 # method of GridDOSData will write this to a new array, so on this
383 # occasion it is safe to pass references to the mutable data.
385 if not isinstance(other, GridDOSData):
386 raise TypeError("GridDOSData can only be combined with other "
387 "GridDOSData objects")
388 if len(self._data[0, :]) != len(other.get_energies()):
389 raise ValueError("Cannot add GridDOSData objects with different-"
390 "length energy grids.")
392 if not np.allclose(self._data[0, :], other.get_energies()):
393 raise ValueError("Cannot add GridDOSData objects with different "
394 "energy grids.")
396 # Take intersection of metadata (i.e. only common entries are retained)
397 new_info = dict(set(self.info.items()) & set(other.info.items()))
399 # Concatenate the energy/weight data
400 new_weights = self._data[1, :] + other.get_weights()
402 new_object = GridDOSData(self._data[0, :], new_weights,
403 info=new_info)
404 return new_object
406 @staticmethod
407 def _interpret_smearing_args(npts: int,
408 width: float = None,
409 default_npts: int = 1000,
410 default_width: float = 0.1
411 ) -> Tuple[int, Union[float, None]]:
412 """Figure out what the user intended: resample if width provided"""
413 if width is not None:
414 if npts:
415 return (npts, float(width))
416 else:
417 return (default_npts, float(width))
418 else:
419 if npts:
420 return (npts, default_width)
421 else:
422 return (0, None)
424 def plot(self,
425 npts: int = 0,
426 xmin: float = None,
427 xmax: float = None,
428 width: float = None,
429 smearing: str = 'Gauss',
430 ax: 'matplotlib.axes.Axes' = None,
431 show: bool = False,
432 filename: str = None,
433 mplargs: dict = None) -> 'matplotlib.axes.Axes':
434 """Simple 1-D plot of DOS data
436 Data will be resampled onto a grid with `npts` points unless `npts` is
437 set to zero, in which case:
439 - no resampling takes place
440 - `width` and `smearing` are ignored
441 - `xmin` and `xmax` affect the axis limits of the plot, not the
442 underlying data.
444 If the special key 'label' is present in self.info, this will be set
445 as the label for the plotted line (unless overruled in mplargs). The
446 label is only seen if a legend is added to the plot (i.e. by calling
447 ``ax.legend()``).
449 Args:
450 npts, xmin, xmax: output data range, as passed to self.sample_grid
451 width: Width of broadening kernel, passed to self.sample_grid().
452 If no npts was set but width is set, npts will be set to 1000.
453 smearing: selection of broadening kernel for self.sample_grid()
454 ax: existing Matplotlib axes object. If not provided, a new figure
455 with one set of axes will be created using Pyplot
456 show: show the figure on-screen
457 filename: if a path is given, save the figure to this file
458 mplargs: additional arguments to pass to matplotlib plot command
459 (e.g. {'linewidth': 2} for a thicker line).
461 Returns:
462 Plotting axes. If "ax" was set, this is the same object.
463 """
465 npts, width = self._interpret_smearing_args(npts, width)
467 if mplargs is None:
468 mplargs = {}
469 if 'label' not in mplargs:
470 mplargs.update({'label': self.label_from_info(self.info)})
472 if npts:
473 assert isinstance(width, float)
474 dos = self.sample_grid(npts, xmin=xmin,
475 xmax=xmax, width=width,
476 smearing=smearing)
477 else:
478 dos = self
480 energies, intensity = dos.get_energies(), dos.get_weights()
482 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
483 ax.plot(energies, intensity, **mplargs)
484 ax.set_xlim(left=xmin, right=xmax)
486 return ax