Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import functools 

2import json 

3import numbers 

4import operator 

5import os 

6import re 

7import warnings 

8from time import time 

9from typing import List, Dict, Any 

10 

11import numpy as np 

12 

13from ase.atoms import Atoms 

14from ase.calculators.calculator import all_properties, all_changes 

15from ase.data import atomic_numbers 

16from ase.db.row import AtomsRow 

17from ase.formula import Formula 

18from ase.io.jsonio import create_ase_object 

19from ase.parallel import world, DummyMPI, parallel_function, parallel_generator 

20from ase.utils import Lock, PurePath 

21 

22 

23T2000 = 946681200.0 # January 1. 2000 

24YEAR = 31557600.0 # 365.25 days 

25 

26 

27# Format of key description: ('short', 'long', 'unit') 

28default_key_descriptions = { 

29 'id': ('ID', 'Uniqe row ID', ''), 

30 'age': ('Age', 'Time since creation', ''), 

31 'formula': ('Formula', 'Chemical formula', ''), 

32 'pbc': ('PBC', 'Periodic boundary conditions', ''), 

33 'user': ('Username', '', ''), 

34 'calculator': ('Calculator', 'ASE-calculator name', ''), 

35 'energy': ('Energy', 'Total energy', 'eV'), 

36 'natoms': ('Number of atoms', '', ''), 

37 'fmax': ('Maximum force', '', 'eV/Ang'), 

38 'smax': ('Maximum stress', 'Maximum stress on unit cell', 

39 '`\\text{eV/Ang}^3`'), 

40 'charge': ('Charge', 'Net charge in unit cell', '|e|'), 

41 'mass': ('Mass', 'Sum of atomic masses in unit cell', 'au'), 

42 'magmom': ('Magnetic moment', '', 'au'), 

43 'unique_id': ('Unique ID', 'Random (unique) ID', ''), 

44 'volume': ('Volume', 'Volume of unit cell', '`\\text{Ang}^3`')} 

45 

46 

47def now(): 

48 """Return time since January 1. 2000 in years.""" 

49 return (time() - T2000) / YEAR 

50 

51 

52seconds = {'s': 1, 

53 'm': 60, 

54 'h': 3600, 

55 'd': 86400, 

56 'w': 604800, 

57 'M': 2629800, 

58 'y': YEAR} 

59 

60longwords = {'s': 'second', 

61 'm': 'minute', 

62 'h': 'hour', 

63 'd': 'day', 

64 'w': 'week', 

65 'M': 'month', 

66 'y': 'year'} 

67 

68ops = {'<': operator.lt, 

69 '<=': operator.le, 

70 '=': operator.eq, 

71 '>=': operator.ge, 

72 '>': operator.gt, 

73 '!=': operator.ne} 

74 

75invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='} 

76 

77word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$') 

78 

79reserved_keys = set(all_properties + 

80 all_changes + 

81 list(atomic_numbers) + 

82 ['id', 'unique_id', 'ctime', 'mtime', 'user', 

83 'fmax', 'smax', 

84 'momenta', 'constraints', 'natoms', 'formula', 'age', 

85 'calculator', 'calculator_parameters', 

86 'key_value_pairs', 'data']) 

87 

88numeric_keys = set(['id', 'energy', 'magmom', 'charge', 'natoms']) 

89 

90 

91def check(key_value_pairs): 

92 for key, value in key_value_pairs.items(): 

93 if key == "external_tables": 

94 # Checks for external_tables are not 

95 # performed 

96 continue 

97 

98 if not word.match(key) or key in reserved_keys: 

99 raise ValueError('Bad key: {}'.format(key)) 

100 try: 

101 Formula(key, strict=True) 

102 except ValueError: 

103 pass 

104 else: 

105 warnings.warn( 

106 'It is best not to use keys ({0}) that are also a ' 

107 'chemical formula. If you do a "db.select({0!r})",' 

108 'you will not find rows with your key. Instead, you wil get ' 

109 'rows containing the atoms in the formula!'.format(key)) 

110 if not isinstance(value, (numbers.Real, str, np.bool_)): 

111 raise ValueError('Bad value for {!r}: {}'.format(key, value)) 

112 if isinstance(value, str): 

113 for t in [int, float]: 

114 if str_represents(value, t): 

115 raise ValueError( 

116 'Value ' + value + ' is put in as string ' + 

117 'but can be interpreted as ' + 

118 '{}! Please convert '.format(t.__name__) + 

119 'to {} using '.format(t.__name__) + 

120 '{}(value) before '.format(t.__name__) + 

121 'writing to the database OR change ' + 

122 'to a different string.') 

123 

