Coverage for /builds/debichem-team/python-ase/ase/db/core.py: 85.57%

388 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import functools 

2import json 

3import numbers 

4import operator 

5import os 

6import re 

7import warnings 

8from time import time 

9from typing import Any, Dict, List 

10 

11import numpy as np 

12 

13from ase.atoms import Atoms 

14from ase.calculators.calculator import all_changes, all_properties 

15from ase.data import atomic_numbers 

16from ase.db.row import AtomsRow 

17from ase.formula import Formula 

18from ase.io.jsonio import create_ase_object 

19from ase.parallel import DummyMPI, parallel_function, parallel_generator, world 

20from ase.utils import Lock, PurePath 

21 

22T2000 = 946681200.0 # January 1. 2000 

23YEAR = 31557600.0 # 365.25 days 

24 

25 

26@functools.total_ordering 

27class KeyDescription: 

28 _subscript = re.compile(r'`(.)_(.)`') 

29 _superscript = re.compile(r'`(.*)\^\{?(.*?)\}?`') 

30 

31 def __init__(self, key, shortdesc=None, longdesc=None, unit=''): 

32 self.key = key 

33 

34 if shortdesc is None: 

35 shortdesc = key 

36 

37 if longdesc is None: 

38 longdesc = shortdesc 

39 

40 self.shortdesc = shortdesc 

41 self.longdesc = longdesc 

42 

43 # Somewhat arbitrary that we do this conversion. Can we avoid that? 

44 # Previously done in create_key_descriptions(). 

45 unit = self._subscript.sub(r'\1<sub>\2</sub>', unit) 

46 unit = self._superscript.sub(r'\1<sup>\2</sup>', unit) 

47 unit = unit.replace(r'\text{', '').replace('}', '') 

48 

49 self.unit = unit 

50 

51 def __repr__(self): 

52 cls = type(self).__name__ 

53 return (f'{cls}({self.key!r}, {self.shortdesc!r}, {self.longdesc!r}, ' 

54 f'unit={self.unit!r})') 

55 

56 # The templates like to sort key descriptions by shortdesc. 

57 def __eq__(self, other): 

58 return self.shortdesc == getattr(other, 'shortdesc', None) 

59 

60 def __lt__(self, other): 

61 return self.shortdesc < getattr(other, 'shortdesc', self.shortdesc) 

62 

63 

64def get_key_descriptions(): 

65 KD = KeyDescription 

66 return {keydesc.key: keydesc for keydesc in [ 

67 KD('id', 'ID', 'Uniqe row ID'), 

68 KD('age', 'Age', 'Time since creation'), 

69 KD('formula', 'Formula', 'Chemical formula'), 

70 KD('pbc', 'PBC', 'Periodic boundary conditions'), 

71 KD('user', 'Username'), 

72 KD('calculator', 'Calculator', 'ASE-calculator name'), 

73 KD('energy', 'Energy', 'Total energy', unit='eV'), 

74 KD('natoms', 'Number of atoms'), 

75 KD('fmax', 'Maximum force', unit='eV/Å'), 

76 KD('smax', 'Maximum stress', 'Maximum stress on unit cell', 

77 unit='eV/ų'), 

78 KD('charge', 'Charge', 'Net charge in unit cell', unit='|e|'), 

79 KD('mass', 'Mass', 'Sum of atomic masses in unit cell', unit='au'), 

80 KD('magmom', 'Magnetic moment', unit='μ_B'), 

81 KD('unique_id', 'Unique ID', 'Random (unique) ID'), 

82 KD('volume', 'Volume', 'Volume of unit cell', unit='ų') 

83 ]} 

84 

85 

86def now(): 

87 """Return time since January 1. 2000 in years.""" 

88 return (time() - T2000) / YEAR 

89 

90 

91seconds = {'s': 1, 

92 'm': 60, 

93 'h': 3600, 

94 'd': 86400, 

95 'w': 604800, 

96 'M': 2629800, 

97 'y': YEAR} 

98 

99longwords = {'s': 'second', 

100 'm': 'minute', 

101 'h': 'hour', 

102 'd': 'day', 

103 'w': 'week', 

104 'M': 'month', 

105 'y': 'year'} 

