Coverage for /builds/debichem-team/python-ase/ase/db/core.py: 85.57%
388 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import functools
2import json
3import numbers
4import operator
5import os
6import re
7import warnings
8from time import time
9from typing import Any, Dict, List
11import numpy as np
13from ase.atoms import Atoms
14from ase.calculators.calculator import all_changes, all_properties
15from ase.data import atomic_numbers
16from ase.db.row import AtomsRow
17from ase.formula import Formula
18from ase.io.jsonio import create_ase_object
19from ase.parallel import DummyMPI, parallel_function, parallel_generator, world
20from ase.utils import Lock, PurePath
22T2000 = 946681200.0 # January 1. 2000
23YEAR = 31557600.0 # 365.25 days
26@functools.total_ordering
27class KeyDescription:
28 _subscript = re.compile(r'`(.)_(.)`')
29 _superscript = re.compile(r'`(.*)\^\{?(.*?)\}?`')
31 def __init__(self, key, shortdesc=None, longdesc=None, unit=''):
32 self.key = key
34 if shortdesc is None:
35 shortdesc = key
37 if longdesc is None:
38 longdesc = shortdesc
40 self.shortdesc = shortdesc
41 self.longdesc = longdesc
43 # Somewhat arbitrary that we do this conversion. Can we avoid that?
44 # Previously done in create_key_descriptions().
45 unit = self._subscript.sub(r'\1<sub>\2</sub>', unit)
46 unit = self._superscript.sub(r'\1<sup>\2</sup>', unit)
47 unit = unit.replace(r'\text{', '').replace('}', '')
49 self.unit = unit
51 def __repr__(self):
52 cls = type(self).__name__
53 return (f'{cls}({self.key!r}, {self.shortdesc!r}, {self.longdesc!r}, '
54 f'unit={self.unit!r})')
56 # The templates like to sort key descriptions by shortdesc.
57 def __eq__(self, other):
58 return self.shortdesc == getattr(other, 'shortdesc', None)
60 def __lt__(self, other):
61 return self.shortdesc < getattr(other, 'shortdesc', self.shortdesc)
64def get_key_descriptions():
65 KD = KeyDescription
66 return {keydesc.key: keydesc for keydesc in [
67 KD('id', 'ID', 'Uniqe row ID'),
68 KD('age', 'Age', 'Time since creation'),
69 KD('formula', 'Formula', 'Chemical formula'),
70 KD('pbc', 'PBC', 'Periodic boundary conditions'),
71 KD('user', 'Username'),
72 KD('calculator', 'Calculator', 'ASE-calculator name'),
73 KD('energy', 'Energy', 'Total energy', unit='eV'),
74 KD('natoms', 'Number of atoms'),
75 KD('fmax', 'Maximum force', unit='eV/Å'),
76 KD('smax', 'Maximum stress', 'Maximum stress on unit cell',
77 unit='eV/ų'),
78 KD('charge', 'Charge', 'Net charge in unit cell', unit='|e|'),
79 KD('mass', 'Mass', 'Sum of atomic masses in unit cell', unit='au'),
80 KD('magmom', 'Magnetic moment', unit='μ_B'),
81 KD('unique_id', 'Unique ID', 'Random (unique) ID'),
82 KD('volume', 'Volume', 'Volume of unit cell', unit='ų')
83 ]}
86def now():
87 """Return time since January 1. 2000 in years."""
88 return (time() - T2000) / YEAR
91seconds = {'s': 1,
92 'm': 60,
93 'h': 3600,
94 'd': 86400,
95 'w': 604800,
96 'M': 2629800,
97 'y': YEAR}
99longwords = {'s': 'second',
100 'm': 'minute',
101 'h': 'hour',
102 'd': 'day',
103 'w': 'week',
104 'M': 'month',
105 'y': 'year'}
107ops = {'<': operator.lt,
108 '<=': operator.le,
109 '=': operator.eq,
110 '>=': operator.ge,
111 '>': operator.gt,
112 '!=': operator.ne}
114invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='}
116word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$')
118reserved_keys = set(all_properties +
119 all_changes +
120 list(atomic_numbers) +
121 ['id', 'unique_id', 'ctime', 'mtime', 'user',
122 'fmax', 'smax',
123 'momenta', 'constraints', 'natoms', 'formula', 'age',
124 'calculator', 'calculator_parameters',
125 'key_value_pairs', 'data'])
127numeric_keys = {'id', 'energy', 'magmom', 'charge', 'natoms'}
130def check(key_value_pairs):
131 for key, value in key_value_pairs.items():
132 if key == "external_tables":
133 # Checks for external_tables are not
134 # performed
135 continue
137 if not word.match(key) or key in reserved_keys:
138 raise ValueError(f'Bad key: {key}')
139 try:
140 Formula(key, strict=True)
141 except ValueError:
142 pass
143 else:
144 warnings.warn(
145 'It is best not to use keys ({0}) that are also a '
146 'chemical formula. If you do a "db.select({0!r})",'
147 'you will not find rows with your key. Instead, you wil get '
148 'rows containing the atoms in the formula!'.format(key))
149 if not isinstance(value, (numbers.Real, str, np.bool_)):
150 raise ValueError(f'Bad value for {key!r}: {value}')
151 if isinstance(value, str):
152 for t in [bool, int, float]:
153 if str_represents(value, t):
154 raise ValueError(
155 'Value ' + value + ' is put in as string ' +
156 'but can be interpreted as ' +
157 f'{t.__name__}! Please convert ' +
158 f'to {t.__name__} before ' +
159 'writing to the database OR change ' +
160 'to a different string.')
163def str_represents(value, t=int):
164 new_value = convert_str_to_int_float_bool_or_str(value)
165 return isinstance(new_value, t)
168def connect(name, type='extract_from_name', create_indices=True,
169 use_lock_file=True, append=True, serial=False):
170 """Create connection to database.
172 name: str
173 Filename or address of database.
174 type: str
175 One of 'json', 'db', 'postgresql',
176 (JSON, SQLite, PostgreSQL).
177 Default is 'extract_from_name', which will guess the type
178 from the name.
179 use_lock_file: bool
180 You can turn this off if you know what you are doing ...
181 append: bool
182 Use append=False to start a new database.
183 """
185 if isinstance(name, PurePath):
186 name = str(name)
188 if type == 'extract_from_name':
189 if name is None:
190 type = None
191 elif not isinstance(name, str):
192 type = 'json'
193 elif (name.startswith('postgresql://') or
194 name.startswith('postgres://')):
195 type = 'postgresql'
196 elif name.startswith('mysql://') or name.startswith('mariadb://'):
197 type = 'mysql'
198 else:
199 type = os.path.splitext(name)[1][1:]
200 if type == '':
201 raise ValueError('No file extension or database type given')
203 if type is None:
204 return Database()
206 if not append and world.rank == 0:
207 if isinstance(name, str) and os.path.isfile(name):
208 os.remove(name)
210 if type not in ['postgresql', 'mysql'] and isinstance(name, str):
211 name = os.path.abspath(name)
213 if type == 'json':
214 from ase.db.jsondb import JSONDatabase
215 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial)
216 if type == 'db':
217 from ase.db.sqlite import SQLite3Database
218 return SQLite3Database(name, create_indices, use_lock_file,
219 serial=serial)
220 if type == 'postgresql':
221 from ase.db.postgresql import PostgreSQLDatabase
222 return PostgreSQLDatabase(name)
224 if type == 'mysql':
225 from ase.db.mysql import MySQLDatabase
226 return MySQLDatabase(name)
227 raise ValueError('Unknown database type: ' + type)
230def lock(method):
231 """Decorator for using a lock-file."""
232 @functools.wraps(method)
233 def new_method(self, *args, **kwargs):
234 if self.lock is None:
235 return method(self, *args, **kwargs)
236 else:
237 with self.lock:
238 return method(self, *args, **kwargs)
239 return new_method
242def convert_str_to_int_float_bool_or_str(value):
243 """Safe eval()"""
244 try:
245 return int(value)
246 except ValueError:
247 try:
248 value = float(value)
249 except ValueError:
250 value = {'True': True, 'False': False}.get(value, value)
251 return value
254def parse_selection(selection, **kwargs):
255 if selection is None or selection == '':
256 expressions = []
257 elif isinstance(selection, int):
258 expressions = [('id', '=', selection)]
259 elif isinstance(selection, list):
260 expressions = selection
261 else:
262 expressions = [w.strip() for w in selection.split(',')]
263 keys = []
264 comparisons = []
265 for expression in expressions:
266 if isinstance(expression, (list, tuple)):
267 comparisons.append(expression)
268 continue
269 if expression.count('<') == 2:
270 value, expression = expression.split('<', 1)
271 if expression[0] == '=':
272 op = '>='
273 expression = expression[1:]
274 else:
275 op = '>'
276 key = expression.split('<', 1)[0]
277 comparisons.append((key, op, value))
278 for op in ['!=', '<=', '>=', '<', '>', '=']:
279 if op in expression:
280 break
281 else: # no break
282 if expression in atomic_numbers:
283 comparisons.append((expression, '>', 0))
284 else:
285 try:
286 count = Formula(expression).count()
287 except ValueError:
288 keys.append(expression)
289 else:
290 comparisons.extend((symbol, '>', n - 1)
291 for symbol, n in count.items())
292 continue
293 key, value = expression.split(op)
294 comparisons.append((key, op, value))
296 cmps = []
297 for key, value in kwargs.items():
298 comparisons.append((key, '=', value))
300 for key, op, value in comparisons:
301 if key == 'age':
302 key = 'ctime'
303 op = invop[op]
304 value = now() - time_string_to_float(value)
305 elif key == 'formula':
306 if op != '=':
307 raise ValueError('Use fomula=...')
308 f = Formula(value)
309 count = f.count()
310 cmps.extend((atomic_numbers[symbol], '=', n)
311 for symbol, n in count.items())
312 key = 'natoms'
313 value = len(f)
314 elif key in atomic_numbers:
315 key = atomic_numbers[key]
316 value = int(value)
317 elif isinstance(value, str):
318 value = convert_str_to_int_float_bool_or_str(value)
319 if key in numeric_keys and not isinstance(value, (int, float)):
320 msg = 'Wrong type for "{}{}{}" - must be a number'
321 raise ValueError(msg.format(key, op, value))
322 cmps.append((key, op, value))
324 return keys, cmps
327class Database:
328 """Base class for all databases."""
330 def __init__(self, filename=None, create_indices=True,
331 use_lock_file=False, serial=False):
332 """Database object.
334 serial: bool
335 Let someone else handle parallelization. Default behavior is
336 to interact with the database on the master only and then
337 distribute results to all slaves.
338 """
339 if isinstance(filename, str):
340 filename = os.path.expanduser(filename)
341 self.filename = filename
342 self.create_indices = create_indices
343 if use_lock_file and isinstance(filename, str):
344 self.lock = Lock(filename + '.lock', world=DummyMPI())
345 else:
346 self.lock = None
347 self.serial = serial
349 # Decription of columns and other stuff:
350 self._metadata: Dict[str, Any] = None
352 @property
353 def metadata(self) -> Dict[str, Any]:
354 raise NotImplementedError
356 @parallel_function
357 @lock
358 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs):
359 """Write atoms to database with key-value pairs.
361 atoms: Atoms object
362 Write atomic numbers, positions, unit cell and boundary
363 conditions. If a calculator is attached, write also already
364 calculated properties such as the energy and forces.
365 key_value_pairs: dict
366 Dictionary of key-value pairs. Values must be strings or numbers.
367 data: dict
368 Extra stuff (not for searching).
369 id: int
370 Overwrite existing row.
372 Key-value pairs can also be set using keyword arguments::
374 connection.write(atoms, name='ABC', frequency=42.0)
376 Returns integer id of the new row.
377 """
379 if atoms is None:
380 atoms = Atoms()
382 kvp = dict(key_value_pairs) # modify a copy
383 kvp.update(kwargs)
385 id = self._write(atoms, kvp, data, id)
386 return id
388 def _write(self, atoms, key_value_pairs, data, id=None):
389 check(key_value_pairs)
390 return 1
392 @parallel_function
393 @lock
394 def reserve(self, **key_value_pairs):
395 """Write empty row if not already present.
397 Usage::
399 id = conn.reserve(key1=value1, key2=value2, ...)
401 Write an empty row with the given key-value pairs and
402 return the integer id. If such a row already exists, don't write
403 anything and return None.
404 """
406 for _ in self._select([],
407 [(key, '=', value)
408 for key, value in key_value_pairs.items()]):
409 return None
411 atoms = Atoms()
413 calc_name = key_value_pairs.pop('calculator', None)
415 if calc_name:
416 # Allow use of calculator key
417 assert calc_name.lower() == calc_name
419 # Fake calculator class:
420 class Fake:
421 name = calc_name
423 def todict(self):
424 return {}
426 def check_state(self, atoms):
427 return ['positions']
429 atoms.calc = Fake()
431 id = self._write(atoms, key_value_pairs, {}, None)
433 return id
435 def __delitem__(self, id):
436 self.delete([id])
438 def get_atoms(self, selection=None,
439 add_additional_information=False, **kwargs):
440 """Get Atoms object.
442 selection: int, str or list
443 See the select() method.
444 add_additional_information: bool
445 Put key-value pairs and data into Atoms.info dictionary.
447 In addition, one can use keyword arguments to select specific
448 key-value pairs.
449 """
451 row = self.get(selection, **kwargs)
452 return row.toatoms(add_additional_information)
454 def __getitem__(self, selection):
455 return self.get(selection)
457 def get(self, selection=None, **kwargs):
458 """Select a single row and return it as a dictionary.
460 selection: int, str or list
461 See the select() method.
462 """
463 rows = list(self.select(selection, limit=2, **kwargs))
464 if not rows:
465 raise KeyError('no match')
466 assert len(rows) == 1, 'more than one row matched'
467 return rows[0]
469 @parallel_generator
470 def select(self, selection=None, filter=None, explain=False,
471 verbosity=1, limit=None, offset=0, sort=None,
472 include_data=True, columns='all', **kwargs):
473 """Select rows.
475 Return AtomsRow iterator with results. Selection is done
476 using key-value pairs and the special keys:
478 formula, age, user, calculator, natoms, energy, magmom
479 and/or charge.
481 selection: int, str or list
482 Can be:
484 * an integer id
485 * a string like 'key=value', where '=' can also be one of
486 '<=', '<', '>', '>=' or '!='.
487 * a string like 'key'
488 * comma separated strings like 'key1<value1,key2=value2,key'
489 * list of strings or tuples: [('charge', '=', 1)].
490 filter: function
491 A function that takes as input a row and returns True or False.
492 explain: bool
493 Explain query plan.
494 verbosity: int
495 Possible values: 0, 1 or 2.
496 limit: int or None
497 Limit selection.
498 offset: int
499 Offset into selected rows.
500 sort: str
501 Sort rows after key. Prepend with minus sign for a decending sort.
502 include_data: bool
503 Use include_data=False to skip reading data from rows.
504 columns: 'all' or list of str
505 Specify which columns from the SQL table to include.
506 For example, if only the row id and the energy is needed,
507 queries can be speeded up by setting columns=['id', 'energy'].
508 """
510 if sort:
511 if sort == 'age':
512 sort = '-ctime'
513 elif sort == '-age':
514 sort = 'ctime'
515 elif sort.lstrip('-') == 'user':
516 sort += 'name'
518 keys, cmps = parse_selection(selection, **kwargs)
519 for row in self._select(keys, cmps, explain=explain,
520 verbosity=verbosity,
521 limit=limit, offset=offset, sort=sort,
522 include_data=include_data,
523 columns=columns):
524 if filter is None or filter(row):
525 yield row
527 def count(self, selection=None, **kwargs):
528 """Count rows.
530 See the select() method for the selection syntax. Use db.count() or
531 len(db) to count all rows.
532 """
533 n = 0
534 for _ in self.select(selection, **kwargs):
535 n += 1
536 return n
538 def __len__(self):
539 return self.count()
541 @parallel_function
542 @lock
543 def update(self, id, atoms=None, delete_keys=[], data=None,
544 **add_key_value_pairs):
545 """Update and/or delete key-value pairs of row(s).
547 id: int
548 ID of row to update.
549 atoms: Atoms object
550 Optionally update the Atoms data (positions, cell, ...).
551 data: dict
552 Data dict to be added to the existing data.
553 delete_keys: list of str
554 Keys to remove.
556 Use keyword arguments to add new key-value pairs.
558 Returns number of key-value pairs added and removed.
559 """
561 if not isinstance(id, numbers.Integral):
562 if isinstance(id, list):
563 err = ('First argument must be an int and not a list.\n'
564 'Do something like this instead:\n\n'
565 'with db:\n'
566 ' for id in ids:\n'
567 ' db.update(id, ...)')
568 raise ValueError(err)
569 raise TypeError('id must be an int')
571 check(add_key_value_pairs)
573 row = self._get_row(id)
574 kvp = row.key_value_pairs
576 n = len(kvp)
577 for key in delete_keys:
578 kvp.pop(key, None)
579 n -= len(kvp)
580 m = -len(kvp)
581 kvp.update(add_key_value_pairs)
582 m += len(kvp)
584 moredata = data
585 data = row.get('data', {})
586 if moredata:
587 data.update(moredata)
588 if not data:
589 data = None
591 if atoms:
592 oldrow = row
593 row = AtomsRow(atoms)
594 # Copy over data, kvp, ctime, user and id
595 row._data = oldrow._data
596 row.__dict__.update(kvp)
597 row._keys = list(kvp)
598 row.ctime = oldrow.ctime
599 row.user = oldrow.user
600 row.id = id
602 if atoms or os.path.splitext(self.filename)[1] == '.json':
603 self._write(row, kvp, data, row.id)
604 else:
605 self._update(row.id, kvp, data)
606 return m, n
608 def delete(self, ids):
609 """Delete rows."""
610 raise NotImplementedError
613def time_string_to_float(s):
614 if isinstance(s, (float, int)):
615 return s
616 s = s.replace(' ', '')
617 if '+' in s:
618 return sum(time_string_to_float(x) for x in s.split('+'))
619 if s[-2].isalpha() and s[-1] == 's':
620 s = s[:-1]
621 i = 1
622 while s[i].isdigit():
623 i += 1
624 return seconds[s[i:]] * int(s[:i]) / YEAR
627def float_to_time_string(t, long=False):
628 t *= YEAR
629 for s in 'yMwdhms':
630 x = t / seconds[s]
631 if x > 5:
632 break
633 if long:
634 return f'{x:.3f} {longwords[s]}s'
635 else:
636 return f'{round(x):.0f}{s}'
639def object_to_bytes(obj: Any) -> bytes:
640 """Serialize Python object to bytes."""
641 parts = [b'12345678']
642 obj = o2b(obj, parts)
643 offset = sum(len(part) for part in parts)
644 x = np.array(offset, np.int64)
645 if not np.little_endian:
646 x.byteswap(True)
647 parts[0] = x.tobytes()
648 parts.append(json.dumps(obj, separators=(',', ':')).encode())
649 return b''.join(parts)
652def bytes_to_object(b: bytes) -> Any:
653 """Deserialize bytes to Python object."""
654 x = np.frombuffer(b[:8], np.int64)
655 if not np.little_endian:
656 x = x.byteswap()
657 offset = x.item()
658 obj = json.loads(b[offset:].decode())
659 return b2o(obj, b)
662def o2b(obj: Any, parts: List[bytes]):
663 if isinstance(obj, (int, float, bool, str, type(None))):
664 return obj
665 if isinstance(obj, dict):
666 return {key: o2b(value, parts) for key, value in obj.items()}
667 if isinstance(obj, (list, tuple)):
668 return [o2b(value, parts) for value in obj]
669 if isinstance(obj, np.ndarray):
670 assert obj.dtype != object, \
671 'Cannot convert ndarray of type "object" to bytes.'
672 offset = sum(len(part) for part in parts)
673 if not np.little_endian:
674 obj = obj.byteswap()
675 parts.append(obj.tobytes())
676 return {'__ndarray__': [obj.shape,
677 obj.dtype.name,
678 offset]}
679 if isinstance(obj, complex):
680 return {'__complex__': [obj.real, obj.imag]}
681 objtype = obj.ase_objtype
682 if objtype:
683 dct = o2b(obj.todict(), parts)
684 dct['__ase_objtype__'] = objtype
685 return dct
686 raise ValueError('Objects of type {type} not allowed'
687 .format(type=type(obj)))
690def b2o(obj: Any, b: bytes) -> Any:
691 if isinstance(obj, (int, float, bool, str, type(None))):
692 return obj
694 if isinstance(obj, list):
695 return [b2o(value, b) for value in obj]
697 assert isinstance(obj, dict)
699 x = obj.get('__complex__')
700 if x is not None:
701 return complex(*x)
703 x = obj.get('__ndarray__')
704 if x is not None:
705 shape, name, offset = x
706 dtype = np.dtype(name)
707 size = dtype.itemsize * np.prod(shape).astype(int)
708 a = np.frombuffer(b[offset:offset + size], dtype)
709 a.shape = shape
710 if not np.little_endian:
711 a = a.byteswap()
712 return a
714 dct = {key: b2o(value, b) for key, value in obj.items()}
715 objtype = dct.pop('__ase_objtype__', None)
716 if objtype is None:
717 return dct
718 return create_ase_object(objtype, dct)