Coverage for /builds/debichem-team/python-ase/ase/spacegroup/symmetrize.py: 92.06%

126 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1""" 

2Provides utility functions for FixSymmetry class 

3""" 

4from collections.abc import MutableMapping 

5from typing import Optional 

6 

7import numpy as np 

8 

9from ase.utils import atoms_to_spglib_cell 

10 

11__all__ = ['refine_symmetry', 'check_symmetry'] 

12 

13 

14def spglib_get_symmetry_dataset(*args, **kwargs): 

15 """Temporary compatibility adapter around spglib dataset. 

16 

17 Return an object that allows attribute-based access 

18 in line with recent spglib. This allows ASE code to not care about 

19 older spglib versions. 

20 """ 

21 import spglib 

22 dataset = spglib.get_symmetry_dataset(*args, **kwargs) 

23 if dataset is None: 

24 return None 

25 if isinstance(dataset, dict): # spglib < 2.5.0 

26 return SpglibDatasetWrapper(dataset) 

27 return dataset # spglib >= 2.5.0 

28 

29 

30class SpglibDatasetWrapper(MutableMapping): 

31 # Spglib 2.5.0 returns SpglibDataset with deprecated __getitem__. 

32 # Spglib 2.4.0 and earlier return dict. 

33 # 

34 # We use this object to wrap dictionaries such that both types of access 

35 # work correctly. 

36 def __init__(self, spglib_dct): 

37 self._spglib_dct = spglib_dct 

38 

39 def __getattr__(self, attr): 

40 return self[attr] 

41 

42 def __getitem__(self, key): 

43 return self._spglib_dct[key] 

44 

45 def __len__(self): 

46 return len(self._spglib_dct) 

47 

48 def __iter__(self): 

49 return iter(self._spglib_dct) 

50 

51 def __setitem__(self, key, value): 

52 self._spglib_dct[key] = value 

53 

54 def __delitem__(self, item): 

55 del self._spglib_dct[item] 

56 

57 

58def print_symmetry(symprec, dataset): 

59 print("ase.spacegroup.symmetrize: prec", symprec, 

60 "got symmetry group number", dataset.number, 

61 ", international (Hermann-Mauguin)", dataset.international, 

62 ", Hall ", dataset.hall) 

63 

64 

65def refine_symmetry(atoms, symprec=0.01, verbose=False): 

66 """ 

67 Refine symmetry of an Atoms object 

68 

69 Parameters 

70 ---------- 

71 atoms - input Atoms object 

72 symprec - symmetry precicion 

73 verbose - if True, print out symmetry information before and after 

74 

75 Returns 

76 ------- 

77 

78 spglib dataset 

79 

80 """ 

81 _check_and_symmetrize_cell(atoms, symprec=symprec, verbose=verbose) 

82 _check_and_symmetrize_positions(atoms, symprec=symprec, verbose=verbose) 

83 return check_symmetry(atoms, symprec=1e-4, verbose=verbose) 

84 

85 

86class IntermediateDatasetError(Exception): 

87 """The symmetry dataset in `_check_and_symmetrize_positions` can be at odds 

88 with the original symmetry dataset in `_check_and_symmetrize_cell`. 

89 This implies a faulty partial symmetrization if not handled by exception.""" 

90 

91 

92def get_symmetrized_atoms(atoms, 

93 symprec: float = 0.01, 

94 final_symprec: Optional[float] = None): 

95 """Get new Atoms object with refined symmetries. 

96 

97 Checks internal consistency of the found symmetries. 

98 

99 Parameters 

100 ---------- 

101 atoms : Atoms 

102 Input atoms object. 

103 symprec : float 

104 Symmetry precision used to identify symmetries with spglib. 

105 final_symprec : float 

106 Symmetry precision used for testing the symmetrization. 

107 

108 Returns 

109 ------- 

110 symatoms : Atoms 

111 New atoms object symmetrized according to the input symprec. 

112 """ 

113 atoms = atoms.copy() 

114 original_dataset = _check_and_symmetrize_cell(atoms, symprec=symprec) 

115 intermediate_dataset = _check_and_symmetrize_positions( 

116 atoms, symprec=symprec) 

117 if intermediate_dataset.number != original_dataset.number: 

118 raise IntermediateDatasetError() 

119 final_symprec = final_symprec or symprec 

120 final_dataset = check_symmetry(atoms, symprec=final_symprec) 

121 assert final_dataset.number == original_dataset.number 

122 return atoms, final_dataset 

123 

124 

125def _check_and_symmetrize_cell(atoms, **kwargs): 

126 dataset = check_symmetry(atoms, **kwargs) 

127 _symmetrize_cell(atoms, dataset) 

128 return dataset 

129 

130 

131def _symmetrize_cell(atoms, dataset): 

132 # set actual cell to symmetrized cell vectors by copying 

133 # transformed and rotated standard cell 

134 std_cell = dataset.std_lattice 

135 trans_std_cell = dataset.transformation_matrix.T @ std_cell 

136 rot_trans_std_cell = trans_std_cell @ dataset.std_rotation_matrix 

137 atoms.set_cell(rot_trans_std_cell, True) 

138 

139 

140def _check_and_symmetrize_positions(atoms, *, symprec, **kwargs): 

141 import spglib 

142 dataset = check_symmetry(atoms, symprec=symprec, **kwargs) 

143 # here we are assuming that primitive vectors returned by find_primitive 

144 # are compatible with std_lattice returned by get_symmetry_dataset 

145 res = spglib.find_primitive(atoms_to_spglib_cell(atoms), symprec=symprec) 

146 _symmetrize_positions(atoms, dataset, res) 

147 return dataset 

148 

149 

