Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import string 

2import numpy as np 

3from ase.io import string2index 

4from ase.io.formats import parse_filename 

5from ase.data import chemical_symbols 

6 

7# default fields 

8 

9 

10def field_specs_on_conditions(calculator_outputs, rank_order): 

11 if calculator_outputs: 

12 field_specs = ['i:0', 'el', 'd', 'rd', 'df', 'rdf'] 

13 else: 

14 field_specs = ['i:0', 'el', 'dx', 'dy', 'dz', 'd', 'rd'] 

15 if rank_order is not None: 

16 field_specs[0] = 'i:1' 

17 if rank_order in field_specs: 

18 for c, i in enumerate(field_specs): 

19 if i == rank_order: 

20 field_specs[c] = i + ':0:1' 

21 else: 

22 field_specs.append(rank_order + ':0:1') 

23 else: 

24 field_specs[0] = field_specs[0] + ':1' 

25 return field_specs 

26 

27 

28def summary_functions_on_conditions(has_calc): 

29 if has_calc: 

30 return [rmsd, energy_delta] 

31 return [rmsd] 

32 

33 

34def header_alias(h): 

35 """Replace keyboard characters with Unicode symbols 

36 for pretty printing""" 

37 if h == 'i': 

38 h = 'index' 

39 elif h == 'an': 

40 h = 'atomic #' 

41 elif h == 't': 

42 h = 'tag' 

43 elif h == 'el': 

44 h = 'element' 

45 elif h[0] == 'd': 

46 h = h.replace('d', 'Δ') 

47 elif h[0] == 'r': 

48 h = 'rank ' + header_alias(h[1:]) 

49 elif h[0] == 'a': 

50 h = h.replace('a', '<') 

51 h += '>' 

52 return h 

53 

54 

55def prec_round(a, prec=2): 

56 """ 

57 To make hierarchical sorting different from non-hierarchical sorting 

58 with floats. 

59 """ 

60 if a == 0: 

61 return a 

62 else: 

63 s = 1 if a > 0 else -1 

64 m = np.log10(s * a) // 1 

65 c = np.log10(s * a) % 1 

66 return s * np.round(10**c, prec) * 10**m 

67 

68 

69prec_round = np.vectorize(prec_round) 

70 

71# end most settings 

72 

73# this will sort alphabetically by chemical symbol 

74num2sym = dict(zip(np.argsort(chemical_symbols), chemical_symbols)) 

75# to sort by atomic number, uncomment below 

76# num2sym = dict(zip(range(len(chemical_symbols)), chemical_symbols)) 

77sym2num = {v: k for k, v in num2sym.items()} 

78 

79atoms_props = [ 

80 'dx', 

81 'dy', 

82 'dz', 

83 'd', 

84 't', 

85 'an', 

86 'i', 

87 'el', 

88 'p1', 

89 'p2', 

90 'p1x', 

91 'p1y', 

92 'p1z', 

93 'p2x', 

94 'p2y', 

95 'p2z'] 

96 

97 

98def get_field_data(atoms1, atoms2, field): 

99 if field[0] == 'r': 

100 field = field[1:] 

101 rank_order = True 

102 else: 

103 rank_order = False 

104 

105 if field in atoms_props: 

106 if field == 't': 

107 data = atoms1.get_tags() 

108 elif field == 'an': 

109 data = atoms1.numbers 

110 elif field == 'el': 

111 data = np.array([sym2num[sym] for sym in atoms1.symbols]) 

112 elif field == 'i': 

113 data = np.arange(len(atoms1)) 

114 else: 

115 if field.startswith('d'): 

116 y = atoms2.positions - atoms1.positions 

117 elif field.startswith('p'): 

118 if field[1] == '1': 

119 y = atoms1.positions 

120 else: 

121 y = atoms2.positions 

122 

123 if field.endswith('x'): 

124 data = y[:, 0] 

125 elif field.endswith('y'): 

126 data = y[:, 1] 

127 elif field.endswith('z'): 

128 data = y[:, 2] 

129 else: 

130 data = np.linalg.norm(y, axis=1) 

131 else: 