124 

125def str_represents(value, t=int): 

126 try: 

127 t(value) 

128 except ValueError: 

129 return False 

130 return True 

131 

132 

133def connect(name, type='extract_from_name', create_indices=True, 

134 use_lock_file=True, append=True, serial=False): 

135 """Create connection to database. 

136 

137 name: str 

138 Filename or address of database. 

139 type: str 

140 One of 'json', 'db', 'postgresql', 

141 (JSON, SQLite, PostgreSQL). 

142 Default is 'extract_from_name', which will guess the type 

143 from the name. 

144 use_lock_file: bool 

145 You can turn this off if you know what you are doing ... 

146 append: bool 

147 Use append=False to start a new database. 

148 """ 

149 

150 if isinstance(name, PurePath): 

151 name = str(name) 

152 

153 if type == 'extract_from_name': 

154 if name is None: 

155 type = None 

156 elif not isinstance(name, str): 

157 type = 'json' 

158 elif (name.startswith('postgresql://') or 

159 name.startswith('postgres://')): 

160 type = 'postgresql' 

161 elif name.startswith('mysql://') or name.startswith('mariadb://'): 

162 type = 'mysql' 

163 else: 

164 type = os.path.splitext(name)[1][1:] 

165 if type == '': 

166 raise ValueError('No file extension or database type given') 

167 

168 if type is None: 

169 return Database() 

170 

171 if not append and world.rank == 0: 

172 if isinstance(name, str) and os.path.isfile(name): 

173 os.remove(name) 

174 

175 if type not in ['postgresql', 'mysql'] and isinstance(name, str): 

176 name = os.path.abspath(name) 

177 

178 if type == 'json': 

179 from ase.db.jsondb import JSONDatabase 

180 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial) 

181 if type == 'db': 

182 from ase.db.sqlite import SQLite3Database 

183 return SQLite3Database(name, create_indices, use_lock_file, 

184 serial=serial) 

185 if type == 'postgresql': 

186 from ase.db.postgresql import PostgreSQLDatabase 

187 return PostgreSQLDatabase(name) 

188 

189 if type == 'mysql': 

190 from ase.db.mysql import MySQLDatabase 

191 return MySQLDatabase(name) 

192 raise ValueError('Unknown database type: ' + type) 

193 

194 

195def lock(method): 

196 """Decorator for using a lock-file.""" 

197 @functools.wraps(method) 

198 def new_method(self, *args, **kwargs): 

199 if self.lock is None: 

200 return method(self, *args, **kwargs) 

201 else: 

202 with self.lock: 

203 return method(self, *args, **kwargs) 

204 return new_method 

205 

206 

207def convert_str_to_int_float_or_str(value): 

208 """Safe eval()""" 

209 try: 

210 return int(value) 

211 except ValueError: 

212 try: 

213 value = float(value) 

214 except ValueError: 

215 value = {'True': True, 'False': False}.get(value, value) 

216 return value 

217 

218 

219def parse_selection(selection, **kwargs): 

220 if selection is None or selection == '': 

221 expressions = [] 

222 elif isinstance(selection, int): 

223 expressions = [('id', '=', selection)] 

224 elif isinstance(selection, list): 

225 expressions = selection 

226 else: 

227 expressions = [w.strip() for w in selection.split(',')] 

228 keys = [] 

229 comparisons = [] 

230 for expression in expressions: 

231 if isinstance(expression, (list, tuple)): 

232 comparisons.append(expression) 

233 continue 

234 if expression.count('<') == 2: 

235 value, expression = expression.split('<', 1) 

236 if expression[0] == '=': 

237 op = '>=' 

238 expression = expression[1:] 

239 else: 

240 op = '>' 

241 key = expression.split('<', 1)[0] 

242 comparisons.append((key, op, value)) 

243 for op in ['!=', '<=', '>=', '<', '>', '=']: 

244 if op in expression: 

245 break 

246 else: # no break 

247 if expression in atomic_numbers: 

248 comparisons.append((expression, '>', 0)) 

249 else: 

250 try: 

251 count = Formula(expression).count() 

252 except ValueError: 

253 keys.append(expression) 

254 else: 

255 comparisons.extend((symbol, '>', n - 1) 

256 for symbol, n in count.items()) 

257 continue 

258 key, value = expression.split(op) 

259 comparisons.append((key, op, value)) 

260 

261 cmps = [] 

262 for key, value in kwargs.items(): 

263 comparisons.append((key, '=', value)) 

264 

265 for key, op, value in comparisons: 

266 if key == 'age': 

267 key = 'ctime' 

268 op = invop[op] 

269 value = now() - time_string_to_float(value) 

