Coverage for /builds/debichem-team/python-ase/ase/db/cli.py: 62.30%

244 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import json 

2import sys 

3from collections import defaultdict 

4from contextlib import contextmanager 

5from pathlib import Path 

6from typing import Iterable, Iterator 

7 

8import ase.io 

9from ase.db import connect 

10from ase.db.core import convert_str_to_int_float_bool_or_str 

11from ase.db.row import row2dct 

12from ase.db.table import Table, all_columns 

13from ase.utils import plural 

14 

15 

16def count_keys(db, query): 

17 keys = defaultdict(int) 

18 for row in db.select(query): 

19 for key in row._keys: 

20 keys[key] += 1 

21 

22 n = max(len(key) for key in keys) + 1 

23 for key, number in keys.items(): 

24 print('{:{}} {}'.format(key + ':', n, number)) 

25 return 

26 

27 

28def main(args): 

29 verbosity = 1 - args.quiet + args.verbose 

30 query = ','.join(args.query) 

31 

32 if args.sort.endswith('-'): 

33 # Allow using "key-" instead of "-key" for reverse sorting 

34 args.sort = '-' + args.sort[:-1] 

35 

36 if query.isdigit(): 

37 query = int(query) 

38 

39 add_key_value_pairs = {} 

40 if args.add_key_value_pairs: 

41 for pair in args.add_key_value_pairs.split(','): 

42 key, value = pair.split('=') 

43 add_key_value_pairs[key] = \ 

44 convert_str_to_int_float_bool_or_str(value) 

45 

46 if args.delete_keys: 

47 delete_keys = args.delete_keys.split(',') 

48 else: 

49 delete_keys = [] 

50 

51 db = connect(args.database, use_lock_file=not args.no_lock_file) 

52 

53 def out(*args): 

54 if verbosity > 0: 

55 print(*args) 

56 

57 if args.analyse: 

58 db.analyse() 

59 return 

60 

61 if args.show_keys: 

62 count_keys(db, query) 

63 return 

64 

65 if args.show_values: 

66 keys = args.show_values.split(',') 

67 values = {key: defaultdict(int) for key in keys} 

68 numbers = set() 

69 for row in db.select(query): 

70 kvp = row.key_value_pairs 

71 for key in keys: 

72 value = kvp.get(key) 

73 if value is not None: 

74 values[key][value] += 1 

75 if not isinstance(value, str): 

76 numbers.add(key) 

77 

78 n = max(len(key) for key in keys) + 1 

79 for key in keys: 

80 vals = values[key] 

81 if key in numbers: 

82 print('{:{}} [{}..{}]' 

83 .format(key + ':', n, min(vals), max(vals))) 

84 else: 

85 print('{:{}} {}' 

86 .format(key + ':', n, 

87 ', '.join(f'{v}({n})' 

88 for v, n in vals.items()))) 

89 return 

90 

91 if args.add_from_file: 

92 filename = args.add_from_file 

93 configs = ase.io.read(filename) 

94 if not isinstance(configs, list): 

95 configs = [configs] 

96 for atoms in configs: 

97 db.write(atoms, key_value_pairs=add_key_value_pairs) 

98 out('Added ' + plural(len(configs), 'row')) 

99 return 

100 

101 if args.count: 

102 n = db.count(query) 

103 print(f'{plural(n, "row")}') 

104 return 

105 

106 if args.insert_into: 

107 if args.limit == -1: 

108 args.limit = 0 

109 

110 progressbar = no_progressbar 

111 length = None 

112 

113 if args.progress_bar: 

114 # Try to import the one from click. 

115 # People using ase.db will most likely have flask installed 

116 # and therfore also click. 

117 try: 

118 from click import progressbar 

119 except ImportError: 

120 pass 

121 else: 

122 length = db.count(query) 

123 

124 nkvp = 0 

125 nrows = 0 

126 with connect(args.insert_into, 

127 use_lock_file=not args.no_lock_file) as db2: 

128 with progressbar(db.select(query, 

129 sort=args.sort, 

130 limit=args.limit, 

131 offset=args.offset), 

132 length=length) as rows: 

133 for row in rows: 

134 kvp = row.get('key_value_pairs', {}) 

135 nkvp -= len(kvp) 

136 kvp.update(add_key_value_pairs) 

137 nkvp += len(kvp) 

138 if args.strip_data: 

139 db2.write(row.toatoms(), **kvp) 

140 else: 

141 db2.write(row, data=row.get('data'), **kvp) 

142 nrows += 1 

143 