132 if field[0] == 'd': 

133 y = atoms2.get_forces() - atoms1.get_forces() 

134 elif field[0] == 'a': 

135 y = (atoms2.get_forces() + atoms1.get_forces()) / 2 

136 else: 

137 if field[1] == '1': 

138 y = atoms1.get_forces() 

139 else: 

140 y = atoms2.get_forces() 

141 

142 if field.endswith('x'): 

143 data = y[:, 0] 

144 elif field.endswith('y'): 

145 data = y[:, 1] 

146 elif field.endswith('z'): 

147 data = y[:, 2] 

148 else: 

149 data = np.linalg.norm(y, axis=1) 

150 

151 if rank_order: 

152 return np.argsort(np.argsort(-data)) 

153 

154 return data 

155 

156 

157# Summary Functions 

158 

159def rmsd(atoms1, atoms2): 

160 dpositions = atoms2.positions - atoms1.positions 

161 return 'RMSD={:+.1E}'.format( 

162 np.sqrt((np.linalg.norm(dpositions, axis=1)**2).mean())) 

163 

164 

165def energy_delta(atoms1, atoms2): 

166 E1 = atoms1.get_potential_energy() 

167 E2 = atoms2.get_potential_energy() 

168 return 'E1 = {:+.1E}, E2 = {:+.1E}, dE = {:+1.1E}'.format(E1, E2, E2 - E1) 

169 

170 

171def parse_field_specs(field_specs): 

172 fields = [] 

173 hier = [] 

174 scent = [] 

175 for fs in field_specs: 

176 fhs = fs.split(':') 

177 if len(fhs) == 3: 

178 scent.append(int(fhs[2])) 

179 hier.append(int(fhs[1])) 

180 fields.append(fhs[0]) 

181 elif len(fhs) == 2: 

182 scent.append(-1) 

183 hier.append(int(fhs[1])) 

184 fields.append(fhs[0]) 

185 elif len(fhs) == 1: 

186 scent.append(-1) 

187 hier.append(-1) 

188 fields.append(fhs[0]) 

189 mxm = max(hier) 

190 for c in range(len(hier)): 

191 if hier[c] < 0: 

192 mxm += 1 

193 hier[c] = mxm 

194 # reversed by convention of numpy lexsort 

195 hier = np.argsort(hier)[::-1] 

196 return fields, hier, np.array(scent) 

197 

198# Class definitions 

199 

200 

201class MapFormatter(string.Formatter): 

202 """String formatting method to map string 

203 mapped to float data field 

204 used for sorting back to string.""" 

205 

206 def format_field(self, value, spec): 

207 if spec.endswith('h'): 

208 value = num2sym[int(value)] 

209 spec = spec[:-1] + 's' 

210 return super(MapFormatter, self).format_field(value, spec) 

211 

212 

213class TableFormat: 

214 def __init__(self, 

215 columnwidth=9, 

216 precision=2, 

217 representation='E', 

218 toprule='=', 

219 midrule='-', 

220 bottomrule='='): 

221 

222 self.precision = precision 

223 self.representation = representation 

224 self.columnwidth = columnwidth 

225 self.formatter = MapFormatter().format 

226 self.toprule = toprule 

227 self.midrule = midrule 

228 self.bottomrule = bottomrule 

229 

230 self.fmt_class = { 

231 'signed float': "{{: ^{}.{}{}}}".format( 

232 self.columnwidth, 

233 self.precision - 1, 

234 self.representation), 

235 'unsigned float': "{{:^{}.{}{}}}".format( 

236 self.columnwidth, 

237 self.precision - 1, 

238 self.representation), 

239 'int': "{{:^{}n}}".format( 

240 self.columnwidth), 

241 'str': "{{:^{}s}}".format( 

242 self.columnwidth), 

243 'conv': "{{:^{}h}}".format( 

244 self.columnwidth)} 

245 fmt = {} 