270 elif key == 'formula': 

271 if op != '=': 

272 raise ValueError('Use fomula=...') 

273 f = Formula(value) 

274 count = f.count() 

275 cmps.extend((atomic_numbers[symbol], '=', n) 

276 for symbol, n in count.items()) 

277 key = 'natoms' 

278 value = len(f) 

279 elif key in atomic_numbers: 

280 key = atomic_numbers[key] 

281 value = int(value) 

282 elif isinstance(value, str): 

283 value = convert_str_to_int_float_or_str(value) 

284 if key in numeric_keys and not isinstance(value, (int, float)): 

285 msg = 'Wrong type for "{}{}{}" - must be a number' 

286 raise ValueError(msg.format(key, op, value)) 

287 cmps.append((key, op, value)) 

288 

289 return keys, cmps 

290 

291 

292class Database: 

293 """Base class for all databases.""" 

294 def __init__(self, filename=None, create_indices=True, 

295 use_lock_file=False, serial=False): 

296 """Database object. 

297 

298 serial: bool 

299 Let someone else handle parallelization. Default behavior is 

300 to interact with the database on the master only and then 

301 distribute results to all slaves. 

302 """ 

303 if isinstance(filename, str): 

304 filename = os.path.expanduser(filename) 

305 self.filename = filename 

306 self.create_indices = create_indices 

307 if use_lock_file and isinstance(filename, str): 

308 self.lock = Lock(filename + '.lock', world=DummyMPI()) 

309 else: 

310 self.lock = None 

311 self.serial = serial 

312 

313 # Decription of columns and other stuff: 

314 self._metadata: Dict[str, Any] = None 

315 

316 @property 

317 def metadata(self) -> Dict[str, Any]: 

318 raise NotImplementedError 

319 

320 @parallel_function 

321 @lock 

322 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs): 

323 """Write atoms to database with key-value pairs. 

324 

325 atoms: Atoms object 

326 Write atomic numbers, positions, unit cell and boundary 

327 conditions. If a calculator is attached, write also already 

328 calculated properties such as the energy and forces. 

329 key_value_pairs: dict 

330 Dictionary of key-value pairs. Values must be strings or numbers. 

331 data: dict 

332 Extra stuff (not for searching). 

333 id: int 

334 Overwrite existing row. 

335 

336 Key-value pairs can also be set using keyword arguments:: 

337 

338 connection.write(atoms, name='ABC', frequency=42.0) 

339 

340 Returns integer id of the new row. 

341 """ 

342 

343 if atoms is None: 

344 atoms = Atoms() 

345 

346 kvp = dict(key_value_pairs) # modify a copy 

347 kvp.update(kwargs) 

348 

349 id = self._write(atoms, kvp, data, id) 

350 return id 

351 

352 def _write(self, atoms, key_value_pairs, data, id=None): 

353 check(key_value_pairs) 

354 return 1 

355 

356 @parallel_function 

357 @lock 

358 def reserve(self, **key_value_pairs): 

359 """Write empty row if not already present. 

360 

361 Usage:: 

362 

363 id = conn.reserve(key1=value1, key2=value2, ...) 

364 

365 Write an empty row with the given key-value pairs and 

366 return the integer id. If such a row already exists, don't write 

367 anything and return None. 

368 """ 

369 

370 for dct in self._select([], 

371 [(key, '=', value) 

372 for key, value in key_value_pairs.items()]): 

373 return None 

374 

375 atoms = Atoms() 

376 

377 calc_name = key_value_pairs.pop('calculator', None) 

378 

379 if calc_name: 

380 # Allow use of calculator key 

381 assert calc_name.lower() == calc_name 

382 

383 # Fake calculator class: 

384 class Fake: 

385 name = calc_name 

386 

387 def todict(self): 

388 return {} 

389 

390 def check_state(self, atoms): 

391 return ['positions'] 

392 

393 atoms.calc = Fake() 

394 

395 id = self._write(atoms, key_value_pairs, {}, None) 

396 

397 return id 

398 

399 def __delitem__(self, id): 

400 self.delete([id]) 

401 

402 def get_atoms(self, selection=None, attach_calculator=False, 

403 add_additional_information=False, **kwargs): 

404 """Get Atoms object. 

405 

406 selection: int, str or list 

407 See the select() method. 

408 attach_calculator: bool 

409 Attach calculator object to Atoms object (default value is 

410 False). 

411 add_additional_information: bool 

412 Put key-value pairs and data into Atoms.info dictionary. 

413 

414 In addition, one can use keyword arguments to select specific 

415 key-value pairs. 

416 """ 

417 