106 

107ops = {'<': operator.lt, 

108 '<=': operator.le, 

109 '=': operator.eq, 

110 '>=': operator.ge, 

111 '>': operator.gt, 

112 '!=': operator.ne} 

113 

114invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='} 

115 

116word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$') 

117 

118reserved_keys = set(all_properties + 

119 all_changes + 

120 list(atomic_numbers) + 

121 ['id', 'unique_id', 'ctime', 'mtime', 'user', 

122 'fmax', 'smax', 

123 'momenta', 'constraints', 'natoms', 'formula', 'age', 

124 'calculator', 'calculator_parameters', 

125 'key_value_pairs', 'data']) 

126 

127numeric_keys = {'id', 'energy', 'magmom', 'charge', 'natoms'} 

128 

129 

130def check(key_value_pairs): 

131 for key, value in key_value_pairs.items(): 

132 if key == "external_tables": 

133 # Checks for external_tables are not 

134 # performed 

135 continue 

136 

137 if not word.match(key) or key in reserved_keys: 

138 raise ValueError(f'Bad key: {key}') 

139 try: 

140 Formula(key, strict=True) 

141 except ValueError: 

142 pass 

143 else: 

144 warnings.warn( 

145 'It is best not to use keys ({0}) that are also a ' 

146 'chemical formula. If you do a "db.select({0!r})",' 

147 'you will not find rows with your key. Instead, you wil get ' 

148 'rows containing the atoms in the formula!'.format(key)) 

149 if not isinstance(value, (numbers.Real, str, np.bool_)): 

150 raise ValueError(f'Bad value for {key!r}: {value}') 

151 if isinstance(value, str): 

152 for t in [bool, int, float]: 

153 if str_represents(value, t): 

154 raise ValueError( 

155 'Value ' + value + ' is put in as string ' + 

156 'but can be interpreted as ' + 

157 f'{t.__name__}! Please convert ' + 

158 f'to {t.__name__} before ' + 

159 'writing to the database OR change ' + 

160 'to a different string.') 

161 

162 

163def str_represents(value, t=int): 

164 new_value = convert_str_to_int_float_bool_or_str(value) 

165 return isinstance(new_value, t) 

166 

167 

168def connect(name, type='extract_from_name', create_indices=True, 

169 use_lock_file=True, append=True, serial=False): 

170 """Create connection to database. 

171 

172 name: str 

173 Filename or address of database. 

174 type: str 

175 One of 'json', 'db', 'postgresql', 

176 (JSON, SQLite, PostgreSQL). 

177 Default is 'extract_from_name', which will guess the type 

178 from the name. 

179 use_lock_file: bool 

180 You can turn this off if you know what you are doing ... 

181 append: bool 

182 Use append=False to start a new database. 

183 """ 

184 

185 if isinstance(name, PurePath): 

186 name = str(name) 

187 

188 if type == 'extract_from_name': 

189 if name is None: 

190 type = None 

191 elif not isinstance(name, str): 

192 type = 'json' 

193 elif (name.startswith('postgresql://') or 

194 name.startswith('postgres://')): 

195 type = 'postgresql' 

196 elif name.startswith('mysql://') or name.startswith('mariadb://'): 

197 type = 'mysql' 

198 else: 

199 type = os.path.splitext(name)[1][1:] 

200 if type == '': 

201 raise ValueError('No file extension or database type given') 

202 

203 if type is None: 

204 return Database() 

205 

206 if not append and world.rank == 0: 

207 if isinstance(name, str) and os.path.isfile(name): 

208 os.remove(name) 

209 

210 if type not in ['postgresql', 'mysql'] and isinstance(name, str): 

211 name = os.path.abspath(name) 

212 

213 if type == 'json': 

214 from ase.db.jsondb import JSONDatabase 

215 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial) 

216 if type == 'db': 

217 from ase.db.sqlite import SQLite3Database 

218 return SQLite3Database(name, create_indices, use_lock_file, 

219 serial=serial) 

220 if type == 'postgresql': 

