Coverage for /builds/debichem-team/python-ase/ase/dft/bz.py: 94.15%

188 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1from itertools import product 

2from math import cos, pi, sin 

3from typing import Any, Dict, Optional, Tuple, Union 

4 

5import numpy as np 

6from matplotlib.patches import FancyArrowPatch 

7from mpl_toolkits.mplot3d import Axes3D, proj3d 

8from scipy.spatial.transform import Rotation 

9 

10from ase.cell import Cell 

11 

12 

13def bz_vertices(icell, dim=3): 

14 """Return the vertices and the normal vector of the BZ. 

15 

16 See https://xkcd.com/1421 ...""" 

17 from scipy.spatial import Voronoi 

18 

19 icell = icell.copy() 

20 if dim < 3: 

21 icell[2, 2] = 1e-3 

22 if dim < 2: 

23 icell[1, 1] = 1e-3 

24 

25 indices = (np.indices((3, 3, 3)) - 1).reshape((3, 27)) 

26 G = np.dot(icell.T, indices).T 

27 vor = Voronoi(G) 

28 bz1 = [] 

29 for vertices, points in zip(vor.ridge_vertices, vor.ridge_points): 

30 if -1 not in vertices and 13 in points: 

31 normal = G[points].sum(0) 

32 normal /= (normal**2).sum()**0.5 

33 bz1.append((vor.vertices[vertices], normal)) 

34 return bz1 

35 

36 

37class FlatPlot: 

38 """Helper class for 1D/2D Brillouin zone plots.""" 

39 

40 axis_dim = 2 # Dimension of the plotting surface (2 even if it's 1D BZ). 

41 point_options = {'zorder': 5} 

42 

43 def new_axes(self, fig): 

44 return fig.gca() 

45 

46 def adjust_view(self, ax, minp, maxp, symmetric: bool = True): 

47 """Ajusting view property of the drawn BZ. (1D/2D) 

48 

49 Parameters 

50 ---------- 

51 ax: Axes 

52 matplotlib Axes object. 

53 minp: float 

54 minimum value for the plotting region, which detemines the 

55 bottom left corner of the figure. if symmetric is set as True, 

56 this value is ignored. 

57 maxp: float 

58 maximum value for the plotting region, which detemines the 

59 top right corner of the figure. 

60 symmetric: bool 

61 if True, set the (0,0) position (Gamma-bar position) at the center 

62 of the figure. 

63 

64 """ 

65 ax.autoscale_view(tight=True) 

66 s = maxp * 1.05 

67 if symmetric: 

68 ax.set_xlim(-s, s) 

69 ax.set_ylim(-s, s) 

70 else: 

71 ax.set_xlim(minp * 1.05, s) 

72 ax.set_ylim(minp * 1.05, s) 

73 ax.set_aspect('equal') 

74 

75 def draw_arrow(self, ax, vector, **kwargs): 

76 ax.arrow(0, 0, vector[0], vector[1], 

77 lw=1, 

78 length_includes_head=True, 

79 head_width=0.03, 

80 head_length=0.05, 

81 **kwargs) 

82 

83 def label_options(self, point): 

84 ha_s = ['right', 'left', 'right'] 

85 va_s = ['bottom', 'bottom', 'top'] 

86 

87 x, y = point 

88 ha = ha_s[int(np.sign(x))] 

89 va = va_s[int(np.sign(y))] 

90 return {'ha': ha, 'va': va, 'zorder': 4} 

91 

92 def view(self): 

93 pass 

94 

95 

96class SpacePlot: 

97 """Helper class for ordinary (3D) Brillouin zone plots. 

98 

99 Attributes 

100 ---------- 

101 azim : float 

102 Azimuthal angle in radian for viewing 3D BZ. 

103 default value is pi/5 

104 elev : float 

105 Elevation angle in radian for viewing 3D BZ. 

106 default value is pi/6 

107 

108 """ 

109 axis_dim = 3 

110 point_options: Dict[str, Any] = {} 

111 

112 def __init__(self, *, azim: Optional[float] = None, 

113 elev: Optional[float] = None): 

114 class Arrow3D(FancyArrowPatch): 

115 def __init__(self, ax, xs, ys, zs, *args, **kwargs): 

116 FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) 

117 self._verts3d = xs, ys, zs 

118 self.ax = ax 

119 

120 def draw(self, renderer): 

121 xs3d, ys3d, zs3d = self._verts3d 

122 xs, ys, _zs = proj3d.proj_transform(xs3d, ys3d, 

123 zs3d, self.ax.axes.M) 

124 self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) 

125 FancyArrowPatch.draw(self, renderer) 

126 

127 # FIXME: Compatibility fix for matplotlib 3.5.0: Handling of 3D 

128 # artists have changed and all 3D artists now need 

129 # "do_3d_projection". Since this class is a hack that manually 

130 # projects onto the 3D axes we don't need to do anything in this 

131 # method. Ideally we shouldn't resort to a hack like this. 

132 def do_3d_projection(self, *_, **__): 

133 return 0 

134 

135 self.arrow3d = Arrow3D 