418 row = self.get(selection, **kwargs) 

419 return row.toatoms(attach_calculator, add_additional_information) 

420 

421 def __getitem__(self, selection): 

422 return self.get(selection) 

423 

424 def get(self, selection=None, **kwargs): 

425 """Select a single row and return it as a dictionary. 

426 

427 selection: int, str or list 

428 See the select() method. 

429 """ 

430 rows = list(self.select(selection, limit=2, **kwargs)) 

431 if not rows: 

432 raise KeyError('no match') 

433 assert len(rows) == 1, 'more than one row matched' 

434 return rows[0] 

435 

436 @parallel_generator 

437 def select(self, selection=None, filter=None, explain=False, 

438 verbosity=1, limit=None, offset=0, sort=None, 

439 include_data=True, columns='all', **kwargs): 

440 """Select rows. 

441 

442 Return AtomsRow iterator with results. Selection is done 

443 using key-value pairs and the special keys: 

444 

445 formula, age, user, calculator, natoms, energy, magmom 

446 and/or charge. 

447 

448 selection: int, str or list 

449 Can be: 

450 

451 * an integer id 

452 * a string like 'key=value', where '=' can also be one of 

453 '<=', '<', '>', '>=' or '!='. 

454 * a string like 'key' 

455 * comma separated strings like 'key1<value1,key2=value2,key' 

456 * list of strings or tuples: [('charge', '=', 1)]. 

457 filter: function 

458 A function that takes as input a row and returns True or False. 

459 explain: bool 

460 Explain query plan. 

461 verbosity: int 

462 Possible values: 0, 1 or 2. 

463 limit: int or None 

464 Limit selection. 

465 offset: int 

466 Offset into selected rows. 

467 sort: str 

468 Sort rows after key. Prepend with minus sign for a decending sort. 

469 include_data: bool 

470 Use include_data=False to skip reading data from rows. 

471 columns: 'all' or list of str 

472 Specify which columns from the SQL table to include. 

473 For example, if only the row id and the energy is needed, 

474 queries can be speeded up by setting columns=['id', 'energy']. 

475 """ 

476 

477 if sort: 

478 if sort == 'age': 

479 sort = '-ctime' 

480 elif sort == '-age': 

481 sort = 'ctime' 

482 elif sort.lstrip('-') == 'user': 

483 sort += 'name' 

484 

485 keys, cmps = parse_selection(selection, **kwargs) 

486 for row in self._select(keys, cmps, explain=explain, 

487 verbosity=verbosity, 

488 limit=limit, offset=offset, sort=sort, 

489 include_data=include_data, 

490 columns=columns): 

491 if filter is None or filter(row): 

492 yield row 

493 

494 def count(self, selection=None, **kwargs): 

495 """Count rows. 

496 

497 See the select() method for the selection syntax. Use db.count() or 

498 len(db) to count all rows. 

499 """ 

500 n = 0 

501 for row in self.select(selection, **kwargs): 

502 n += 1 

503 return n 

504 

505 def __len__(self): 

506 return self.count() 

507 

508 @parallel_function 

509 @lock 

510 def update(self, id, atoms=None, delete_keys=[], data=None, 

511 **add_key_value_pairs): 

512 """Update and/or delete key-value pairs of row(s). 

513 

514 id: int 

515 ID of row to update. 

516 atoms: Atoms object 

517 Optionally update the Atoms data (positions, cell, ...). 

518 data: dict 

519 Data dict to be added to the existing data. 

520 delete_keys: list of str 

521 Keys to remove. 

522 

523 Use keyword arguments to add new key-value pairs. 

524 

525 Returns number of key-value pairs added and removed. 

526 """ 

527 

528 if not isinstance(id, numbers.Integral): 

529 if isinstance(id, list): 

530 err = ('First argument must be an int and not a list.\n' 

531 'Do something like this instead:\n\n' 

532 'with db:\n' 

533 ' for id in ids:\n' 

534 ' db.update(id, ...)') 

535 raise ValueError(err) 

536 raise TypeError('id must be an int') 

537 

538 check(add_key_value_pairs) 

539 

540 row = self._get_row(id) 

541 kvp = row.key_value_pairs 

542 

543 n = len(kvp) 

544 for key in delete_keys: 

545 kvp.pop(key, None) 

546 n -= len(kvp) 

547 m = -len(kvp) 

548 kvp.update(add_key_value_pairs) 

549 m += len(kvp) 

550 

551 moredata = data 

552 data = row.get('data', {}) 

553 if moredata: 

554 data.update(moredata) 

555 if not data: 

556 data = None 

557 

558 if atoms: 

559 oldrow = row 