221 from ase.db.postgresql import PostgreSQLDatabase 

222 return PostgreSQLDatabase(name) 

223 

224 if type == 'mysql': 

225 from ase.db.mysql import MySQLDatabase 

226 return MySQLDatabase(name) 

227 raise ValueError('Unknown database type: ' + type) 

228 

229 

230def lock(method): 

231 """Decorator for using a lock-file.""" 

232 @functools.wraps(method) 

233 def new_method(self, *args, **kwargs): 

234 if self.lock is None: 

235 return method(self, *args, **kwargs) 

236 else: 

237 with self.lock: 

238 return method(self, *args, **kwargs) 

239 return new_method 

240 

241 

242def convert_str_to_int_float_bool_or_str(value): 

243 """Safe eval()""" 

244 try: 

245 return int(value) 

246 except ValueError: 

247 try: 

248 value = float(value) 

249 except ValueError: 

250 value = {'True': True, 'False': False}.get(value, value) 

251 return value 

252 

253 

254def parse_selection(selection, **kwargs): 

255 if selection is None or selection == '': 

256 expressions = [] 

257 elif isinstance(selection, int): 

258 expressions = [('id', '=', selection)] 

259 elif isinstance(selection, list): 

260 expressions = selection 

261 else: 

262 expressions = [w.strip() for w in selection.split(',')] 

263 keys = [] 

264 comparisons = [] 

265 for expression in expressions: 

266 if isinstance(expression, (list, tuple)): 

267 comparisons.append(expression) 

268 continue 

269 if expression.count('<') == 2: 

270 value, expression = expression.split('<', 1) 

271 if expression[0] == '=': 

272 op = '>=' 

273 expression = expression[1:] 

274 else: 

275 op = '>' 

276 key = expression.split('<', 1)[0] 

277 comparisons.append((key, op, value)) 

278 for op in ['!=', '<=', '>=', '<', '>', '=']: 

279 if op in expression: 

280 break 

281 else: # no break 

282 if expression in atomic_numbers: 

283 comparisons.append((expression, '>', 0)) 

284 else: 

285 try: 

286 count = Formula(expression).count() 

287 except ValueError: 

288 keys.append(expression) 

289 else: 

290 comparisons.extend((symbol, '>', n - 1) 

291 for symbol, n in count.items()) 

292 continue 

293 key, value = expression.split(op) 

294 comparisons.append((key, op, value)) 

295 

296 cmps = [] 

297 for key, value in kwargs.items(): 

298 comparisons.append((key, '=', value)) 

299 

300 for key, op, value in comparisons: 

301 if key == 'age': 

302 key = 'ctime' 

303 op = invop[op] 

304 value = now() - time_string_to_float(value) 

305 elif key == 'formula': 

306 if op != '=': 

307 raise ValueError('Use fomula=...') 

308 f = Formula(value) 

309 count = f.count() 

310 cmps.extend((atomic_numbers[symbol], '=', n) 

311 for symbol, n in count.items()) 

312 key = 'natoms' 

313 value = len(f) 

314 elif key in atomic_numbers: 

315 key = atomic_numbers[key] 

316 value = int(value) 

317 elif isinstance(value, str): 

318 value = convert_str_to_int_float_bool_or_str(value) 

319 if key in numeric_keys and not isinstance(value, (int, float)): 

320 msg = 'Wrong type for "{}{}{}" - must be a number' 

321 raise ValueError(msg.format(key, op, value)) 

322 cmps.append((key, op, value)) 

323 

324 return keys, cmps 

325 

326 

327class Database: 

328 """Base class for all databases.""" 

329 

330 def __init__(self, filename=None, create_indices=True, 

331 use_lock_file=False, serial=False): 

332 """Database object. 

333 

334 serial: bool 

335 Let someone else handle parallelization. Default behavior is 

336 to interact with the database on the master only and then 

337 distribute results to all slaves. 

338 """ 

339 if isinstance(filename, str): 

340 filename = os.path.expanduser(filename) 

341 self.filename = filename 

342 self.create_indices = create_indices 

343 if use_lock_file and isinstance(filename, str): 