136 self.azim: float = pi / 5 if azim is None else azim 

137 self.elev: float = pi / 6 if elev is None else elev 

138 self.view = [ 

139 sin(self.azim) * cos(self.elev), 

140 cos(self.azim) * cos(self.elev), 

141 sin(self.elev), 

142 ] 

143 

144 def new_axes(self, fig): 

145 return fig.add_subplot(projection='3d') 

146 

147 def draw_arrow(self, ax: Axes3D, vector, **kwargs): 

148 ax.add_artist(self.arrow3d( 

149 ax, 

150 [0, vector[0]], 

151 [0, vector[1]], 

152 [0, vector[2]], 

153 mutation_scale=20, 

154 arrowstyle='-|>', 

155 **kwargs)) 

156 

157 def adjust_view(self, ax, minp, maxp, symmetric=True): 

158 """Ajusting view property of the drawn BZ. (3D) 

159 

160 Parameters 

161 ---------- 

162 ax: Axes 

163 matplotlib Axes object. 

164 minp: float 

165 minimum value for the plotting region, which detemines the 

166 bottom left corner of the figure. if symmetric is set as True, 

167 this value is ignored. 

168 maxp: float 

169 maximum value for the plotting region, which detemines the 

170 top right corner of the figure. 

171 symmetric: bool 

172 Currently, this is not used, just for keeping consistency with 2D 

173 version. 

174 

175 """ 

176 import matplotlib.pyplot as plt 

177 

178 # ax.set_aspect('equal') <-- won't work anymore in 3.1.0 

179 ax.view_init(azim=np.rad2deg(self.azim), elev=np.rad2deg(self.elev)) 

180 # We want aspect 'equal', but apparently there was a bug in 

181 # matplotlib causing wrong behaviour. Matplotlib raises 

182 # NotImplementedError as of v3.1.0. This is a bit unfortunate 

183 # because the workarounds known to StackOverflow and elsewhere 

184 # all involve using set_aspect('equal') and then doing 

185 # something more. 

186 # 

187 # We try to get square axes here by setting a square figure, 

188 # but this is probably rather inexact. 

189 fig = ax.get_figure() 

190 xx = plt.figaspect(1.0) 

191 fig.set_figheight(xx[1]) 

192 fig.set_figwidth(xx[0]) 

193 

194 ax.set_proj_type('ortho') 

195 

196 minp0 = 0.9 * minp # Here we cheat a bit to trim spacings 

197 maxp0 = 0.9 * maxp 

198 ax.set_xlim3d(minp0, maxp0) 

199 ax.set_ylim3d(minp0, maxp0) 

200 ax.set_zlim3d(minp0, maxp0) 

201 

202 ax.set_box_aspect([1, 1, 1]) 

203 

204 def label_options(self, point): 

205 return dict(ha='center', va='bottom') 

206 

207 

208def normalize_name(name): 

209 if name == 'G': 

210 return '\\Gamma' 

211 

212 if len(name) > 1: 

213 import re 

214 

215 m = re.match(r'^(\D+?)(\d*)$', name) 

216 if m is None: 

217 raise ValueError(f'Bad label: {name}') 

218 name, num = m.group(1, 2) 

219 if num: 

220 name = f'{name}_{{{num}}}' 

221 return name 

222 

223 

224def bz_plot(cell: Cell, vectors: bool = False, paths=None, points=None, 

225 azim: Optional[float] = None, elev: Optional[float] = None, 

226 scale=1, interactive: bool = False, 

227 transforms: Optional[list] = None, 

228 repeat: Union[Tuple[int, int], Tuple[int, int, int]] = (1, 1, 1), 

229 pointstyle: Optional[dict] = None, 

230 ax=None, show=False, **kwargs): 

231 """Plot the Brillouin zone of the Cell 

232 

233 Parameters 

234 ---------- 

235 cell: Cell 

236 Cell object for BZ drawing. 

237 vectors : bool 

238 if True, show the vector. 

239 paths : list[tuple[str, np.ndarray]] | None 

240 Special point name and its coordinate position 

241 points : np.ndarray 

242 Coordinate points along the paths. 

243 azim : float | None 

244 Azimuthal angle in radian for viewing 3D BZ. 

245 elev : float | None 

246 Elevation angle in radian for viewing 3D BZ. 

247 scale : float 

248 Not used. To be removed? 

249 interactive : bool 

250 Not effectively works. To be removed? 

251 transforms: List 

252 List of linear transformation (scipy.spatial.transform.Rotation) 

253 repeat: Tuple[int, int] | Tuple[int, int, int] 

254 Set the repeating draw of BZ. default is (1, 1, 1), no repeat. 

255 pointstyle : Dict 

256 Style of the special point 

257 ax : Axes | Axes3D 

258 matplolib Axes (Axes3D in 3D) object 

259 show : bool 

260 If true, show the figure. 

261 **kwargs 

262 Additional keyword arguments to pass to ax.plot 

263 

264 Returns 

265 ------- 

266 ax 

267 A matplotlib axis object. 

268 """ 