560 row = AtomsRow(atoms) 

561 # Copy over data, kvp, ctime, user and id 

562 row._data = oldrow._data 

563 row.__dict__.update(kvp) 

564 row._keys = list(kvp) 

565 row.ctime = oldrow.ctime 

566 row.user = oldrow.user 

567 row.id = id 

568 

569 if atoms or os.path.splitext(self.filename)[1] == '.json': 

570 self._write(row, kvp, data, row.id) 

571 else: 

572 self._update(row.id, kvp, data) 

573 return m, n 

574 

575 def delete(self, ids): 

576 """Delete rows.""" 

577 raise NotImplementedError 

578 

579 

580def time_string_to_float(s): 

581 if isinstance(s, (float, int)): 

582 return s 

583 s = s.replace(' ', '') 

584 if '+' in s: 

585 return sum(time_string_to_float(x) for x in s.split('+')) 

586 if s[-2].isalpha() and s[-1] == 's': 

587 s = s[:-1] 

588 i = 1 

589 while s[i].isdigit(): 

590 i += 1 

591 return seconds[s[i:]] * int(s[:i]) / YEAR 

592 

593 

594def float_to_time_string(t, long=False): 

595 t *= YEAR 

596 for s in 'yMwdhms': 

597 x = t / seconds[s] 

598 if x > 5: 

599 break 

600 if long: 

601 return '{:.3f} {}s'.format(x, longwords[s]) 

602 else: 

603 return '{:.0f}{}'.format(round(x), s) 

604 

605 

606def object_to_bytes(obj: Any) -> bytes: 

607 """Serialize Python object to bytes.""" 

608 parts = [b'12345678'] 

609 obj = o2b(obj, parts) 

610 offset = sum(len(part) for part in parts) 

611 x = np.array(offset, np.int64) 

612 if not np.little_endian: 

613 x.byteswap(True) 

614 parts[0] = x.tobytes() 

615 parts.append(json.dumps(obj, separators=(',', ':')).encode()) 

616 return b''.join(parts) 

617 

618 

619def bytes_to_object(b: bytes) -> Any: 

620 """Deserialize bytes to Python object.""" 

621 x = np.frombuffer(b[:8], np.int64) 

622 if not np.little_endian: 

623 x = x.byteswap() 

624 offset = x.item() 

625 obj = json.loads(b[offset:].decode()) 

626 return b2o(obj, b) 

627 

628 

629def o2b(obj: Any, parts: List[bytes]): 

630 if isinstance(obj, (int, float, bool, str, type(None))): 

631 return obj 

632 if isinstance(obj, dict): 

633 return {key: o2b(value, parts) for key, value in obj.items()} 

634 if isinstance(obj, (list, tuple)): 

635 return [o2b(value, parts) for value in obj] 

636 if isinstance(obj, np.ndarray): 

637 assert obj.dtype != object, \ 

638 'Cannot convert ndarray of type "object" to bytes.' 

639 offset = sum(len(part) for part in parts) 

640 if not np.little_endian: 

641 obj = obj.byteswap() 

642 parts.append(obj.tobytes()) 

643 return {'__ndarray__': [obj.shape, 

644 obj.dtype.name, 

645 offset]} 

646 if isinstance(obj, complex): 

647 return {'__complex__': [obj.real, obj.imag]} 

648 objtype = getattr(obj, 'ase_objtype') 

649 if objtype: 

650 dct = o2b(obj.todict(), parts) 

651 dct['__ase_objtype__'] = objtype 

652 return dct 

653 raise ValueError('Objects of type {type} not allowed' 

654 .format(type=type(obj))) 

655 

656 

657def b2o(obj: Any, b: bytes) -> Any: 

658 if isinstance(obj, (int, float, bool, str, type(None))): 

659 return obj 

660 

661 if isinstance(obj, list): 

662 return [b2o(value, b) for value in obj] 

663 

664 assert isinstance(obj, dict) 

665 

666 x = obj.get('__complex__') 

667 if x is not None: 

668 return complex(*x) 

669 

670 x = obj.get('__ndarray__') 

671 if x is not None: 

672 shape, name, offset = x 

673 dtype = np.dtype(name) 

674 size = dtype.itemsize * np.prod(shape).astype(int) 

675 a = np.frombuffer(b[offset:offset + size], dtype) 

676 a.shape = shape 

677 if not np.little_endian: 

678 a = a.byteswap() 

679 return a 

680 

681 dct = {key: b2o(value, b) for key, value in obj.items()} 

682 objtype = dct.pop('__ase_objtype__', None) 

683 if objtype is None: 

684 return dct 

685 return create_ase_object(objtype, dct)