150def _symmetrize_positions(atoms, dataset, primitive_spglib_cell): 

151 prim_cell, _prim_scaled_pos, _prim_types = primitive_spglib_cell 

152 

153 # calculate offset between standard cell and actual cell 

154 std_cell = dataset.std_lattice 

155 rot_std_cell = std_cell @ dataset.std_rotation_matrix 

156 rot_std_pos = dataset.std_positions @ rot_std_cell 

157 pos = atoms.get_positions() 

158 dp0 = (pos[list(dataset.mapping_to_primitive).index(0)] - rot_std_pos[ 

159 list(dataset.std_mapping_to_primitive).index(0)]) 

160 

161 # create aligned set of standard cell positions to figure out mapping 

162 rot_prim_cell = prim_cell @ dataset.std_rotation_matrix 

163 inv_rot_prim_cell = np.linalg.inv(rot_prim_cell) 

164 aligned_std_pos = rot_std_pos + dp0 

165 

166 # find ideal positions from position of corresponding std cell atom + 

167 # integer_vec . primitive cell vectors 

168 mapping_to_primitive = list(dataset.mapping_to_primitive) 

169 std_mapping_to_primitive = list(dataset.std_mapping_to_primitive) 

170 pos = atoms.get_positions() 

171 for i_at in range(len(atoms)): 

172 std_i_at = std_mapping_to_primitive.index(mapping_to_primitive[i_at]) 

173 dp = aligned_std_pos[std_i_at] - pos[i_at] 

174 dp_s = dp @ inv_rot_prim_cell 

175 pos[i_at] = (aligned_std_pos[std_i_at] - np.round(dp_s) @ rot_prim_cell) 

176 atoms.set_positions(pos) 

177 

178 

179def check_symmetry(atoms, symprec=1.0e-6, verbose=False): 

180 """ 

181 Check symmetry of `atoms` with precision `symprec` using `spglib` 

182 

183 Prints a summary and returns result of `spglib.get_symmetry_dataset()` 

184 """ 

185 dataset = spglib_get_symmetry_dataset(atoms_to_spglib_cell(atoms), 

186 symprec=symprec) 

187 if verbose: 

188 print_symmetry(symprec, dataset) 

189 return dataset 

190 

191 

192def is_subgroup(sup_data, sub_data, tol=1e-10): 

193 """ 

194 Test if spglib dataset `sub_data` is a subgroup of dataset `sup_data` 

195 """ 

196 for rot1, trns1 in zip(sub_data.rotations, sub_data.translations): 

197 for rot2, trns2 in zip(sup_data.rotations, sup_data.translations): 

198 if np.all(rot1 == rot2) and np.linalg.norm(trns1 - trns2) < tol: 

199 break 

200 else: 

201 return False 

202 return True 

203 

204 

205def prep_symmetry(atoms, symprec=1.0e-6, verbose=False): 

206 """ 

207 Prepare `at` for symmetry-preserving minimisation at precision `symprec` 

208 

209 Returns a tuple `(rotations, translations, symm_map)` 

210 """ 

211 dataset = spglib_get_symmetry_dataset(atoms_to_spglib_cell(atoms), 

212 symprec=symprec) 

213 if verbose: 

214 print_symmetry(symprec, dataset) 

215 rotations = dataset.rotations.copy() 

216 translations = dataset.translations.copy() 

217 symm_map = [] 

218 scaled_pos = atoms.get_scaled_positions() 

219 for (rot, trans) in zip(rotations, translations): 

220 this_op_map = [-1] * len(atoms) 

221 for i_at in range(len(atoms)): 

222 new_p = rot @ scaled_pos[i_at, :] + trans 

223 dp = scaled_pos - new_p 

224 dp -= np.round(dp) 

225 i_at_map = np.argmin(np.linalg.norm(dp, axis=1)) 

226 this_op_map[i_at] = i_at_map 

227 symm_map.append(this_op_map) 

228 return (rotations, translations, symm_map) 

229 

230 

231def symmetrize_rank1(lattice, inv_lattice, forces, rot, trans, symm_map): 

232 """ 

233 Return symmetrized forces 

234 

235 lattice vectors expected as row vectors (same as ASE get_cell() convention), 

236 inv_lattice is its matrix inverse (reciprocal().T) 

237 """ 

238 scaled_symmetrized_forces_T = np.zeros(forces.T.shape) 

239 

240 scaled_forces_T = np.dot(inv_lattice.T, forces.T) 

241 for (r, t, this_op_map) in zip(rot, trans, symm_map): 

242 transformed_forces_T = np.dot(r, scaled_forces_T) 

243 scaled_symmetrized_forces_T[:, this_op_map] += transformed_forces_T 

244 scaled_symmetrized_forces_T /= len(rot) 

245 symmetrized_forces = (lattice.T @ scaled_symmetrized_forces_T).T 

246 

247 return symmetrized_forces 

248 

249 

250def symmetrize_rank2(lattice, lattice_inv, stress_3_3, rot): 

251 """ 

252 Return symmetrized stress 

253 

254 lattice vectors expected as row vectors (same as ASE get_cell() convention), 

255 inv_lattice is its matrix inverse (reciprocal().T) 

256 """ 

257 scaled_stress = np.dot(np.dot(lattice, stress_3_3), lattice.T) 

258 

259 symmetrized_scaled_stress = np.zeros((3, 3)) 

260 for r in rot: 

261 symmetrized_scaled_stress += np.dot(np.dot(r.T, scaled_stress), r) 

262 symmetrized_scaled_stress /= len(rot) 

263 

264 sym = np.dot(np.dot(lattice_inv, symmetrized_scaled_stress), lattice_inv.T) 

265 return sym