269 import matplotlib.pyplot as plt 

270 

271 if pointstyle is None: 

272 pointstyle = {} 

273 

274 if transforms is None: 

275 transforms = [Rotation.from_rotvec((0, 0, 0))] 

276 

277 cell = cell.copy() 

278 

279 dimensions = cell.rank 

280 if dimensions == 3: 

281 plotter: Union[SpacePlot, FlatPlot] = SpacePlot(azim=azim, elev=elev) 

282 else: 

283 plotter = FlatPlot() 

284 assert dimensions > 0, 'No BZ for 0D!' 

285 

286 if ax is None: 

287 ax = plotter.new_axes(plt.gcf()) 

288 

289 assert not np.array(cell)[dimensions:, :].any() 

290 assert not np.array(cell)[:, dimensions:].any() 

291 

292 icell = cell.reciprocal() 

293 kpoints = points 

294 bz1 = bz_vertices(icell, dim=dimensions) 

295 if len(repeat) == 2: 

296 repeat = (repeat[0], repeat[1], 1) 

297 

298 maxp = 0.0 

299 minp = 0.0 

300 for bz_i in bz_index(repeat): 

301 for points, normal in bz1: 

302 shift = np.dot(np.array(icell).T, np.array(bz_i)) 

303 for transform in transforms: 

304 shift = transform.apply(shift) 

305 ls = '-' 

306 xyz = np.concatenate([points, points[:1]]) 

307 for transform in transforms: 

308 xyz = transform.apply(xyz) 

309 x, y, z = xyz.T 

310 x, y, z = x + shift[0], y + shift[1], z + shift[2] 

311 if dimensions == 3: 

312 if normal @ plotter.view < 0 and not interactive: 

313 ls = ':' 

314 if plotter.axis_dim == 2: 

315 ax.plot(x, y, c='k', ls=ls, **kwargs) 

316 else: 

317 ax.plot(x, y, z, c='k', ls=ls, **kwargs) 

318 maxp = max(maxp, x.max(), y.max(), z.max()) 

319 minp = min(minp, x.min(), y.min(), z.min()) 

320 

321 if vectors: 

322 for transform in transforms: 

323 icell = transform.apply(icell) 

324 assert isinstance(icell, np.ndarray) 

325 for i in range(dimensions): 

326 plotter.draw_arrow(ax, icell[i], color='k') 

327 

328 # XXX Can this be removed? 

329 if dimensions == 3: 

330 maxp = max(maxp, 0.6 * icell.max()) 

331 else: 

332 maxp = max(maxp, icell.max()) 

333 

334 if paths is not None: 

335 for names, points in paths: 

336 for transform in transforms: 

337 points = transform.apply(points) 

338 coords = np.array(points).T[:plotter.axis_dim, :] 

339 ax.plot(*coords, c='r', ls='-') 

340 

341 for name, point in zip(names, points): 

342 name = normalize_name(name) 

343 for transform in transforms: 

344 point = transform.apply(point) 

345 point = point[:plotter.axis_dim] 

346 ax.text(*point, rf'$\mathrm{{{name}}}$', 

347 color='g', **plotter.label_options(point)) 

348 

349 if kpoints is not None: 

350 kw = {'c': 'b', **plotter.point_options, **pointstyle} 

351 for transform in transforms: 

352 kpoints = transform.apply(kpoints) 

353 ax.scatter(*kpoints[:, :plotter.axis_dim].T, **kw) 

354 

355 ax.set_axis_off() 

356 

357 if repeat == (1, 1, 1): 

358 plotter.adjust_view(ax, minp, maxp) 

359 else: 

360 plotter.adjust_view(ax, minp, maxp, symmetric=False) 

361 if show: 

362 plt.show() 

363 

364 return ax 

365 

366 

367def bz_index(repeat): 

368 """BZ index from the repeat 

369 

370 A helper function to iterating drawing BZ. 

371 

372 Parameters 

373 ---------- 

374 repeat: Tuple[int, int] | Tuple[int, int, int] 

375 repeating for drawing BZ 

376 

377 Returns 

378 ------- 

379 Iterator[Tuple[int, int, int]] 

380 

381 >>> list(_bz_index((1, 2, -2))) 

382 [(0, 0, 0), (0, 0, -1), (0, 1, 0), (0, 1, -1)] 

383 

384 """ 

385 if len(repeat) == 2: 

386 repeat = (repeat[0], repeat[1], 1) 

387 assert len(repeat) == 3 

388 assert repeat[0] != 0 

389 assert repeat[1] != 0 

390 assert repeat[2] != 0 

391 repeat_along_a = ( 

392 range(0, repeat[0]) if repeat[0] > 0 else range(0, repeat[0], -1) 

393 ) 

394 repeat_along_b = ( 

395 range(0, repeat[1]) if repeat[1] > 0 else range(0, repeat[1], -1) 

396 ) 

397 repeat_along_c = ( 

398 range(0, repeat[2]) if repeat[2] > 0 else range(0, repeat[2], -1) 

399 ) 

400 return product(repeat_along_a, repeat_along_b, repeat_along_c)