344 self.lock = Lock(filename + '.lock', world=DummyMPI()) 

345 else: 

346 self.lock = None 

347 self.serial = serial 

348 

349 # Decription of columns and other stuff: 

350 self._metadata: Dict[str, Any] = None 

351 

352 @property 

353 def metadata(self) -> Dict[str, Any]: 

354 raise NotImplementedError 

355 

356 @parallel_function 

357 @lock 

358 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs): 

359 """Write atoms to database with key-value pairs. 

360 

361 atoms: Atoms object 

362 Write atomic numbers, positions, unit cell and boundary 

363 conditions. If a calculator is attached, write also already 

364 calculated properties such as the energy and forces. 

365 key_value_pairs: dict 

366 Dictionary of key-value pairs. Values must be strings or numbers. 

367 data: dict 

368 Extra stuff (not for searching). 

369 id: int 

370 Overwrite existing row. 

371 

372 Key-value pairs can also be set using keyword arguments:: 

373 

374 connection.write(atoms, name='ABC', frequency=42.0) 

375 

376 Returns integer id of the new row. 

377 """ 

378 

379 if atoms is None: 

380 atoms = Atoms() 

381 

382 kvp = dict(key_value_pairs) # modify a copy 

383 kvp.update(kwargs) 

384 

385 id = self._write(atoms, kvp, data, id) 

386 return id 

387 

388 def _write(self, atoms, key_value_pairs, data, id=None): 

389 check(key_value_pairs) 

390 return 1 

391 

392 @parallel_function 

393 @lock 

394 def reserve(self, **key_value_pairs): 

395 """Write empty row if not already present. 

396 

397 Usage:: 

398 

399 id = conn.reserve(key1=value1, key2=value2, ...) 

400 

401 Write an empty row with the given key-value pairs and 

402 return the integer id. If such a row already exists, don't write 

403 anything and return None. 

404 """ 

405 

406 for _ in self._select([], 

407 [(key, '=', value) 

408 for key, value in key_value_pairs.items()]): 

409 return None 

410 

411 atoms = Atoms() 

412 

413 calc_name = key_value_pairs.pop('calculator', None) 

414 

415 if calc_name: 

416 # Allow use of calculator key 

417 assert calc_name.lower() == calc_name 

418 

419 # Fake calculator class: 

420 class Fake: 

421 name = calc_name 

422 

423 def todict(self): 

424 return {} 

425 

426 def check_state(self, atoms): 

427 return ['positions'] 

428 

429 atoms.calc = Fake() 

430 

431 id = self._write(atoms, key_value_pairs, {}, None) 

432 

433 return id 

434 

435 def __delitem__(self, id): 

436 self.delete([id]) 

437 

438 def get_atoms(self, selection=None, 

439 add_additional_information=False, **kwargs): 

440 """Get Atoms object. 

441 

442 selection: int, str or list 

443 See the select() method. 

444 add_additional_information: bool 

445 Put key-value pairs and data into Atoms.info dictionary. 

446 

447 In addition, one can use keyword arguments to select specific 

448 key-value pairs. 

449 """ 

450 

451 row = self.get(selection, **kwargs) 

452 return row.toatoms(add_additional_information) 

453 

454 def __getitem__(self, selection): 

455 return self.get(selection) 

456 

457 def get(self, selection=None, **kwargs): 

458 """Select a single row and return it as a dictionary. 

459 

460 selection: int, str or list 

461 See the select() method. 

462 """ 

463 rows = list(self.select(selection, limit=2, **kwargs)) 

464 if not rows: 

465 raise KeyError('no match') 

466 assert len(rows) == 1, 'more than one row matched' 

467 return rows[0] 

468 

469 @parallel_generator 

470 def select(self, selection=None, filter=None, explain=False, 

471 verbosity=1, limit=None, offset=0, sort=None, 

472 include_data=True, columns='all', **kwargs): 

