Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import functools
2import json
3import numbers
4import operator
5import os
6import re
7import warnings
8from time import time
9from typing import List, Dict, Any
11import numpy as np
13from ase.atoms import Atoms
14from ase.calculators.calculator import all_properties, all_changes
15from import atomic_numbers
16from ase.db.row import AtomsRow
17from ase.formula import Formula
18from import create_ase_object
19from ase.parallel import world, DummyMPI, parallel_function, parallel_generator
20from ase.utils import Lock, PurePath
23T2000 = 946681200.0 # January 1. 2000
24YEAR = 31557600.0 # 365.25 days
27# Format of key description: ('short', 'long', 'unit')
28default_key_descriptions = {
29 'id': ('ID', 'Uniqe row ID', ''),
30 'age': ('Age', 'Time since creation', ''),
31 'formula': ('Formula', 'Chemical formula', ''),
32 'pbc': ('PBC', 'Periodic boundary conditions', ''),
33 'user': ('Username', '', ''),
34 'calculator': ('Calculator', 'ASE-calculator name', ''),
35 'energy': ('Energy', 'Total energy', 'eV'),
36 'natoms': ('Number of atoms', '', ''),
37 'fmax': ('Maximum force', '', 'eV/Ang'),
38 'smax': ('Maximum stress', 'Maximum stress on unit cell',
39 '`\\text{eV/Ang}^3`'),
40 'charge': ('Charge', 'Net charge in unit cell', '|e|'),
41 'mass': ('Mass', 'Sum of atomic masses in unit cell', 'au'),
42 'magmom': ('Magnetic moment', '', 'au'),
43 'unique_id': ('Unique ID', 'Random (unique) ID', ''),
44 'volume': ('Volume', 'Volume of unit cell', '`\\text{Ang}^3`')}
47def now():
48 """Return time since January 1. 2000 in years."""
49 return (time() - T2000) / YEAR
52seconds = {'s': 1,
53 'm': 60,
54 'h': 3600,
55 'd': 86400,
56 'w': 604800,
57 'M': 2629800,
58 'y': YEAR}
60longwords = {'s': 'second',
61 'm': 'minute',
62 'h': 'hour',
63 'd': 'day',
64 'w': 'week',
65 'M': 'month',
66 'y': 'year'}
68ops = {'<':,
69 '<=': operator.le,
70 '=': operator.eq,
71 '>=':,
72 '>':,
73 '!=':}
75invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='}
77word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$')
79reserved_keys = set(all_properties +
80 all_changes +
81 list(atomic_numbers) +
82 ['id', 'unique_id', 'ctime', 'mtime', 'user',
83 'fmax', 'smax',
84 'momenta', 'constraints', 'natoms', 'formula', 'age',
85 'calculator', 'calculator_parameters',
86 'key_value_pairs', 'data'])
88numeric_keys = set(['id', 'energy', 'magmom', 'charge', 'natoms'])
91def check(key_value_pairs):
92 for key, value in key_value_pairs.items():
93 if key == "external_tables":
94 # Checks for external_tables are not
95 # performed
96 continue
98 if not word.match(key) or key in reserved_keys:
99 raise ValueError('Bad key: {}'.format(key))
100 try:
101 Formula(key, strict=True)
102 except ValueError:
103 pass
104 else:
105 warnings.warn(
106 'It is best not to use keys ({0}) that are also a '
107 'chemical formula. If you do a "{0!r})",'
108 'you will not find rows with your key. Instead, you wil get '
109 'rows containing the atoms in the formula!'.format(key))
110 if not isinstance(value, (numbers.Real, str, np.bool_)):
111 raise ValueError('Bad value for {!r}: {}'.format(key, value))
112 if isinstance(value, str):
113 for t in [int, float]:
114 if str_represents(value, t):
115 raise ValueError(
116 'Value ' + value + ' is put in as string ' +
117 'but can be interpreted as ' +
118 '{}! Please convert '.format(t.__name__) +
119 'to {} using '.format(t.__name__) +
120 '{}(value) before '.format(t.__name__) +
121 'writing to the database OR change ' +
122 'to a different string.')
125def str_represents(value, t=int):
126 try:
127 t(value)
128 except ValueError:
129 return False
130 return True
133def connect(name, type='extract_from_name', create_indices=True,
134 use_lock_file=True, append=True, serial=False):
135 """Create connection to database.
137 name: str
138 Filename or address of database.
139 type: str
140 One of 'json', 'db', 'postgresql',
141 (JSON, SQLite, PostgreSQL).
142 Default is 'extract_from_name', which will guess the type
143 from the name.
144 use_lock_file: bool
145 You can turn this off if you know what you are doing ...
146 append: bool
147 Use append=False to start a new database.
148 """
150 if isinstance(name, PurePath):
151 name = str(name)
153 if type == 'extract_from_name':
154 if name is None:
155 type = None
156 elif not isinstance(name, str):
157 type = 'json'
158 elif (name.startswith('postgresql://') or
159 name.startswith('postgres://')):
160 type = 'postgresql'
161 elif name.startswith('mysql://') or name.startswith('mariadb://'):
162 type = 'mysql'
163 else:
164 type = os.path.splitext(name)[1][1:]
165 if type == '':
166 raise ValueError('No file extension or database type given')
168 if type is None:
169 return Database()
171 if not append and world.rank == 0:
172 if isinstance(name, str) and os.path.isfile(name):
173 os.remove(name)
175 if type not in ['postgresql', 'mysql'] and isinstance(name, str):
176 name = os.path.abspath(name)
178 if type == 'json':
179 from ase.db.jsondb import JSONDatabase
180 return JSONDatabase(name, use_lock_file=use_lock_file, serial=serial)
181 if type == 'db':
182 from ase.db.sqlite import SQLite3Database
183 return SQLite3Database(name, create_indices, use_lock_file,
184 serial=serial)
185 if type == 'postgresql':
186 from ase.db.postgresql import PostgreSQLDatabase
187 return PostgreSQLDatabase(name)
189 if type == 'mysql':
190 from ase.db.mysql import MySQLDatabase
191 return MySQLDatabase(name)
192 raise ValueError('Unknown database type: ' + type)
195def lock(method):
196 """Decorator for using a lock-file."""
197 @functools.wraps(method)
198 def new_method(self, *args, **kwargs):
199 if self.lock is None:
200 return method(self, *args, **kwargs)
201 else:
202 with self.lock:
203 return method(self, *args, **kwargs)
204 return new_method
207def convert_str_to_int_float_or_str(value):
208 """Safe eval()"""
209 try:
210 return int(value)
211 except ValueError:
212 try:
213 value = float(value)
214 except ValueError:
215 value = {'True': True, 'False': False}.get(value, value)
216 return value
219def parse_selection(selection, **kwargs):
220 if selection is None or selection == '':
221 expressions = []
222 elif isinstance(selection, int):
223 expressions = [('id', '=', selection)]
224 elif isinstance(selection, list):
225 expressions = selection
226 else:
227 expressions = [w.strip() for w in selection.split(',')]
228 keys = []
229 comparisons = []
230 for expression in expressions:
231 if isinstance(expression, (list, tuple)):
232 comparisons.append(expression)
233 continue
234 if expression.count('<') == 2:
235 value, expression = expression.split('<', 1)
236 if expression[0] == '=':
237 op = '>='
238 expression = expression[1:]
239 else:
240 op = '>'
241 key = expression.split('<', 1)[0]
242 comparisons.append((key, op, value))
243 for op in ['!=', '<=', '>=', '<', '>', '=']:
244 if op in expression:
245 break
246 else: # no break
247 if expression in atomic_numbers:
248 comparisons.append((expression, '>', 0))
249 else:
250 try:
251 count = Formula(expression).count()
252 except ValueError:
253 keys.append(expression)
254 else:
255 comparisons.extend((symbol, '>', n - 1)
256 for symbol, n in count.items())
257 continue
258 key, value = expression.split(op)
259 comparisons.append((key, op, value))
261 cmps = []
262 for key, value in kwargs.items():
263 comparisons.append((key, '=', value))
265 for key, op, value in comparisons:
266 if key == 'age':
267 key = 'ctime'
268 op = invop[op]
269 value = now() - time_string_to_float(value)
270 elif key == 'formula':
271 if op != '=':
272 raise ValueError('Use fomula=...')
273 f = Formula(value)
274 count = f.count()
275 cmps.extend((atomic_numbers[symbol], '=', n)
276 for symbol, n in count.items())
277 key = 'natoms'
278 value = len(f)
279 elif key in atomic_numbers:
280 key = atomic_numbers[key]
281 value = int(value)
282 elif isinstance(value, str):
283 value = convert_str_to_int_float_or_str(value)
284 if key in numeric_keys and not isinstance(value, (int, float)):
285 msg = 'Wrong type for "{}{}{}" - must be a number'
286 raise ValueError(msg.format(key, op, value))
287 cmps.append((key, op, value))
289 return keys, cmps
292class Database:
293 """Base class for all databases."""
294 def __init__(self, filename=None, create_indices=True,
295 use_lock_file=False, serial=False):
296 """Database object.
298 serial: bool
299 Let someone else handle parallelization. Default behavior is
300 to interact with the database on the master only and then
301 distribute results to all slaves.
302 """
303 if isinstance(filename, str):
304 filename = os.path.expanduser(filename)
305 self.filename = filename
306 self.create_indices = create_indices
307 if use_lock_file and isinstance(filename, str):
308 self.lock = Lock(filename + '.lock', world=DummyMPI())
309 else:
310 self.lock = None
311 self.serial = serial
313 # Decription of columns and other stuff:
314 self._metadata: Dict[str, Any] = None
316 @property
317 def metadata(self) -> Dict[str, Any]:
318 raise NotImplementedError
320 @parallel_function
321 @lock
322 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs):
323 """Write atoms to database with key-value pairs.
325 atoms: Atoms object
326 Write atomic numbers, positions, unit cell and boundary
327 conditions. If a calculator is attached, write also already
328 calculated properties such as the energy and forces.
329 key_value_pairs: dict
330 Dictionary of key-value pairs. Values must be strings or numbers.
331 data: dict
332 Extra stuff (not for searching).
333 id: int
334 Overwrite existing row.
336 Key-value pairs can also be set using keyword arguments::
338 connection.write(atoms, name='ABC', frequency=42.0)
340 Returns integer id of the new row.
341 """
343 if atoms is None:
344 atoms = Atoms()
346 kvp = dict(key_value_pairs) # modify a copy
347 kvp.update(kwargs)
349 id = self._write(atoms, kvp, data, id)
350 return id
352 def _write(self, atoms, key_value_pairs, data, id=None):
353 check(key_value_pairs)
354 return 1
356 @parallel_function
357 @lock
358 def reserve(self, **key_value_pairs):
359 """Write empty row if not already present.
361 Usage::
363 id = conn.reserve(key1=value1, key2=value2, ...)
365 Write an empty row with the given key-value pairs and
366 return the integer id. If such a row already exists, don't write
367 anything and return None.
368 """
370 for dct in self._select([],
371 [(key, '=', value)
372 for key, value in key_value_pairs.items()]):
373 return None
375 atoms = Atoms()
377 calc_name = key_value_pairs.pop('calculator', None)
379 if calc_name:
380 # Allow use of calculator key
381 assert calc_name.lower() == calc_name
383 # Fake calculator class:
384 class Fake:
385 name = calc_name
387 def todict(self):
388 return {}
390 def check_state(self, atoms):
391 return ['positions']
393 atoms.calc = Fake()
395 id = self._write(atoms, key_value_pairs, {}, None)
397 return id
399 def __delitem__(self, id):
400 self.delete([id])
402 def get_atoms(self, selection=None, attach_calculator=False,
403 add_additional_information=False, **kwargs):
404 """Get Atoms object.
406 selection: int, str or list
407 See the select() method.
408 attach_calculator: bool
409 Attach calculator object to Atoms object (default value is
410 False).
411 add_additional_information: bool
412 Put key-value pairs and data into dictionary.
414 In addition, one can use keyword arguments to select specific
415 key-value pairs.
416 """
418 row = self.get(selection, **kwargs)
419 return row.toatoms(attach_calculator, add_additional_information)
421 def __getitem__(self, selection):
422 return self.get(selection)
424 def get(self, selection=None, **kwargs):
425 """Select a single row and return it as a dictionary.
427 selection: int, str or list
428 See the select() method.
429 """
430 rows = list(, limit=2, **kwargs))
431 if not rows:
432 raise KeyError('no match')
433 assert len(rows) == 1, 'more than one row matched'
434 return rows[0]
436 @parallel_generator
437 def select(self, selection=None, filter=None, explain=False,
438 verbosity=1, limit=None, offset=0, sort=None,
439 include_data=True, columns='all', **kwargs):
440 """Select rows.
442 Return AtomsRow iterator with results. Selection is done
443 using key-value pairs and the special keys:
445 formula, age, user, calculator, natoms, energy, magmom
446 and/or charge.
448 selection: int, str or list
449 Can be:
451 * an integer id
452 * a string like 'key=value', where '=' can also be one of
453 '<=', '<', '>', '>=' or '!='.
454 * a string like 'key'
455 * comma separated strings like 'key1<value1,key2=value2,key'
456 * list of strings or tuples: [('charge', '=', 1)].
457 filter: function
458 A function that takes as input a row and returns True or False.
459 explain: bool
460 Explain query plan.
461 verbosity: int
462 Possible values: 0, 1 or 2.
463 limit: int or None
464 Limit selection.
465 offset: int
466 Offset into selected rows.
467 sort: str
468 Sort rows after key. Prepend with minus sign for a decending sort.
469 include_data: bool
470 Use include_data=False to skip reading data from rows.
471 columns: 'all' or list of str
472 Specify which columns from the SQL table to include.
473 For example, if only the row id and the energy is needed,
474 queries can be speeded up by setting columns=['id', 'energy'].
475 """
477 if sort:
478 if sort == 'age':
479 sort = '-ctime'
480 elif sort == '-age':
481 sort = 'ctime'
482 elif sort.lstrip('-') == 'user':
483 sort += 'name'
485 keys, cmps = parse_selection(selection, **kwargs)
486 for row in self._select(keys, cmps, explain=explain,
487 verbosity=verbosity,
488 limit=limit, offset=offset, sort=sort,
489 include_data=include_data,
490 columns=columns):
491 if filter is None or filter(row):
492 yield row
494 def count(self, selection=None, **kwargs):
495 """Count rows.
497 See the select() method for the selection syntax. Use db.count() or
498 len(db) to count all rows.
499 """
500 n = 0
501 for row in, **kwargs):
502 n += 1
503 return n
505 def __len__(self):
506 return self.count()
508 @parallel_function
509 @lock
510 def update(self, id, atoms=None, delete_keys=[], data=None,
511 **add_key_value_pairs):
512 """Update and/or delete key-value pairs of row(s).
514 id: int
515 ID of row to update.
516 atoms: Atoms object
517 Optionally update the Atoms data (positions, cell, ...).
518 data: dict
519 Data dict to be added to the existing data.
520 delete_keys: list of str
521 Keys to remove.
523 Use keyword arguments to add new key-value pairs.
525 Returns number of key-value pairs added and removed.
526 """
528 if not isinstance(id, numbers.Integral):
529 if isinstance(id, list):
530 err = ('First argument must be an int and not a list.\n'
531 'Do something like this instead:\n\n'
532 'with db:\n'
533 ' for id in ids:\n'
534 ' db.update(id, ...)')
535 raise ValueError(err)
536 raise TypeError('id must be an int')
538 check(add_key_value_pairs)
540 row = self._get_row(id)
541 kvp = row.key_value_pairs
543 n = len(kvp)
544 for key in delete_keys:
545 kvp.pop(key, None)
546 n -= len(kvp)
547 m = -len(kvp)
548 kvp.update(add_key_value_pairs)
549 m += len(kvp)
551 moredata = data
552 data = row.get('data', {})
553 if moredata:
554 data.update(moredata)
555 if not data:
556 data = None
558 if atoms:
559 oldrow = row
560 row = AtomsRow(atoms)
561 # Copy over data, kvp, ctime, user and id
562 row._data = oldrow._data
563 row.__dict__.update(kvp)
564 row._keys = list(kvp)
565 row.ctime = oldrow.ctime
566 row.user = oldrow.user
567 = id
569 if atoms or os.path.splitext(self.filename)[1] == '.json':
570 self._write(row, kvp, data,
571 else:
572 self._update(, kvp, data)
573 return m, n
575 def delete(self, ids):
576 """Delete rows."""
577 raise NotImplementedError
580def time_string_to_float(s):
581 if isinstance(s, (float, int)):
582 return s
583 s = s.replace(' ', '')
584 if '+' in s:
585 return sum(time_string_to_float(x) for x in s.split('+'))
586 if s[-2].isalpha() and s[-1] == 's':
587 s = s[:-1]
588 i = 1
589 while s[i].isdigit():
590 i += 1
591 return seconds[s[i:]] * int(s[:i]) / YEAR
594def float_to_time_string(t, long=False):
595 t *= YEAR
596 for s in 'yMwdhms':
597 x = t / seconds[s]
598 if x > 5:
599 break
600 if long:
601 return '{:.3f} {}s'.format(x, longwords[s])
602 else:
603 return '{:.0f}{}'.format(round(x), s)
606def object_to_bytes(obj: Any) -> bytes:
607 """Serialize Python object to bytes."""
608 parts = [b'12345678']
609 obj = o2b(obj, parts)
610 offset = sum(len(part) for part in parts)
611 x = np.array(offset, np.int64)
612 if not np.little_endian:
613 x.byteswap(True)
614 parts[0] = x.tobytes()
615 parts.append(json.dumps(obj, separators=(',', ':')).encode())
616 return b''.join(parts)
619def bytes_to_object(b: bytes) -> Any:
620 """Deserialize bytes to Python object."""
621 x = np.frombuffer(b[:8], np.int64)
622 if not np.little_endian:
623 x = x.byteswap()
624 offset = x.item()
625 obj = json.loads(b[offset:].decode())
626 return b2o(obj, b)
629def o2b(obj: Any, parts: List[bytes]):
630 if isinstance(obj, (int, float, bool, str, type(None))):
631 return obj
632 if isinstance(obj, dict):
633 return {key: o2b(value, parts) for key, value in obj.items()}
634 if isinstance(obj, (list, tuple)):
635 return [o2b(value, parts) for value in obj]
636 if isinstance(obj, np.ndarray):
637 assert obj.dtype != object, \
638 'Cannot convert ndarray of type "object" to bytes.'
639 offset = sum(len(part) for part in parts)
640 if not np.little_endian:
641 obj = obj.byteswap()
642 parts.append(obj.tobytes())
643 return {'__ndarray__': [obj.shape,
645 offset]}
646 if isinstance(obj, complex):
647 return {'__complex__': [obj.real, obj.imag]}
648 objtype = getattr(obj, 'ase_objtype')
649 if objtype:
650 dct = o2b(obj.todict(), parts)
651 dct['__ase_objtype__'] = objtype
652 return dct
653 raise ValueError('Objects of type {type} not allowed'
654 .format(type=type(obj)))
657def b2o(obj: Any, b: bytes) -> Any:
658 if isinstance(obj, (int, float, bool, str, type(None))):
659 return obj
661 if isinstance(obj, list):
662 return [b2o(value, b) for value in obj]
664 assert isinstance(obj, dict)
666 x = obj.get('__complex__')
667 if x is not None:
668 return complex(*x)
670 x = obj.get('__ndarray__')
671 if x is not None:
672 shape, name, offset = x
673 dtype = np.dtype(name)
674 size = dtype.itemsize *
675 a = np.frombuffer(b[offset:offset + size], dtype)
676 a.shape = shape
677 if not np.little_endian:
678 a = a.byteswap()
679 return a
681 dct = {key: b2o(value, b) for key, value in obj.items()}
682 objtype = dct.pop('__ase_objtype__', None)
683 if objtype is None:
684 return dct
685 return create_ase_object(objtype, dct)