Coverage for /builds/debichem-team/python-ase/ase/dft/bandgap.py: 87.69%

130 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import warnings 

2from dataclasses import dataclass 

3 

4import numpy as np 

5 

6spin_error = ( 

7 'The spin keyword is no longer supported. Please call the function ' 

8 'with the energies corresponding to the desired spins.') 

9_deprecated = object() 

10 

11 

12def get_band_gap(calc, direct=False, spin=_deprecated): 

13 warnings.warn('Please use ase.dft.bandgap.bandgap() instead!') 

14 gap, (s1, k1, _n1), (s2, k2, _n2) = bandgap(calc, direct, spin=spin) 

15 ns = calc.get_number_of_spins() 

16 if ns == 2: 

17 return gap, (s1, k1), (s2, k2) 

18 return gap, k1, k2 

19 

20 

21@dataclass 

22class GapInfo: 

23 eigenvalues: np.ndarray 

24 

25 def __post_init__(self): 

26 self._gapinfo = _bandgap(self.eigenvalues, direct=False) 

27 self._direct_gapinfo = _bandgap(self.eigenvalues, direct=True) 

28 

29 @classmethod 

30 def fromcalc(cls, calc): 

31 kpts = calc.get_ibz_k_points() 

32 nk = len(kpts) 

33 ns = calc.get_number_of_spins() 

34 eigenvalues = np.array([[calc.get_eigenvalues(kpt=k, spin=s) 

35 for k in range(nk)] 

36 for s in range(ns)]) 

37 

38 efermi = calc.get_fermi_level() 

39 return cls(eigenvalues - efermi) 

40 

41 def gap(self): 

42 return self._gapinfo 

43 

44 def direct_gap(self): 

45 return self._direct_gapinfo 

46 

47 @property 

48 def is_metallic(self) -> bool: 

49 return self._gapinfo[0] == 0.0 

50 

51 @property 

52 def gap_is_direct(self) -> bool: 

53 """Whether the direct and indirect gaps are the same transition.""" 

54 return self._gapinfo[1:] == self._direct_gapinfo[1:] 

55 

56 def description(self, *, ibz_kpoints=None) -> str: 

57 """Return human-friendly description of direct/indirect gap. 

58 

59 If ibz_k_points are given, coordinates are printed as well.""" 

60 from typing import List 

61 

62 lines: List[str] = [] 

63 add = lines.append 

64 

65 def skn(skn): 

66 """Convert k-point indices (s, k, n) to string.""" 

67 description = 's={}, k={}, n={}'.format(*skn) 

68 if ibz_kpoints is not None: 

69 coordtxt = '[{:.2f}, {:.2f}, {:.2f}]'.format( 

70 *ibz_kpoints[skn[1]]) 

71 description = f'{description}, {coordtxt}' 

72 return f'({description})' 

73 

74 gap, skn1, skn2 = self.gap() 

75 direct_gap, skn_direct1, skn_direct2 = self.direct_gap() 

76 

77 if self.is_metallic: 

78 add('No gap') 

79 else: 

80 add(f'Gap: {gap:.3f} eV') 

81 add('Transition (v -> c):') 

82 add(f' {skn(skn1)} -> {skn(skn2)}') 

83 

84 if self.gap_is_direct: 

85 add('No difference between direct/indirect transitions') 

86 else: 

87 add('Direct/indirect transitions are different') 

88 add(f'Direct gap: {direct_gap:.3f} eV') 

89 if skn_direct1[0] == skn_direct2[0]: 

90 add(f'Transition at: {skn(skn_direct1)}') 

91 else: 

92 transition = skn((f'{skn_direct1[0]}->{skn_direct2[0]}', 

93 *skn_direct1[1:])) 

94 add(f'Transition at: {transition}') 

95 

96 return '\n'.join(lines) 

97 

98 

99def bandgap(calc=None, direct=False, spin=_deprecated, 

100 eigenvalues=None, efermi=None, output=None, kpts=None): 