246 signed_floats = [ 

247 'dx', 

248 'dy', 

249 'dz', 

250 'dfx', 

251 'dfy', 

252 'dfz', 

253 'afx', 

254 'afy', 

255 'afz', 

256 'p1x', 

257 'p2x', 

258 'p1y', 

259 'p2y', 

260 'p1z', 

261 'p2z', 

262 'f1x', 

263 'f2x', 

264 'f1y', 

265 'f2y', 

266 'f1z', 

267 'f2z'] 

268 for sf in signed_floats: 

269 fmt[sf] = self.fmt_class['signed float'] 

270 unsigned_floats = ['d', 'df', 'af', 'p1', 'p2', 'f1', 'f2'] 

271 for usf in unsigned_floats: 

272 fmt[usf] = self.fmt_class['unsigned float'] 

273 integers = ['i', 'an', 't'] + ['r' + sf for sf in signed_floats] + \ 

274 ['r' + usf for usf in unsigned_floats] 

275 for i in integers: 

276 fmt[i] = self.fmt_class['int'] 

277 fmt['el'] = self.fmt_class['conv'] 

278 

279 self.fmt = fmt 

280 

281 

282class Table: 

283 def __init__(self, 

284 field_specs, 

285 summary_functions=[], 

286 tableformat=None, 

287 max_lines=None, 

288 title='', 

289 tablewidth=None): 

290 

291 self.max_lines = max_lines 

292 self.summary_functions = summary_functions 

293 self.field_specs = field_specs 

294 

295 self.fields, self.hier, self.scent = parse_field_specs(self.field_specs) 

296 self.nfields = len(self.fields) 

297 

298 # formatting 

299 if tableformat is None: 

300 self.tableformat = TableFormat() 

301 else: 

302 self.tableformat = tableformat 

303 

304 if tablewidth is None: 

305 self.tablewidth = self.tableformat.columnwidth * self.nfields 

306 else: 

307 self.tablewidth = tablewidth 

308 

309 self.title = title 

310 

311 def make(self, atoms1, atoms2, csv=False): 

312 header = self.make_header(csv=csv) 

313 body = self.make_body(atoms1, atoms2, csv=csv) 

314 if self.max_lines is not None: 

315 body = body[:self.max_lines] 

316 summary = self.make_summary(atoms1, atoms2) 

317 

318 return '\n'.join([self.title, 

319 self.tableformat.toprule * self.tablewidth, 

320 header, 

321 self.tableformat.midrule * self.tablewidth, 

322 body, 

323 self.tableformat.bottomrule * self.tablewidth, 

324 summary]) 

325 

326 def make_header(self, csv=False): 

327 if csv: 

328 return ','.join([header_alias(field) for field in self.fields]) 

329 

330 fields = self.tableformat.fmt_class['str'] * self.nfields 

331 headers = [header_alias(field) for field in self.fields] 

332 

333 return self.tableformat.formatter(fields, *headers) 

334 

335 def make_summary(self, atoms1, atoms2): 

336 return '\n'.join([summary_function(atoms1, atoms2) 

337 for summary_function in self.summary_functions]) 

338 

339 def make_body(self, atoms1, atoms2, csv=False): 

340 field_data = np.array([get_field_data(atoms1, atoms2, field) 

341 for field in self.fields]) 

342 

343 sorting_array = field_data * self.scent[:, np.newaxis] 

344 sorting_array = sorting_array[self.hier] 

345 sorting_array = prec_round(sorting_array, self.tableformat.precision) 

346 

347 field_data = field_data[:, np.lexsort(sorting_array)].transpose() 

348 

349 if csv: 

350 rowformat = ','.join(['{:h}' if field == 'el' else '{{:.{}E}}'.format( 

351 self.tableformat.precision) for field in self.fields]) 

352 else: 

353 rowformat = ''.join([self.tableformat.fmt[field] 

354 for field in self.fields]) 

355 body = [ 

356 self.tableformat.formatter( 

357 rowformat, 

358 *row) for row in field_data] 

359 return '\n'.join(body) 

360 

361 

362default_index = string2index(':') 

363 

364 

365def slice_split(filename): 

366 if '@' in filename: 

367 filename, index = parse_filename(filename, None) 

368 else: 

369 filename, index = parse_filename(filename, default_index) 

370 return filename, index