473 """Select rows. 

474 

475 Return AtomsRow iterator with results. Selection is done 

476 using key-value pairs and the special keys: 

477 

478 formula, age, user, calculator, natoms, energy, magmom 

479 and/or charge. 

480 

481 selection: int, str or list 

482 Can be: 

483 

484 * an integer id 

485 * a string like 'key=value', where '=' can also be one of 

486 '<=', '<', '>', '>=' or '!='. 

487 * a string like 'key' 

488 * comma separated strings like 'key1<value1,key2=value2,key' 

489 * list of strings or tuples: [('charge', '=', 1)]. 

490 filter: function 

491 A function that takes as input a row and returns True or False. 

492 explain: bool 

493 Explain query plan. 

494 verbosity: int 

495 Possible values: 0, 1 or 2. 

496 limit: int or None 

497 Limit selection. 

498 offset: int 

499 Offset into selected rows. 

500 sort: str 

501 Sort rows after key. Prepend with minus sign for a decending sort. 

502 include_data: bool 

503 Use include_data=False to skip reading data from rows. 

504 columns: 'all' or list of str 

505 Specify which columns from the SQL table to include. 

506 For example, if only the row id and the energy is needed, 

507 queries can be speeded up by setting columns=['id', 'energy']. 

508 """ 

509 

510 if sort: 

511 if sort == 'age': 

512 sort = '-ctime' 

513 elif sort == '-age': 

514 sort = 'ctime' 

515 elif sort.lstrip('-') == 'user': 

516 sort += 'name' 

517 

518 keys, cmps = parse_selection(selection, **kwargs) 

519 for row in self._select(keys, cmps, explain=explain, 

520 verbosity=verbosity, 

521 limit=limit, offset=offset, sort=sort, 

522 include_data=include_data, 

523 columns=columns): 

524 if filter is None or filter(row): 

525 yield row 

526 

527 def count(self, selection=None, **kwargs): 

528 """Count rows. 

529 

530 See the select() method for the selection syntax. Use db.count() or 

531 len(db) to count all rows. 

532 """ 

533 n = 0 

534 for _ in self.select(selection, **kwargs): 

535 n += 1 

536 return n 

537 

538 def __len__(self): 

539 return self.count() 

540 

541 @parallel_function 

542 @lock 

543 def update(self, id, atoms=None, delete_keys=[], data=None, 

544 **add_key_value_pairs): 

545 """Update and/or delete key-value pairs of row(s). 

546 

547 id: int 

548 ID of row to update. 

549 atoms: Atoms object 

550 Optionally update the Atoms data (positions, cell, ...). 

551 data: dict 

552 Data dict to be added to the existing data. 

553 delete_keys: list of str 

554 Keys to remove. 

555 

556 Use keyword arguments to add new key-value pairs. 

557 

558 Returns number of key-value pairs added and removed. 

559 """ 

560 

561 if not isinstance(id, numbers.Integral): 

562 if isinstance(id, list): 

563 err = ('First argument must be an int and not a list.\n' 

564 'Do something like this instead:\n\n' 

565 'with db:\n' 

566 ' for id in ids:\n' 

567 ' db.update(id, ...)') 

568 raise ValueError(err) 

569 raise TypeError('id must be an int') 

570 

571 check(add_key_value_pairs) 

572 

573 row = self._get_row(id) 

574 kvp = row.key_value_pairs 

575 

576 n = len(kvp) 

577 for key in delete_keys: 

578 kvp.pop(key, None) 

579 n -= len(kvp) 

580 m = -len(kvp) 

581 kvp.update(add_key_value_pairs) 

582 m += len(kvp) 

583 

584 moredata = data 

585 data = row.get('data', {}) 

586 if moredata: 

587 data.update(moredata) 

588 if not data: 

589 data = None 

590 

591 if atoms: 

592 oldrow = row 

593 row = AtomsRow(atoms) 

594 # Copy over data, kvp, ctime, user and id 

595 row._data = oldrow._data 

596 row.__dict__.update(kvp) 

597 row._keys = list(kvp) 

598 row.ctime = oldrow.ctime 

599 row.user = oldrow.user 

600 row.id = id 

601 

602 if atoms or os.path.splitext(self.filename)[1] == '.json': 

603 self._write(row, kvp, data, row.id) 

604 else: 

605 self._update(row.id, kvp, data) 