101 """Calculates the band-gap. 

102 

103 Parameters: 

104 

105 calc: Calculator object 

106 Electronic structure calculator object. 

107 direct: bool 

108 Calculate direct band-gap. 

109 eigenvalues: ndarray of shape (nspin, nkpt, nband) or (nkpt, nband) 

110 Eigenvalues. 

111 efermi: float 

112 Fermi level (defaults to 0.0). 

113 

114 Returns a (gap, p1, p2) tuple where p1 and p2 are tuples of indices of the 

115 valence and conduction points (s, k, n). 

116 

117 Example: 

118 

119 >>> gap, p1, p2 = bandgap(silicon.calc) 

120 >>> print(gap, p1, p2) 

121 1.2 (0, 0, 3), (0, 5, 4) 

122 >>> gap, p1, p2 = bandgap(silicon.calc, direct=True) 

123 >>> print(gap, p1, p2) 

124 3.4 (0, 0, 3), (0, 0, 4) 

125 """ 

126 

127 if spin is not _deprecated: 

128 raise RuntimeError(spin_error) 

129 

130 if calc: 

131 kpts = calc.get_ibz_k_points() 

132 nk = len(kpts) 

133 ns = calc.get_number_of_spins() 

134 eigenvalues = np.array([[calc.get_eigenvalues(kpt=k, spin=s) 

135 for k in range(nk)] 

136 for s in range(ns)]) 

137 if efermi is None: 

138 efermi = calc.get_fermi_level() 

139 

140 efermi = efermi or 0.0 

141 

142 gapinfo = GapInfo(eigenvalues - efermi) 

143 

144 e_skn = gapinfo.eigenvalues 

145 if eigenvalues.ndim == 2: 

146 e_skn = e_skn[np.newaxis] # spinors 

147 

148 if not np.isfinite(e_skn).all(): 

149 raise ValueError('Bad eigenvalues!') 

150 

151 gap, (s1, k1, n1), (s2, k2, n2) = _bandgap(e_skn, direct) 

152 

153 if eigenvalues.ndim != 3: 

154 p1 = (k1, n1) 

155 p2 = (k2, n2) 

156 else: 

157 p1 = (s1, k1, n1) 

158 p2 = (s2, k2, n2) 

159 

160 return gap, p1, p2 

161 

162 

163def _bandgap(e_skn, direct): 

164 """Helper function.""" 

165 ns, nk, nb = e_skn.shape 

166 s1 = s2 = k1 = k2 = n1 = n2 = None 

167 

168 N_sk = (e_skn < 0.0).sum(2) # number of occupied bands 

169 

170 # Check for bands crossing the fermi-level 

171 if ns == 1: 

172 if np.ptp(N_sk[0]) > 0: 

173 return 0.0, (None, None, None), (None, None, None) 

174 else: 

175 if (np.ptp(N_sk, axis=1) > 0).any(): 

176 return 0.0, (None, None, None), (None, None, None) 

177 

178 if (N_sk == 0).any() or (N_sk == nb).any(): 

179 raise ValueError('Too few bands!') 

180 

181 e_skn = np.array([[e_skn[s, k, N_sk[s, k] - 1:N_sk[s, k] + 1] 

182 for k in range(nk)] 

183 for s in range(ns)]) 

184 ev_sk = e_skn[:, :, 0] # valence band 

185 ec_sk = e_skn[:, :, 1] # conduction band 

186 

187 if ns == 1: 

188 s1 = 0 

189 s2 = 0 

190 gap, k1, k2 = find_gap(ev_sk[0], ec_sk[0], direct) 

191 n1 = N_sk[0, 0] - 1 

192 n2 = n1 + 1 

193 return gap, (0, k1, n1), (0, k2, n2) 

194 

195 gap, k1, k2 = find_gap(ev_sk.ravel(), ec_sk.ravel(), direct) 

196 if direct: 

197 # Check also spin flips: 

198 for s in [0, 1]: 

199 g, k, _ = find_gap(ev_sk[s], ec_sk[1 - s], direct) 

200 if g < gap: 

201 gap = g 

202 k1 = k + nk * s 

203 k2 = k + nk * (1 - s) 

204 

205 if gap > 0.0: 

206 s1, k1 = divmod(k1, nk) 

207 s2, k2 = divmod(k2, nk) 

208 n1 = N_sk[s1, k1] - 1 

209 n2 = N_sk[s2, k2] 

210 return gap, (s1, k1, n1), (s2, k2, n2) 

211 return 0.0, (None, None, None), (None, None, None) 

212 

213 

214def find_gap(ev_k, ec_k, direct): 

215 """Helper function.""" 

216 if direct: 

217 gap_k = ec_k - ev_k 

218 k = gap_k.argmin() 

219 return gap_k[k], k, k 

220 kv = ev_k.argmax() 

221 kc = ec_k.argmin() 

222 return ec_k[kc] - ev_k[kv], kv, kc