144 out('Added %s (%s updated)' % 

145 (plural(nkvp, 'key-value pair'), 

146 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair'))) 

147 out(f'Inserted {plural(nrows, "row")}') 

148 return 

149 

150 if args.limit == -1: 

151 args.limit = 20 

152 

153 if args.explain: 

154 for row in db.select(query, explain=True, 

155 verbosity=verbosity, 

156 limit=args.limit, offset=args.offset): 

157 print(row['explain']) 

158 return 

159 

160 if args.show_metadata: 

161 print(json.dumps(db.metadata, sort_keys=True, indent=4)) 

162 return 

163 

164 if args.set_metadata: 

165 with open(args.set_metadata) as fd: 

166 db.metadata = json.load(fd) 

167 return 

168 

169 if add_key_value_pairs or delete_keys: 

170 ids = [row['id'] for row in db.select(query)] 

171 M = 0 

172 N = 0 

173 with db: 

174 for id in ids: 

175 m, n = db.update(id, delete_keys=delete_keys, 

176 **add_key_value_pairs) 

177 M += m 

178 N += n 

179 out('Added %s (%s updated)' % 

180 (plural(M, 'key-value pair'), 

181 plural(len(add_key_value_pairs) * len(ids) - M, 'pair'))) 

182 out('Removed', plural(N, 'key-value pair')) 

183 

184 return 

185 

186 if args.delete: 

187 ids = [row['id'] for row in db.select(query, include_data=False)] 

188 if ids and not args.yes: 

189 msg = f'Delete {plural(len(ids), "row")}? (yes/No): ' 

190 if input(msg).lower() != 'yes': 

191 return 

192 db.delete(ids) 

193 out(f'Deleted {plural(len(ids), "row")}') 

194 return 

195 

196 if args.plot: 

197 if ':' in args.plot: 

198 tags, keys = args.plot.split(':') 

199 tags = tags.split(',') 

200 else: 

201 tags = [] 

202 keys = args.plot 

203 keys = keys.split(',') 

204 plots = defaultdict(list) 

205 X = {} 

206 labels = [] 

207 for row in db.select(query, sort=args.sort, include_data=False): 

208 name = ','.join(str(row[tag]) for tag in tags) 

209 x = row.get(keys[0]) 

210 if x is not None: 

211 if isinstance(x, str): 

212 if x not in X: 

213 X[x] = len(X) 

214 labels.append(x) 

215 x = X[x] 

216 plots[name].append([x] + [row.get(key) for key in keys[1:]]) 

217 import matplotlib.pyplot as plt 

218 for name, plot in plots.items(): 

219 xyy = list(zip(*plot)) 

220 x = xyy[0] 

221 for y, key in zip(xyy[1:], keys[1:]): 

222 plt.plot(x, y, label=name + ':' + key) 

223 if X: 

224 plt.xticks(range(len(labels)), labels, rotation=90) 

225 plt.legend() 

226 plt.show() 

227 return 

228 

229 if args.json: 

230 row = db.get(query) 

231 db2 = connect(sys.stdout, 'json', use_lock_file=False) 

232 kvp = row.get('key_value_pairs', {}) 

233 db2.write(row, data=row.get('data'), **kvp) 

234 return 

235 

236 if args.long: 

237 row = db.get(query) 

238 print(row2str(row)) 

239 return 

240 

241 if args.open_web_browser: 

242 try: 

243 import flask # noqa 

244 except ImportError: 

245 print('Please install Flask: python3 -m pip install flask') 

246 return 

247 check_jsmol() 

248 import ase.db.app as app 

249 app.DBApp().run_db(db) 

250 return 

251 

252 columns = list(all_columns) 

253 c = args.columns 

254 if c and c.startswith('++'): 

255 keys = set() 

256 for row in db.select(query, 

257 limit=args.limit, offset=args.offset, 

258 include_data=False): 

259 keys.update(row._keys) 

260 columns.extend(keys) 

261 if c[2:3] == ',': 

262 c = c[3:] 

263 else: 

264 c = '' 

265 if c: 

266 if c[0] == '+': 

267 c = c[1:] 

268 elif c[0] != '-': 

269 columns = [] 

270 for col in c.split(','): 

271 if col[0] == '-': 

272 columns.remove(col[1:]) 

273 else: 

274 columns.append(col.lstrip('+')) 

275 

276 table = Table(db, verbosity=verbosity, cut=args.cut) 

277 table.select(query, columns, args.sort, args.limit, args.offset) 

278 if args.csv: 

279 table.write_csv() 

280 else: 

281 table.write(query) 

282 

283 

284def row2str(row) -> str: 

285 t = row2dct(row, key_descriptions={}) 

286 S = [t['formula'] + ':', 

287 'Unit cell in Ang:', 

288 'axis|periodic| x| y| z|' + 

289 ' length| angle'] 

290 c = 1 

291 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' + 

292 '{3:>10}|{4:>10}') 

293 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']): 

294 S.append(fmt.format(c, [' no', 'yes'][p], axis, L, A)) 

295 c += 1 

296 S.append('') 

297 

298 if 'stress' in t: 

299 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:', 

300 ' {}\n'.format(t['stress'])] 

301 

302 if 'dipole' in t: 

303 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole'])) 

304 

305 if 'constraints' in t: 

306 S.append('Constraints: {}\n'.format(t['constraints'])) 

307 

308 if 'data' in t: 

309 S.append('Data: {}\n'.format(t['data'])) 

310 

311 width0 = max(max(len(row[0]) for row in t['table']), 3) 

312 width1 = max(max(len(row[1]) for row in t['table']), 11) 

313 S.append('{:{}} | {:{}} | Value' 

314 .format('Key', width0, 'Description', width1)) 

315 for key, desc, value in t['table']: 

316 S.append('{:{}} | {:{}} | {}' 

317 .format(key, width0, desc, width1, value)) 

318 return '\n'.join(S) 

319 

320 

321@contextmanager 

322def no_progressbar(iterable: Iterable, 

323 length: int = None) -> Iterator[Iterable]: 

324 """A do-nothing implementation.""" 

325 yield iterable 

326 

327 

328def check_jsmol(): 

329 static = Path(__file__).parent / 'static' 

330 if not (static / 'jsmol/JSmol.min.js').is_file(): 

331 print(f""" 

332 WARNING: 

333 You don't have jsmol on your system. 

334 

335 Download Jmol-*-binary.tar.gz from 

336 https://sourceforge.net/projects/jmol/files/Jmol/, 

337 extract jsmol.zip, unzip it and create a soft-link: 

338 

339 $ tar -xf Jmol-*-binary.tar.gz 

340 $ unzip jmol-*/jsmol.zip 

341 $ ln -s $PWD/jsmol {static}/jsmol 

342 """, 

343 file=sys.stderr)