Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import string
2import numpy as np
3from ase.io import string2index
4from ase.io.formats import parse_filename
5from ase.data import chemical_symbols
7# default fields
10def field_specs_on_conditions(calculator_outputs, rank_order):
11 if calculator_outputs:
12 field_specs = ['i:0', 'el', 'd', 'rd', 'df', 'rdf']
13 else:
14 field_specs = ['i:0', 'el', 'dx', 'dy', 'dz', 'd', 'rd']
15 if rank_order is not None:
16 field_specs[0] = 'i:1'
17 if rank_order in field_specs:
18 for c, i in enumerate(field_specs):
19 if i == rank_order:
20 field_specs[c] = i + ':0:1'
21 else:
22 field_specs.append(rank_order + ':0:1')
23 else:
24 field_specs[0] = field_specs[0] + ':1'
25 return field_specs
28def summary_functions_on_conditions(has_calc):
29 if has_calc:
30 return [rmsd, energy_delta]
31 return [rmsd]
34def header_alias(h):
35 """Replace keyboard characters with Unicode symbols
36 for pretty printing"""
37 if h == 'i':
38 h = 'index'
39 elif h == 'an':
40 h = 'atomic #'
41 elif h == 't':
42 h = 'tag'
43 elif h == 'el':
44 h = 'element'
45 elif h[0] == 'd':
46 h = h.replace('d', 'Δ')
47 elif h[0] == 'r':
48 h = 'rank ' + header_alias(h[1:])
49 elif h[0] == 'a':
50 h = h.replace('a', '<')
51 h += '>'
52 return h
55def prec_round(a, prec=2):
56 """
57 To make hierarchical sorting different from non-hierarchical sorting
58 with floats.
59 """
60 if a == 0:
61 return a
62 else:
63 s = 1 if a > 0 else -1
64 m = np.log10(s * a) // 1
65 c = np.log10(s * a) % 1
66 return s * np.round(10**c, prec) * 10**m
69prec_round = np.vectorize(prec_round)
71# end most settings
73# this will sort alphabetically by chemical symbol
74num2sym = dict(zip(np.argsort(chemical_symbols), chemical_symbols))
75# to sort by atomic number, uncomment below
76# num2sym = dict(zip(range(len(chemical_symbols)), chemical_symbols))
77sym2num = {v: k for k, v in num2sym.items()}
79atoms_props = [
80 'dx',
81 'dy',
82 'dz',
83 'd',
84 't',
85 'an',
86 'i',
87 'el',
88 'p1',
89 'p2',
90 'p1x',
91 'p1y',
92 'p1z',
93 'p2x',
94 'p2y',
95 'p2z']
98def get_field_data(atoms1, atoms2, field):
99 if field[0] == 'r':
100 field = field[1:]
101 rank_order = True
102 else:
103 rank_order = False
105 if field in atoms_props:
106 if field == 't':
107 data = atoms1.get_tags()
108 elif field == 'an':
109 data = atoms1.numbers
110 elif field == 'el':
111 data = np.array([sym2num[sym] for sym in atoms1.symbols])
112 elif field == 'i':
113 data = np.arange(len(atoms1))
114 else:
115 if field.startswith('d'):
116 y = atoms2.positions - atoms1.positions
117 elif field.startswith('p'):
118 if field[1] == '1':
119 y = atoms1.positions
120 else:
121 y = atoms2.positions
123 if field.endswith('x'):
124 data = y[:, 0]
125 elif field.endswith('y'):
126 data = y[:, 1]
127 elif field.endswith('z'):
128 data = y[:, 2]
129 else:
130 data = np.linalg.norm(y, axis=1)
131 else:
132 if field[0] == 'd':
133 y = atoms2.get_forces() - atoms1.get_forces()
134 elif field[0] == 'a':
135 y = (atoms2.get_forces() + atoms1.get_forces()) / 2
136 else:
137 if field[1] == '1':
138 y = atoms1.get_forces()
139 else:
140 y = atoms2.get_forces()
142 if field.endswith('x'):
143 data = y[:, 0]
144 elif field.endswith('y'):
145 data = y[:, 1]
146 elif field.endswith('z'):
147 data = y[:, 2]
148 else:
149 data = np.linalg.norm(y, axis=1)
151 if rank_order:
152 return np.argsort(np.argsort(-data))
154 return data
157# Summary Functions
159def rmsd(atoms1, atoms2):
160 dpositions = atoms2.positions - atoms1.positions
161 return 'RMSD={:+.1E}'.format(
162 np.sqrt((np.linalg.norm(dpositions, axis=1)**2).mean()))
165def energy_delta(atoms1, atoms2):
166 E1 = atoms1.get_potential_energy()
167 E2 = atoms2.get_potential_energy()
168 return 'E1 = {:+.1E}, E2 = {:+.1E}, dE = {:+1.1E}'.format(E1, E2, E2 - E1)
171def parse_field_specs(field_specs):
172 fields = []
173 hier = []
174 scent = []
175 for fs in field_specs:
176 fhs = fs.split(':')
177 if len(fhs) == 3:
178 scent.append(int(fhs[2]))
179 hier.append(int(fhs[1]))
180 fields.append(fhs[0])
181 elif len(fhs) == 2:
182 scent.append(-1)
183 hier.append(int(fhs[1]))
184 fields.append(fhs[0])
185 elif len(fhs) == 1:
186 scent.append(-1)
187 hier.append(-1)
188 fields.append(fhs[0])
189 mxm = max(hier)
190 for c in range(len(hier)):
191 if hier[c] < 0:
192 mxm += 1
193 hier[c] = mxm
194 # reversed by convention of numpy lexsort
195 hier = np.argsort(hier)[::-1]
196 return fields, hier, np.array(scent)
198# Class definitions
201class MapFormatter(string.Formatter):
202 """String formatting method to map string
203 mapped to float data field
204 used for sorting back to string."""
206 def format_field(self, value, spec):
207 if spec.endswith('h'):
208 value = num2sym[int(value)]
209 spec = spec[:-1] + 's'
210 return super(MapFormatter, self).format_field(value, spec)
213class TableFormat:
214 def __init__(self,
215 columnwidth=9,
216 precision=2,
217 representation='E',
218 toprule='=',
219 midrule='-',
220 bottomrule='='):
222 self.precision = precision
223 self.representation = representation
224 self.columnwidth = columnwidth
225 self.formatter = MapFormatter().format
226 self.toprule = toprule
227 self.midrule = midrule
228 self.bottomrule = bottomrule
230 self.fmt_class = {
231 'signed float': "{{: ^{}.{}{}}}".format(
232 self.columnwidth,
233 self.precision - 1,
234 self.representation),
235 'unsigned float': "{{:^{}.{}{}}}".format(
236 self.columnwidth,
237 self.precision - 1,
238 self.representation),
239 'int': "{{:^{}n}}".format(
240 self.columnwidth),
241 'str': "{{:^{}s}}".format(
242 self.columnwidth),
243 'conv': "{{:^{}h}}".format(
244 self.columnwidth)}
245 fmt = {}
246 signed_floats = [
247 'dx',
248 'dy',
249 'dz',
250 'dfx',
251 'dfy',
252 'dfz',
253 'afx',
254 'afy',
255 'afz',
256 'p1x',
257 'p2x',
258 'p1y',
259 'p2y',
260 'p1z',
261 'p2z',
262 'f1x',
263 'f2x',
264 'f1y',
265 'f2y',
266 'f1z',
267 'f2z']
268 for sf in signed_floats:
269 fmt[sf] = self.fmt_class['signed float']
270 unsigned_floats = ['d', 'df', 'af', 'p1', 'p2', 'f1', 'f2']
271 for usf in unsigned_floats:
272 fmt[usf] = self.fmt_class['unsigned float']
273 integers = ['i', 'an', 't'] + ['r' + sf for sf in signed_floats] + \
274 ['r' + usf for usf in unsigned_floats]
275 for i in integers:
276 fmt[i] = self.fmt_class['int']
277 fmt['el'] = self.fmt_class['conv']
279 self.fmt = fmt
282class Table:
283 def __init__(self,
284 field_specs,
285 summary_functions=[],
286 tableformat=None,
287 max_lines=None,
288 title='',
289 tablewidth=None):
291 self.max_lines = max_lines
292 self.summary_functions = summary_functions
293 self.field_specs = field_specs
295 self.fields, self.hier, self.scent = parse_field_specs(self.field_specs)
296 self.nfields = len(self.fields)
298 # formatting
299 if tableformat is None:
300 self.tableformat = TableFormat()
301 else:
302 self.tableformat = tableformat
304 if tablewidth is None:
305 self.tablewidth = self.tableformat.columnwidth * self.nfields
306 else:
307 self.tablewidth = tablewidth
309 self.title = title
311 def make(self, atoms1, atoms2, csv=False):
312 header = self.make_header(csv=csv)
313 body = self.make_body(atoms1, atoms2, csv=csv)
314 if self.max_lines is not None:
315 body = body[:self.max_lines]
316 summary = self.make_summary(atoms1, atoms2)
318 return '\n'.join([self.title,
319 self.tableformat.toprule * self.tablewidth,
320 header,
321 self.tableformat.midrule * self.tablewidth,
322 body,
323 self.tableformat.bottomrule * self.tablewidth,
324 summary])
326 def make_header(self, csv=False):
327 if csv:
328 return ','.join([header_alias(field) for field in self.fields])
330 fields = self.tableformat.fmt_class['str'] * self.nfields
331 headers = [header_alias(field) for field in self.fields]
333 return self.tableformat.formatter(fields, *headers)
335 def make_summary(self, atoms1, atoms2):
336 return '\n'.join([summary_function(atoms1, atoms2)
337 for summary_function in self.summary_functions])
339 def make_body(self, atoms1, atoms2, csv=False):
340 field_data = np.array([get_field_data(atoms1, atoms2, field)
341 for field in self.fields])
343 sorting_array = field_data * self.scent[:, np.newaxis]
344 sorting_array = sorting_array[self.hier]
345 sorting_array = prec_round(sorting_array, self.tableformat.precision)
347 field_data = field_data[:, np.lexsort(sorting_array)].transpose()
349 if csv:
350 rowformat = ','.join(['{:h}' if field == 'el' else '{{:.{}E}}'.format(
351 self.tableformat.precision) for field in self.fields])
352 else:
353 rowformat = ''.join([self.tableformat.fmt[field]
354 for field in self.fields])
355 body = [
356 self.tableformat.formatter(
357 rowformat,
358 *row) for row in field_data]
359 return '\n'.join(body)
362default_index = string2index(':')
365def slice_split(filename):
366 if '@' in filename:
367 filename, index = parse_filename(filename, None)
368 else:
369 filename, index = parse_filename(filename, default_index)
370 return filename, index