606 return m, n 

607 

608 def delete(self, ids): 

609 """Delete rows.""" 

610 raise NotImplementedError 

611 

612 

613def time_string_to_float(s): 

614 if isinstance(s, (float, int)): 

615 return s 

616 s = s.replace(' ', '') 

617 if '+' in s: 

618 return sum(time_string_to_float(x) for x in s.split('+')) 

619 if s[-2].isalpha() and s[-1] == 's': 

620 s = s[:-1] 

621 i = 1 

622 while s[i].isdigit(): 

623 i += 1 

624 return seconds[s[i:]] * int(s[:i]) / YEAR 

625 

626 

627def float_to_time_string(t, long=False): 

628 t *= YEAR 

629 for s in 'yMwdhms': 

630 x = t / seconds[s] 

631 if x > 5: 

632 break 

633 if long: 

634 return f'{x:.3f} {longwords[s]}s' 

635 else: 

636 return f'{round(x):.0f}{s}' 

637 

638 

639def object_to_bytes(obj: Any) -> bytes: 

640 """Serialize Python object to bytes.""" 

641 parts = [b'12345678'] 

642 obj = o2b(obj, parts) 

643 offset = sum(len(part) for part in parts) 

644 x = np.array(offset, np.int64) 

645 if not np.little_endian: 

646 x.byteswap(True) 

647 parts[0] = x.tobytes() 

648 parts.append(json.dumps(obj, separators=(',', ':')).encode()) 

649 return b''.join(parts) 

650 

651 

652def bytes_to_object(b: bytes) -> Any: 

653 """Deserialize bytes to Python object.""" 

654 x = np.frombuffer(b[:8], np.int64) 

655 if not np.little_endian: 

656 x = x.byteswap() 

657 offset = x.item() 

658 obj = json.loads(b[offset:].decode()) 

659 return b2o(obj, b) 

660 

661 

662def o2b(obj: Any, parts: List[bytes]): 

663 if isinstance(obj, (int, float, bool, str, type(None))): 

664 return obj 

665 if isinstance(obj, dict): 

666 return {key: o2b(value, parts) for key, value in obj.items()} 

667 if isinstance(obj, (list, tuple)): 

668 return [o2b(value, parts) for value in obj] 

669 if isinstance(obj, np.ndarray): 

670 assert obj.dtype != object, \ 

671 'Cannot convert ndarray of type "object" to bytes.' 

672 offset = sum(len(part) for part in parts) 

673 if not np.little_endian: 

674 obj = obj.byteswap() 

675 parts.append(obj.tobytes()) 

676 return {'__ndarray__': [obj.shape, 

677 obj.dtype.name, 

678 offset]} 

679 if isinstance(obj, complex): 

680 return {'__complex__': [obj.real, obj.imag]} 

681 objtype = obj.ase_objtype 

682 if objtype: 

683 dct = o2b(obj.todict(), parts) 

684 dct['__ase_objtype__'] = objtype 

685 return dct 

686 raise ValueError('Objects of type {type} not allowed' 

687 .format(type=type(obj))) 

688 

689 

690def b2o(obj: Any, b: bytes) -> Any: 

691 if isinstance(obj, (int, float, bool, str, type(None))): 

692 return obj 

693 

694 if isinstance(obj, list): 

695 return [b2o(value, b) for value in obj] 

696 

697 assert isinstance(obj, dict) 

698 

699 x = obj.get('__complex__') 

700 if x is not None: 

701 return complex(*x) 

702 

703 x = obj.get('__ndarray__') 

704 if x is not None: 

705 shape, name, offset = x 

706 dtype = np.dtype(name) 

707 size = dtype.itemsize * np.prod(shape).astype(int) 

708 a = np.frombuffer(b[offset:offset + size], dtype) 

709 a.shape = shape 

710 if not np.little_endian: 

711 a = a.byteswap() 

712 return a 

713 

714 dct = {key: b2o(value, b) for key, value in obj.items()} 

715 objtype = dct.pop('__ase_objtype__', None) 

716 if objtype is None: 

717 return dct 

718 return create_ase_object(objtype, dct)