Coverage for /builds/debichem-team/python-ase/ase/dft/bz.py: 94.15%
188 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1from itertools import product
2from math import cos, pi, sin
3from typing import Any, Dict, Optional, Tuple, Union
5import numpy as np
6from matplotlib.patches import FancyArrowPatch
7from mpl_toolkits.mplot3d import Axes3D, proj3d
8from scipy.spatial.transform import Rotation
10from ase.cell import Cell
13def bz_vertices(icell, dim=3):
14 """Return the vertices and the normal vector of the BZ.
16 See https://xkcd.com/1421 ..."""
17 from scipy.spatial import Voronoi
19 icell = icell.copy()
20 if dim < 3:
21 icell[2, 2] = 1e-3
22 if dim < 2:
23 icell[1, 1] = 1e-3
25 indices = (np.indices((3, 3, 3)) - 1).reshape((3, 27))
26 G = np.dot(icell.T, indices).T
27 vor = Voronoi(G)
28 bz1 = []
29 for vertices, points in zip(vor.ridge_vertices, vor.ridge_points):
30 if -1 not in vertices and 13 in points:
31 normal = G[points].sum(0)
32 normal /= (normal**2).sum()**0.5
33 bz1.append((vor.vertices[vertices], normal))
34 return bz1
37class FlatPlot:
38 """Helper class for 1D/2D Brillouin zone plots."""
40 axis_dim = 2 # Dimension of the plotting surface (2 even if it's 1D BZ).
41 point_options = {'zorder': 5}
43 def new_axes(self, fig):
44 return fig.gca()
46 def adjust_view(self, ax, minp, maxp, symmetric: bool = True):
47 """Ajusting view property of the drawn BZ. (1D/2D)
49 Parameters
50 ----------
51 ax: Axes
52 matplotlib Axes object.
53 minp: float
54 minimum value for the plotting region, which detemines the
55 bottom left corner of the figure. if symmetric is set as True,
56 this value is ignored.
57 maxp: float
58 maximum value for the plotting region, which detemines the
59 top right corner of the figure.
60 symmetric: bool
61 if True, set the (0,0) position (Gamma-bar position) at the center
62 of the figure.
64 """
65 ax.autoscale_view(tight=True)
66 s = maxp * 1.05
67 if symmetric:
68 ax.set_xlim(-s, s)
69 ax.set_ylim(-s, s)
70 else:
71 ax.set_xlim(minp * 1.05, s)
72 ax.set_ylim(minp * 1.05, s)
73 ax.set_aspect('equal')
75 def draw_arrow(self, ax, vector, **kwargs):
76 ax.arrow(0, 0, vector[0], vector[1],
77 lw=1,
78 length_includes_head=True,
79 head_width=0.03,
80 head_length=0.05,
81 **kwargs)
83 def label_options(self, point):
84 ha_s = ['right', 'left', 'right']
85 va_s = ['bottom', 'bottom', 'top']
87 x, y = point
88 ha = ha_s[int(np.sign(x))]
89 va = va_s[int(np.sign(y))]
90 return {'ha': ha, 'va': va, 'zorder': 4}
92 def view(self):
93 pass
96class SpacePlot:
97 """Helper class for ordinary (3D) Brillouin zone plots.
99 Attributes
100 ----------
101 azim : float
102 Azimuthal angle in radian for viewing 3D BZ.
103 default value is pi/5
104 elev : float
105 Elevation angle in radian for viewing 3D BZ.
106 default value is pi/6
108 """
109 axis_dim = 3
110 point_options: Dict[str, Any] = {}
112 def __init__(self, *, azim: Optional[float] = None,
113 elev: Optional[float] = None):
114 class Arrow3D(FancyArrowPatch):
115 def __init__(self, ax, xs, ys, zs, *args, **kwargs):
116 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
117 self._verts3d = xs, ys, zs
118 self.ax = ax
120 def draw(self, renderer):
121 xs3d, ys3d, zs3d = self._verts3d
122 xs, ys, _zs = proj3d.proj_transform(xs3d, ys3d,
123 zs3d, self.ax.axes.M)
124 self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
125 FancyArrowPatch.draw(self, renderer)
127 # FIXME: Compatibility fix for matplotlib 3.5.0: Handling of 3D
128 # artists have changed and all 3D artists now need
129 # "do_3d_projection". Since this class is a hack that manually
130 # projects onto the 3D axes we don't need to do anything in this
131 # method. Ideally we shouldn't resort to a hack like this.
132 def do_3d_projection(self, *_, **__):
133 return 0
135 self.arrow3d = Arrow3D
136 self.azim: float = pi / 5 if azim is None else azim
137 self.elev: float = pi / 6 if elev is None else elev
138 self.view = [
139 sin(self.azim) * cos(self.elev),
140 cos(self.azim) * cos(self.elev),
141 sin(self.elev),
142 ]
144 def new_axes(self, fig):
145 return fig.add_subplot(projection='3d')
147 def draw_arrow(self, ax: Axes3D, vector, **kwargs):
148 ax.add_artist(self.arrow3d(
149 ax,
150 [0, vector[0]],
151 [0, vector[1]],
152 [0, vector[2]],
153 mutation_scale=20,
154 arrowstyle='-|>',
155 **kwargs))
157 def adjust_view(self, ax, minp, maxp, symmetric=True):
158 """Ajusting view property of the drawn BZ. (3D)
160 Parameters
161 ----------
162 ax: Axes
163 matplotlib Axes object.
164 minp: float
165 minimum value for the plotting region, which detemines the
166 bottom left corner of the figure. if symmetric is set as True,
167 this value is ignored.
168 maxp: float
169 maximum value for the plotting region, which detemines the
170 top right corner of the figure.
171 symmetric: bool
172 Currently, this is not used, just for keeping consistency with 2D
173 version.
175 """
176 import matplotlib.pyplot as plt
178 # ax.set_aspect('equal') <-- won't work anymore in 3.1.0
179 ax.view_init(azim=np.rad2deg(self.azim), elev=np.rad2deg(self.elev))
180 # We want aspect 'equal', but apparently there was a bug in
181 # matplotlib causing wrong behaviour. Matplotlib raises
182 # NotImplementedError as of v3.1.0. This is a bit unfortunate
183 # because the workarounds known to StackOverflow and elsewhere
184 # all involve using set_aspect('equal') and then doing
185 # something more.
186 #
187 # We try to get square axes here by setting a square figure,
188 # but this is probably rather inexact.
189 fig = ax.get_figure()
190 xx = plt.figaspect(1.0)
191 fig.set_figheight(xx[1])
192 fig.set_figwidth(xx[0])
194 ax.set_proj_type('ortho')
196 minp0 = 0.9 * minp # Here we cheat a bit to trim spacings
197 maxp0 = 0.9 * maxp
198 ax.set_xlim3d(minp0, maxp0)
199 ax.set_ylim3d(minp0, maxp0)
200 ax.set_zlim3d(minp0, maxp0)
202 ax.set_box_aspect([1, 1, 1])
204 def label_options(self, point):
205 return dict(ha='center', va='bottom')
208def normalize_name(name):
209 if name == 'G':
210 return '\\Gamma'
212 if len(name) > 1:
213 import re
215 m = re.match(r'^(\D+?)(\d*)$', name)
216 if m is None:
217 raise ValueError(f'Bad label: {name}')
218 name, num = m.group(1, 2)
219 if num:
220 name = f'{name}_{{{num}}}'
221 return name
224def bz_plot(cell: Cell, vectors: bool = False, paths=None, points=None,
225 azim: Optional[float] = None, elev: Optional[float] = None,
226 scale=1, interactive: bool = False,
227 transforms: Optional[list] = None,
228 repeat: Union[Tuple[int, int], Tuple[int, int, int]] = (1, 1, 1),
229 pointstyle: Optional[dict] = None,
230 ax=None, show=False, **kwargs):
231 """Plot the Brillouin zone of the Cell
233 Parameters
234 ----------
235 cell: Cell
236 Cell object for BZ drawing.
237 vectors : bool
238 if True, show the vector.
239 paths : list[tuple[str, np.ndarray]] | None
240 Special point name and its coordinate position
241 points : np.ndarray
242 Coordinate points along the paths.
243 azim : float | None
244 Azimuthal angle in radian for viewing 3D BZ.
245 elev : float | None
246 Elevation angle in radian for viewing 3D BZ.
247 scale : float
248 Not used. To be removed?
249 interactive : bool
250 Not effectively works. To be removed?
251 transforms: List
252 List of linear transformation (scipy.spatial.transform.Rotation)
253 repeat: Tuple[int, int] | Tuple[int, int, int]
254 Set the repeating draw of BZ. default is (1, 1, 1), no repeat.
255 pointstyle : Dict
256 Style of the special point
257 ax : Axes | Axes3D
258 matplolib Axes (Axes3D in 3D) object
259 show : bool
260 If true, show the figure.
261 **kwargs
262 Additional keyword arguments to pass to ax.plot
264 Returns
265 -------
266 ax
267 A matplotlib axis object.
268 """
269 import matplotlib.pyplot as plt
271 if pointstyle is None:
272 pointstyle = {}
274 if transforms is None:
275 transforms = [Rotation.from_rotvec((0, 0, 0))]
277 cell = cell.copy()
279 dimensions = cell.rank
280 if dimensions == 3:
281 plotter: Union[SpacePlot, FlatPlot] = SpacePlot(azim=azim, elev=elev)
282 else:
283 plotter = FlatPlot()
284 assert dimensions > 0, 'No BZ for 0D!'
286 if ax is None:
287 ax = plotter.new_axes(plt.gcf())
289 assert not np.array(cell)[dimensions:, :].any()
290 assert not np.array(cell)[:, dimensions:].any()
292 icell = cell.reciprocal()
293 kpoints = points
294 bz1 = bz_vertices(icell, dim=dimensions)
295 if len(repeat) == 2:
296 repeat = (repeat[0], repeat[1], 1)
298 maxp = 0.0
299 minp = 0.0
300 for bz_i in bz_index(repeat):
301 for points, normal in bz1:
302 shift = np.dot(np.array(icell).T, np.array(bz_i))
303 for transform in transforms:
304 shift = transform.apply(shift)
305 ls = '-'
306 xyz = np.concatenate([points, points[:1]])
307 for transform in transforms:
308 xyz = transform.apply(xyz)
309 x, y, z = xyz.T
310 x, y, z = x + shift[0], y + shift[1], z + shift[2]
311 if dimensions == 3:
312 if normal @ plotter.view < 0 and not interactive:
313 ls = ':'
314 if plotter.axis_dim == 2:
315 ax.plot(x, y, c='k', ls=ls, **kwargs)
316 else:
317 ax.plot(x, y, z, c='k', ls=ls, **kwargs)
318 maxp = max(maxp, x.max(), y.max(), z.max())
319 minp = min(minp, x.min(), y.min(), z.min())
321 if vectors:
322 for transform in transforms:
323 icell = transform.apply(icell)
324 assert isinstance(icell, np.ndarray)
325 for i in range(dimensions):
326 plotter.draw_arrow(ax, icell[i], color='k')
328 # XXX Can this be removed?
329 if dimensions == 3:
330 maxp = max(maxp, 0.6 * icell.max())
331 else:
332 maxp = max(maxp, icell.max())
334 if paths is not None:
335 for names, points in paths:
336 for transform in transforms:
337 points = transform.apply(points)
338 coords = np.array(points).T[:plotter.axis_dim, :]
339 ax.plot(*coords, c='r', ls='-')
341 for name, point in zip(names, points):
342 name = normalize_name(name)
343 for transform in transforms:
344 point = transform.apply(point)
345 point = point[:plotter.axis_dim]
346 ax.text(*point, rf'$\mathrm{{{name}}}$',
347 color='g', **plotter.label_options(point))
349 if kpoints is not None:
350 kw = {'c': 'b', **plotter.point_options, **pointstyle}
351 for transform in transforms:
352 kpoints = transform.apply(kpoints)
353 ax.scatter(*kpoints[:, :plotter.axis_dim].T, **kw)
355 ax.set_axis_off()
357 if repeat == (1, 1, 1):
358 plotter.adjust_view(ax, minp, maxp)
359 else:
360 plotter.adjust_view(ax, minp, maxp, symmetric=False)
361 if show:
362 plt.show()
364 return ax
367def bz_index(repeat):
368 """BZ index from the repeat
370 A helper function to iterating drawing BZ.
372 Parameters
373 ----------
374 repeat: Tuple[int, int] | Tuple[int, int, int]
375 repeating for drawing BZ
377 Returns
378 -------
379 Iterator[Tuple[int, int, int]]
381 >>> list(_bz_index((1, 2, -2)))
382 [(0, 0, 0), (0, 0, -1), (0, 1, 0), (0, 1, -1)]
384 """
385 if len(repeat) == 2:
386 repeat = (repeat[0], repeat[1], 1)
387 assert len(repeat) == 3
388 assert repeat[0] != 0
389 assert repeat[1] != 0
390 assert repeat[2] != 0
391 repeat_along_a = (
392 range(0, repeat[0]) if repeat[0] > 0 else range(0, repeat[0], -1)
393 )
394 repeat_along_b = (
395 range(0, repeat[1]) if repeat[1] > 0 else range(0, repeat[1], -1)
396 )
397 repeat_along_c = (
398 range(0, repeat[2]) if repeat[2] > 0 else range(0, repeat[2], -1)
399 )
400 return product(repeat_along_a, repeat_along_b, repeat_along_c)