Coverage for /builds/debichem-team/python-ase/ase/db/cli.py: 62.30%
244 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import json
2import sys
3from collections import defaultdict
4from contextlib import contextmanager
5from pathlib import Path
6from typing import Iterable, Iterator
8import ase.io
9from ase.db import connect
10from ase.db.core import convert_str_to_int_float_bool_or_str
11from ase.db.row import row2dct
12from ase.db.table import Table, all_columns
13from ase.utils import plural
16def count_keys(db, query):
17 keys = defaultdict(int)
18 for row in db.select(query):
19 for key in row._keys:
20 keys[key] += 1
22 n = max(len(key) for key in keys) + 1
23 for key, number in keys.items():
24 print('{:{}} {}'.format(key + ':', n, number))
25 return
28def main(args):
29 verbosity = 1 - args.quiet + args.verbose
30 query = ','.join(args.query)
32 if args.sort.endswith('-'):
33 # Allow using "key-" instead of "-key" for reverse sorting
34 args.sort = '-' + args.sort[:-1]
36 if query.isdigit():
37 query = int(query)
39 add_key_value_pairs = {}
40 if args.add_key_value_pairs:
41 for pair in args.add_key_value_pairs.split(','):
42 key, value = pair.split('=')
43 add_key_value_pairs[key] = \
44 convert_str_to_int_float_bool_or_str(value)
46 if args.delete_keys:
47 delete_keys = args.delete_keys.split(',')
48 else:
49 delete_keys = []
51 db = connect(args.database, use_lock_file=not args.no_lock_file)
53 def out(*args):
54 if verbosity > 0:
55 print(*args)
57 if args.analyse:
58 db.analyse()
59 return
61 if args.show_keys:
62 count_keys(db, query)
63 return
65 if args.show_values:
66 keys = args.show_values.split(',')
67 values = {key: defaultdict(int) for key in keys}
68 numbers = set()
69 for row in db.select(query):
70 kvp = row.key_value_pairs
71 for key in keys:
72 value = kvp.get(key)
73 if value is not None:
74 values[key][value] += 1
75 if not isinstance(value, str):
76 numbers.add(key)
78 n = max(len(key) for key in keys) + 1
79 for key in keys:
80 vals = values[key]
81 if key in numbers:
82 print('{:{}} [{}..{}]'
83 .format(key + ':', n, min(vals), max(vals)))
84 else:
85 print('{:{}} {}'
86 .format(key + ':', n,
87 ', '.join(f'{v}({n})'
88 for v, n in vals.items())))
89 return
91 if args.add_from_file:
92 filename = args.add_from_file
93 configs = ase.io.read(filename)
94 if not isinstance(configs, list):
95 configs = [configs]
96 for atoms in configs:
97 db.write(atoms, key_value_pairs=add_key_value_pairs)
98 out('Added ' + plural(len(configs), 'row'))
99 return
101 if args.count:
102 n = db.count(query)
103 print(f'{plural(n, "row")}')
104 return
106 if args.insert_into:
107 if args.limit == -1:
108 args.limit = 0
110 progressbar = no_progressbar
111 length = None
113 if args.progress_bar:
114 # Try to import the one from click.
115 # People using ase.db will most likely have flask installed
116 # and therfore also click.
117 try:
118 from click import progressbar
119 except ImportError:
120 pass
121 else:
122 length = db.count(query)
124 nkvp = 0
125 nrows = 0
126 with connect(args.insert_into,
127 use_lock_file=not args.no_lock_file) as db2:
128 with progressbar(db.select(query,
129 sort=args.sort,
130 limit=args.limit,
131 offset=args.offset),
132 length=length) as rows:
133 for row in rows:
134 kvp = row.get('key_value_pairs', {})
135 nkvp -= len(kvp)
136 kvp.update(add_key_value_pairs)
137 nkvp += len(kvp)
138 if args.strip_data:
139 db2.write(row.toatoms(), **kvp)
140 else:
141 db2.write(row, data=row.get('data'), **kvp)
142 nrows += 1
144 out('Added %s (%s updated)' %
145 (plural(nkvp, 'key-value pair'),
146 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
147 out(f'Inserted {plural(nrows, "row")}')
148 return
150 if args.limit == -1:
151 args.limit = 20
153 if args.explain:
154 for row in db.select(query, explain=True,
155 verbosity=verbosity,
156 limit=args.limit, offset=args.offset):
157 print(row['explain'])
158 return
160 if args.show_metadata:
161 print(json.dumps(db.metadata, sort_keys=True, indent=4))
162 return
164 if args.set_metadata:
165 with open(args.set_metadata) as fd:
166 db.metadata = json.load(fd)
167 return
169 if add_key_value_pairs or delete_keys:
170 ids = [row['id'] for row in db.select(query)]
171 M = 0
172 N = 0
173 with db:
174 for id in ids:
175 m, n = db.update(id, delete_keys=delete_keys,
176 **add_key_value_pairs)
177 M += m
178 N += n
179 out('Added %s (%s updated)' %
180 (plural(M, 'key-value pair'),
181 plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
182 out('Removed', plural(N, 'key-value pair'))
184 return
186 if args.delete:
187 ids = [row['id'] for row in db.select(query, include_data=False)]
188 if ids and not args.yes:
189 msg = f'Delete {plural(len(ids), "row")}? (yes/No): '
190 if input(msg).lower() != 'yes':
191 return
192 db.delete(ids)
193 out(f'Deleted {plural(len(ids), "row")}')
194 return
196 if args.plot:
197 if ':' in args.plot:
198 tags, keys = args.plot.split(':')
199 tags = tags.split(',')
200 else:
201 tags = []
202 keys = args.plot
203 keys = keys.split(',')
204 plots = defaultdict(list)
205 X = {}
206 labels = []
207 for row in db.select(query, sort=args.sort, include_data=False):
208 name = ','.join(str(row[tag]) for tag in tags)
209 x = row.get(keys[0])
210 if x is not None:
211 if isinstance(x, str):
212 if x not in X:
213 X[x] = len(X)
214 labels.append(x)
215 x = X[x]
216 plots[name].append([x] + [row.get(key) for key in keys[1:]])
217 import matplotlib.pyplot as plt
218 for name, plot in plots.items():
219 xyy = list(zip(*plot))
220 x = xyy[0]
221 for y, key in zip(xyy[1:], keys[1:]):
222 plt.plot(x, y, label=name + ':' + key)
223 if X:
224 plt.xticks(range(len(labels)), labels, rotation=90)
225 plt.legend()
226 plt.show()
227 return
229 if args.json:
230 row = db.get(query)
231 db2 = connect(sys.stdout, 'json', use_lock_file=False)
232 kvp = row.get('key_value_pairs', {})
233 db2.write(row, data=row.get('data'), **kvp)
234 return
236 if args.long:
237 row = db.get(query)
238 print(row2str(row))
239 return
241 if args.open_web_browser:
242 try:
243 import flask # noqa
244 except ImportError:
245 print('Please install Flask: python3 -m pip install flask')
246 return
247 check_jsmol()
248 import ase.db.app as app
249 app.DBApp().run_db(db)
250 return
252 columns = list(all_columns)
253 c = args.columns
254 if c and c.startswith('++'):
255 keys = set()
256 for row in db.select(query,
257 limit=args.limit, offset=args.offset,
258 include_data=False):
259 keys.update(row._keys)
260 columns.extend(keys)
261 if c[2:3] == ',':
262 c = c[3:]
263 else:
264 c = ''
265 if c:
266 if c[0] == '+':
267 c = c[1:]
268 elif c[0] != '-':
269 columns = []
270 for col in c.split(','):
271 if col[0] == '-':
272 columns.remove(col[1:])
273 else:
274 columns.append(col.lstrip('+'))
276 table = Table(db, verbosity=verbosity, cut=args.cut)
277 table.select(query, columns, args.sort, args.limit, args.offset)
278 if args.csv:
279 table.write_csv()
280 else:
281 table.write(query)
284def row2str(row) -> str:
285 t = row2dct(row, key_descriptions={})
286 S = [t['formula'] + ':',
287 'Unit cell in Ang:',
288 'axis|periodic| x| y| z|' +
289 ' length| angle']
290 c = 1
291 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' +
292 '{3:>10}|{4:>10}')
293 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']):
294 S.append(fmt.format(c, [' no', 'yes'][p], axis, L, A))
295 c += 1
296 S.append('')
298 if 'stress' in t:
299 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:',
300 ' {}\n'.format(t['stress'])]
302 if 'dipole' in t:
303 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole']))
305 if 'constraints' in t:
306 S.append('Constraints: {}\n'.format(t['constraints']))
308 if 'data' in t:
309 S.append('Data: {}\n'.format(t['data']))
311 width0 = max(max(len(row[0]) for row in t['table']), 3)
312 width1 = max(max(len(row[1]) for row in t['table']), 11)
313 S.append('{:{}} | {:{}} | Value'
314 .format('Key', width0, 'Description', width1))
315 for key, desc, value in t['table']:
316 S.append('{:{}} | {:{}} | {}'
317 .format(key, width0, desc, width1, value))
318 return '\n'.join(S)
321@contextmanager
322def no_progressbar(iterable: Iterable,
323 length: int = None) -> Iterator[Iterable]:
324 """A do-nothing implementation."""
325 yield iterable
328def check_jsmol():
329 static = Path(__file__).parent / 'static'
330 if not (static / 'jsmol/JSmol.min.js').is_file():
331 print(f"""
332 WARNING:
333 You don't have jsmol on your system.
335 Download Jmol-*-binary.tar.gz from
336 https://sourceforge.net/projects/jmol/files/Jmol/,
337 extract jsmol.zip, unzip it and create a soft-link:
339 $ tar -xf Jmol-*-binary.tar.gz
340 $ unzip jmol-*/jsmol.zip
341 $ ln -s $PWD/jsmol {static}/jsmol
342 """,
343 file=sys.stderr)