Coverage for /builds/debichem-team/python-ase/ase/db/sqlite.py: 90.13%
557 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1"""SQLite3 backend.
3Versions:
51) Added 3 more columns.
62) Changed "user" to "username".
73) Now adding keys to keyword table and added an "information" table containing
8 a version number.
94) Got rid of keywords.
105) Add fmax, smax, mass, volume, charge
116) Use REAL for magmom and drop possibility for non-collinear spin
127) Volume can be None
138) Added name='metadata' row to "information" table
149) Row data is now stored in binary format.
15"""
17import json
18import numbers
19import os
20import sqlite3
21import sys
22from contextlib import contextmanager
24import numpy as np
26import ase.io.jsonio
27from ase.calculators.calculator import all_properties
28from ase.data import atomic_numbers
29from ase.db.core import (
30 Database,
31 bytes_to_object,
32 invop,
33 lock,
34 now,
35 object_to_bytes,
36 ops,
37 parse_selection,
38)
39from ase.db.row import AtomsRow
40from ase.parallel import parallel_function
42VERSION = 9
44init_statements = [
45 """CREATE TABLE systems (
46 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name
47 unique_id TEXT UNIQUE,
48 ctime REAL,
49 mtime REAL,
50 username TEXT,
51 numbers BLOB, -- stuff that defines an Atoms object
52 positions BLOB,
53 cell BLOB,
54 pbc INTEGER,
55 initial_magmoms BLOB,
56 initial_charges BLOB,
57 masses BLOB,
58 tags BLOB,
59 momenta BLOB,
60 constraints TEXT, -- constraints and calculator
61 calculator TEXT,
62 calculator_parameters TEXT,
63 energy REAL, -- calculated properties
64 free_energy REAL,
65 forces BLOB,
66 stress BLOB,
67 dipole BLOB,
68 magmoms BLOB,
69 magmom REAL,
70 charges BLOB,
71 key_value_pairs TEXT, -- key-value pairs and data as json
72 data BLOB,
73 natoms INTEGER, -- stuff for making queries faster
74 fmax REAL,
75 smax REAL,
76 volume REAL,
77 mass REAL,
78 charge REAL)""",
80 """CREATE TABLE species (
81 Z INTEGER,
82 n INTEGER,
83 id INTEGER,
84 FOREIGN KEY (id) REFERENCES systems(id))""",
86 """CREATE TABLE keys (
87 key TEXT,
88 id INTEGER,
89 FOREIGN KEY (id) REFERENCES systems(id))""",
91 """CREATE TABLE text_key_values (
92 key TEXT,
93 value TEXT,
94 id INTEGER,
95 FOREIGN KEY (id) REFERENCES systems(id))""",
97 """CREATE TABLE number_key_values (
98 key TEXT,
99 value REAL,
100 id INTEGER,
101 FOREIGN KEY (id) REFERENCES systems(id))""",
103 """CREATE TABLE information (
104 name TEXT,
105 value TEXT)""",
107 f"INSERT INTO information VALUES ('version', '{VERSION}')"]
109index_statements = [
110 'CREATE INDEX unique_id_index ON systems(unique_id)',
111 'CREATE INDEX ctime_index ON systems(ctime)',
112 'CREATE INDEX username_index ON systems(username)',
113 'CREATE INDEX calculator_index ON systems(calculator)',
114 'CREATE INDEX species_index ON species(Z)',
115 'CREATE INDEX key_index ON keys(key)',
116 'CREATE INDEX text_index ON text_key_values(key)',
117 'CREATE INDEX number_index ON number_key_values(key)']
119all_tables = ['systems', 'species', 'keys',
120 'text_key_values', 'number_key_values']
123def float_if_not_none(x):
124 """Convert numpy.float64 to float - old db-interfaces need that."""
125 if x is not None:
126 return float(x)
129class SQLite3Database(Database):
130 type = 'db'
131 initialized = False
132 _allow_reading_old_format = False
133 default = 'NULL' # used for autoincrement id
134 connection = None
135 version = None
136 columnnames = [line.split()[0].lstrip()
137 for line in init_statements[0].splitlines()[1:]]
139 def encode(self, obj, binary=False):
140 if binary:
141 return object_to_bytes(obj)
142 return ase.io.jsonio.encode(obj)
144 def decode(self, txt, lazy=False):
145 if lazy:
146 return txt
147 if isinstance(txt, str):
148 return ase.io.jsonio.decode(txt)
149 return bytes_to_object(txt)
151 def blob(self, array):
152 """Convert array to blob/buffer object."""
154 if array is None:
155 return None
156 if len(array) == 0:
157 array = np.zeros(0)
158 if array.dtype == np.int64:
159 array = array.astype(np.int32)
160 if not np.little_endian:
161 array = array.byteswap()
162 return memoryview(np.ascontiguousarray(array))
164 def deblob(self, buf, dtype=float, shape=None):
165 """Convert blob/buffer object to ndarray of correct dtype and shape.
167 (without creating an extra view)."""
168 if buf is None:
169 return None
170 if len(buf) == 0:
171 array = np.zeros(0, dtype)
172 else:
173 array = np.frombuffer(buf, dtype)
174 if not np.little_endian:
175 array = array.byteswap()
176 if shape is not None:
177 array.shape = shape
178 return array
180 def _connect(self):
181 return sqlite3.connect(self.filename, timeout=20)
183 def __enter__(self):
184 assert self.connection is None
185 self.change_count = 0
186 self.connection = self._connect()
187 return self
189 def __exit__(self, exc_type, exc_value, tb):
190 if exc_type is None:
191 self.connection.commit()
192 else:
193 self.connection.rollback()
194 self.connection.close()
195 self.connection = None
197 @contextmanager
198 def managed_connection(self, commit_frequency=5000):
199 try:
200 con = self.connection or self._connect()
201 self._initialize(con)
202 yield con
203 except ValueError as exc:
204 if self.connection is None:
205 con.close()
206 raise exc
207 else:
208 if self.connection is None:
209 con.commit()
210 con.close()
211 else:
212 self.change_count += 1
213 if self.change_count % commit_frequency == 0:
214 con.commit()
216 def _initialize(self, con):
217 if self.initialized:
218 return
220 self._metadata = {}
222 cur = con.execute(
223 'SELECT COUNT(*) FROM sqlite_master WHERE name="systems"')
225 if cur.fetchone()[0] == 0:
226 for statement in init_statements:
227 con.execute(statement)
228 if self.create_indices:
229 for statement in index_statements:
230 con.execute(statement)
231 con.commit()
232 self.version = VERSION
233 else:
234 cur = con.execute(
235 'SELECT COUNT(*) FROM sqlite_master WHERE name="user_index"')
236 if cur.fetchone()[0] == 1:
237 # Old version with "user" instead of "username" column
238 self.version = 1
239 else:
240 try:
241 cur = con.execute(
242 'SELECT value FROM information WHERE name="version"')
243 except sqlite3.OperationalError:
244 self.version = 2
245 else:
246 self.version = int(cur.fetchone()[0])
248 cur = con.execute(
249 'SELECT value FROM information WHERE name="metadata"')
250 results = cur.fetchall()
251 if results:
252 self._metadata = json.loads(results[0][0])
254 if self.version > VERSION:
255 raise OSError('Can not read new ase.db format '
256 '(version {}). Please update to latest ASE.'
257 .format(self.version))
258 if self.version < 5 and not self._allow_reading_old_format:
259 raise OSError('Please convert to new format. ' +
260 'Use: python -m ase.db.convert ' + self.filename)
262 self.initialized = True
264 def _write(self, atoms, key_value_pairs, data, id):
265 ext_tables = key_value_pairs.pop("external_tables", {})
266 Database._write(self, atoms, key_value_pairs, data)
268 mtime = now()
270 encode = self.encode
271 blob = self.blob
273 if not isinstance(atoms, AtomsRow):
274 row = AtomsRow(atoms)
275 row.ctime = mtime
276 row.user = os.getenv('USER')
277 else:
278 row = atoms
279 # Extract the external tables from AtomsRow
280 names = self._get_external_table_names()
281 for name in names:
282 new_table = row.get(name, {})
283 if new_table:
284 ext_tables[name] = new_table
286 if not id and not key_value_pairs and not ext_tables:
287 key_value_pairs = row.key_value_pairs
289 for k, v in ext_tables.items():
290 dtype = self._guess_type(v)
291 self._create_table_if_not_exists(k, dtype)
293 constraints = row._constraints
294 if constraints:
295 if isinstance(constraints, list):
296 constraints = encode(constraints)
297 else:
298 constraints = None
300 values = (row.unique_id,
301 row.ctime,
302 mtime,
303 row.user,
304 blob(row.numbers),
305 blob(row.positions),
306 blob(row.cell),
307 int(np.dot(row.pbc, [1, 2, 4])),
308 blob(row.get('initial_magmoms')),
309 blob(row.get('initial_charges')),
310 blob(row.get('masses')),
311 blob(row.get('tags')),
312 blob(row.get('momenta')),
313 constraints)
315 if 'calculator' in row:
316 values += (row.calculator, encode(row.calculator_parameters))
317 else:
318 values += (None, None)
320 if not data:
321 data = row._data
323 with self.managed_connection() as con:
324 if not isinstance(data, (str, bytes)):
325 data = encode(data, binary=self.version >= 9)
327 values += (float_if_not_none(row.get('energy')),
328 float_if_not_none(row.get('free_energy')),
329 blob(row.get('forces')),
330 blob(row.get('stress')),
331 blob(row.get('dipole')),
332 blob(row.get('magmoms')),
333 row.get('magmom'),
334 blob(row.get('charges')),
335 encode(key_value_pairs),
336 data,
337 len(row.numbers),
338 float_if_not_none(row.get('fmax')),
339 float_if_not_none(row.get('smax')),
340 float_if_not_none(row.get('volume')),
341 float(row.mass),
342 float(row.charge))
344 cur = con.cursor()
345 if id is None:
346 q = self.default + ', ' + ', '.join('?' * len(values))
347 cur.execute(f'INSERT INTO systems VALUES ({q})',
348 values)
349 id = self.get_last_id(cur)
350 else:
351 self._delete(cur, [id], ['keys', 'text_key_values',
352 'number_key_values', 'species'])
353 q = ', '.join(name + '=?' for name in self.columnnames[1:])
354 cur.execute(f'UPDATE systems SET {q} WHERE id=?',
355 values + (id,))
357 count = row.count_atoms()
358 if count:
359 species = [(atomic_numbers[symbol], n, id)
360 for symbol, n in count.items()]
361 cur.executemany('INSERT INTO species VALUES (?, ?, ?)',
362 species)
364 text_key_values = []
365 number_key_values = []
366 for key, value in key_value_pairs.items():
367 if isinstance(value, (numbers.Real, np.bool_)):
368 number_key_values.append([key, float(value), id])
369 else:
370 assert isinstance(value, str)
371 text_key_values.append([key, value, id])
373 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
374 text_key_values)
375 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
376 number_key_values)
377 cur.executemany('INSERT INTO keys VALUES (?, ?)',
378 [(key, id) for key in key_value_pairs])
380 # Insert entries in the valid tables
381 for tabname in ext_tables.keys():
382 entries = ext_tables[tabname]
383 entries['id'] = id
384 self._insert_in_external_table(
385 cur, name=tabname, entries=ext_tables[tabname])
387 return id
389 def _update(self, id, key_value_pairs, data=None):
390 """Update key_value_pairs and data for a single row """
391 encode = self.encode
392 ext_tables = key_value_pairs.pop('external_tables', {})
394 for k, v in ext_tables.items():
395 dtype = self._guess_type(v)
396 self._create_table_if_not_exists(k, dtype)
398 mtime = now()
399 with self.managed_connection() as con:
400 cur = con.cursor()
401 cur.execute(
402 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?',
403 (mtime, encode(key_value_pairs), id))
404 if data:
405 if not isinstance(data, (str, bytes)):
406 data = encode(data, binary=self.version >= 9)
407 cur.execute('UPDATE systems set data=? where id=?', (data, id))
409 self._delete(cur, [id], ['keys', 'text_key_values',
410 'number_key_values'])
412 text_key_values = []
413 number_key_values = []
414 for key, value in key_value_pairs.items():
415 if isinstance(value, (numbers.Real, np.bool_)):
416 number_key_values.append([key, float(value), id])
417 else:
418 assert isinstance(value, str)
419 text_key_values.append([key, value, id])
421 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
422 text_key_values)
423 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
424 number_key_values)
425 cur.executemany('INSERT INTO keys VALUES (?, ?)',
426 [(key, id) for key in key_value_pairs])
428 # Insert entries in the valid tables
429 for tabname in ext_tables.keys():
430 entries = ext_tables[tabname]
431 entries['id'] = id
432 self._insert_in_external_table(
433 cur, name=tabname, entries=ext_tables[tabname])
435 return id
437 def get_last_id(self, cur):
438 cur.execute('SELECT seq FROM sqlite_sequence WHERE name="systems"')
439 result = cur.fetchone()
440 if result is not None:
441 id = result[0]
442 return id
443 else:
444 return 0
446 def _get_row(self, id):
447 with self.managed_connection() as con:
448 cur = con.cursor()
449 if id is None:
450 cur.execute('SELECT COUNT(*) FROM systems')
451 assert cur.fetchone()[0] == 1
452 cur.execute('SELECT * FROM systems')
453 else:
454 cur.execute('SELECT * FROM systems WHERE id=?', (id,))
455 values = cur.fetchone()
457 return self._convert_tuple_to_row(values)
459 def _convert_tuple_to_row(self, values):
460 deblob = self.deblob
461 decode = self.decode
463 values = self._old2new(values)
464 dct = {'id': values[0],
465 'unique_id': values[1],
466 'ctime': values[2],
467 'mtime': values[3],
468 'user': values[4],
469 'numbers': deblob(values[5], np.int32),
470 'positions': deblob(values[6], shape=(-1, 3)),
471 'cell': deblob(values[7], shape=(3, 3))}
473 if values[8] is not None:
474 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool)
475 if values[9] is not None:
476 dct['initial_magmoms'] = deblob(values[9])
477 if values[10] is not None:
478 dct['initial_charges'] = deblob(values[10])
479 if values[11] is not None:
480 dct['masses'] = deblob(values[11])
481 if values[12] is not None:
482 dct['tags'] = deblob(values[12], np.int32)
483 if values[13] is not None:
484 dct['momenta'] = deblob(values[13], shape=(-1, 3))
485 if values[14] is not None:
486 dct['constraints'] = values[14]
487 if values[15] is not None:
488 dct['calculator'] = values[15]
489 if values[16] is not None:
490 dct['calculator_parameters'] = decode(values[16])
491 if values[17] is not None:
492 dct['energy'] = values[17]
493 if values[18] is not None:
494 dct['free_energy'] = values[18]
495 if values[19] is not None:
496 dct['forces'] = deblob(values[19], shape=(-1, 3))
497 if values[20] is not None:
498 dct['stress'] = deblob(values[20])
499 if values[21] is not None:
500 dct['dipole'] = deblob(values[21])
501 if values[22] is not None:
502 dct['magmoms'] = deblob(values[22])
503 if values[23] is not None:
504 dct['magmom'] = values[23]
505 if values[24] is not None:
506 dct['charges'] = deblob(values[24])
507 if values[25] != '{}':
508 dct['key_value_pairs'] = decode(values[25])
509 if len(values) >= 27 and values[26] != 'null':
510 dct['data'] = decode(values[26], lazy=True)
512 # Now we need to update with info from the external tables
513 external_tab = self._get_external_table_names()
514 tables = {}
515 for tab in external_tab:
516 row = self._read_external_table(tab, dct["id"])
517 tables[tab] = row
519 dct.update(tables)
520 return AtomsRow(dct)
522 def _old2new(self, values):
523 if self.type == 'postgresql':
524 assert self.version >= 8, 'Your db-version is too old!'
525 assert self.version >= 4, 'Your db-file is too old!'
526 if self.version < 5:
527 pass # should be ok for reading by convert.py script
528 if self.version < 6:
529 m = values[23]
530 if m is not None and not isinstance(m, float):
531 magmom = float(self.deblob(m, shape=()))
532 values = values[:23] + (magmom,) + values[24:]
533 return values
535 def create_select_statement(self, keys, cmps,
536 sort=None, order=None, sort_table=None,
537 what='systems.*'):
538 tables = ['systems']
539 where = []
540 args = []
541 for key in keys:
542 if key == 'forces':
543 where.append('systems.fmax IS NOT NULL')
544 elif key == 'strain':
545 where.append('systems.smax IS NOT NULL')
546 elif key in ['energy', 'fmax', 'smax',
547 'constraints', 'calculator']:
548 where.append(f'systems.{key} IS NOT NULL')
549 else:
550 if '-' not in key:
551 q = 'systems.id in (select id from keys where key=?)'
552 else:
553 key = key.replace('-', '')
554 q = 'systems.id not in (select id from keys where key=?)'
555 where.append(q)
556 args.append(key)
558 # Special handling of "H=0" and "H<2" type of selections:
559 bad = {}
560 for key, op, value in cmps:
561 if isinstance(key, int):
562 bad[key] = bad.get(key, True) and ops[op](0, value)
564 for key, op, value in cmps:
565 if key in ['id', 'energy', 'magmom', 'ctime', 'user',
566 'calculator', 'natoms', 'pbc', 'unique_id',
567 'fmax', 'smax', 'volume', 'mass', 'charge']:
568 if key == 'user':
569 key = 'username'
570 elif key == 'pbc':
571 assert op in ['=', '!=']
572 value = int(np.dot([x == 'T' for x in value], [1, 2, 4]))
573 elif key == 'magmom':
574 assert self.version >= 6, 'Update your db-file'
575 where.append(f'systems.{key}{op}?')
576 args.append(value)
577 elif isinstance(key, int):
578 if self.type == 'postgresql':
579 where.append(
580 'cardinality(array_positions(' +
581 f'numbers::int[], ?)){op}?')
582 args += [key, value]
583 else:
584 if bad[key]:
585 where.append(
586 'systems.id not in (select id from species ' +
587 f'where Z=? and n{invop[op]}?)')
588 args += [key, value]
589 else:
590 where.append('systems.id in (select id from species ' +
591 f'where Z=? and n{op}?)')
592 args += [key, value]
594 elif self.type == 'postgresql':
595 jsonop = '->'
596 if isinstance(value, str):
597 jsonop = '->>'
598 elif isinstance(value, bool):
599 jsonop = '->>'
600 value = str(value).lower()
601 where.append("systems.key_value_pairs {} '{}'{}?"
602 .format(jsonop, key, op))
603 args.append(str(value))
605 elif isinstance(value, str):
606 where.append('systems.id in (select id from text_key_values ' +
607 f'where key=? and value{op}?)')
608 args += [key, value]
609 else:
610 where.append(
611 'systems.id in (select id from number_key_values ' +
612 f'where key=? and value{op}?)')
613 args += [key, float(value)]
615 if sort:
616 if sort_table != 'systems':
617 tables.append(f'{sort_table} AS sort_table')
618 where.append('systems.id=sort_table.id AND '
619 'sort_table.key=?')
620 args.append(sort)
621 sort_table = 'sort_table'
622 sort = 'value'
624 sql = f'SELECT {what} FROM\n ' + ', '.join(tables)
625 if where:
626 sql += '\n WHERE\n ' + ' AND\n '.join(where)
627 if sort:
628 # XXX use "?" instead of "{}"
629 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format(
630 sort_table, sort, order)
632 return sql, args
634 def _select(self, keys, cmps, explain=False, verbosity=0,
635 limit=None, offset=0, sort=None, include_data=True,
636 columns='all'):
638 values = np.array([None for _ in range(27)])
639 values[25] = '{}'
640 values[26] = 'null'
642 if columns == 'all':
643 columnindex = list(range(26))
644 else:
645 columnindex = [c for c in range(26)
646 if self.columnnames[c] in columns]
647 if include_data:
648 columnindex.append(26)
650 if sort:
651 if sort[0] == '-':
652 order = 'DESC'
653 sort = sort[1:]
654 else:
655 order = 'ASC'
656 if sort in ['id', 'energy', 'username', 'calculator',
657 'ctime', 'mtime', 'magmom', 'pbc',
658 'fmax', 'smax', 'volume', 'mass', 'charge', 'natoms']:
659 sort_table = 'systems'
660 else:
661 for dct in self._select(keys + [sort], cmps=[], limit=1,
662 include_data=False,
663 columns=['key_value_pairs']):
664 if isinstance(dct['key_value_pairs'][sort], str):
665 sort_table = 'text_key_values'
666 else:
667 sort_table = 'number_key_values'
668 break
669 else:
670 # No rows. Just pick a table:
671 sort_table = 'number_key_values'
673 else:
674 order = None
675 sort_table = None
677 what = ', '.join('systems.' + name
678 for name in
679 np.array(self.columnnames)[np.array(columnindex)])
681 sql, args = self.create_select_statement(keys, cmps, sort, order,
682 sort_table, what)
684 if explain:
685 sql = 'EXPLAIN QUERY PLAN ' + sql
687 if limit:
688 sql += f'\nLIMIT {limit}'
690 if offset:
691 sql += self.get_offset_string(offset, limit=limit)
693 if verbosity == 2:
694 print(sql, args)
696 with self.managed_connection() as con:
697 cur = con.cursor()
698 cur.execute(sql, args)
699 if explain:
700 for row in cur.fetchall():
701 yield {'explain': row}
702 else:
703 n = 0
704 for shortvalues in cur.fetchall():
705 values[columnindex] = shortvalues
706 yield self._convert_tuple_to_row(tuple(values))
707 n += 1
709 if sort and sort_table != 'systems':
710 # Yield rows without sort key last:
711 if limit is not None:
712 if n == limit:
713 return
714 limit -= n
715 for row in self._select(keys + ['-' + sort], cmps,
716 limit=limit, offset=offset,
717 include_data=include_data,
718 columns=columns):
719 yield row
721 def get_offset_string(self, offset, limit=None):
722 sql = ''
723 if not limit:
724 # In sqlite you cannot have offset without limit, so we
725 # set it to -1 meaning no limit
726 sql += '\nLIMIT -1'
727 sql += f'\nOFFSET {offset}'
728 return sql
730 @parallel_function
731 def count(self, selection=None, **kwargs):
732 keys, cmps = parse_selection(selection, **kwargs)
733 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)')
735 with self.managed_connection() as con:
736 cur = con.cursor()
737 cur.execute(sql, args)
738 return cur.fetchone()[0]
740 def analyse(self):
741 with self.managed_connection() as con:
742 con.execute('ANALYZE')
744 @parallel_function
745 @lock
746 def delete(self, ids):
747 if len(ids) == 0:
748 return
749 table_names = self._get_external_table_names() + all_tables[::-1]
750 with self.managed_connection() as con:
751 self._delete(con.cursor(), ids,
752 tables=table_names)
753 self.vacuum()
755 def _delete(self, cur, ids, tables=None):
756 tables = tables or all_tables[::-1]
757 for table in tables:
758 cur.execute('DELETE FROM {} WHERE id in ({});'.
759 format(table, ', '.join([str(id) for id in ids])))
761 def vacuum(self):
762 if self.type != 'db':
763 return
765 with self.managed_connection() as con:
766 con.commit()
767 con.cursor().execute("VACUUM")
769 @property
770 def metadata(self):
771 if self._metadata is None:
772 self._initialize(self._connect())
773 return self._metadata.copy()
775 @metadata.setter
776 def metadata(self, dct):
777 self._metadata = dct
778 md = json.dumps(dct)
779 with self.managed_connection() as con:
780 cur = con.cursor()
781 cur.execute(
782 "SELECT COUNT(*) FROM information WHERE name='metadata'")
784 if cur.fetchone()[0]:
785 cur.execute(
786 "UPDATE information SET value=? WHERE name='metadata'",
787 [md])
788 else:
789 cur.execute('INSERT INTO information VALUES (?, ?)',
790 ('metadata', md))
792 def _get_external_table_names(self, db_con=None):
793 """Return a list with the external table names."""
794 sql = "SELECT value FROM information WHERE name='external_table_name'"
795 with self.managed_connection() as con:
796 cur = con.cursor()
797 cur.execute(sql)
798 ext_tab_names = [x[0] for x in cur.fetchall()]
799 return ext_tab_names
801 def _external_table_exists(self, name):
802 """Return True if an external table name exists."""
803 return name in self._get_external_table_names()
805 def _create_table_if_not_exists(self, name, dtype):
806 """Create a new table if it does not exits.
808 Arguments
809 ==========
810 name: str
811 Name of the new table
812 dtype: str
813 Datatype of the value field (typically REAL, INTEGER, TEXT etc.)
814 """
816 taken_names = set(all_tables + all_properties + self.columnnames)
817 if name in taken_names:
818 raise ValueError("External table can not be any of {}"
819 "".format(taken_names))
821 if self._external_table_exists(name):
822 return
824 sql = f"CREATE TABLE IF NOT EXISTS {name} "
825 sql += f"(key TEXT, value {dtype}, id INTEGER, "
826 sql += "FOREIGN KEY (id) REFERENCES systems(id))"
827 sql2 = "INSERT INTO information VALUES (?, ?)"
828 with self.managed_connection() as con:
829 cur = con.cursor()
830 cur.execute(sql)
831 # Insert an entry saying that there is a new external table
832 # present and an entry with the datatype
833 cur.execute(sql2, ("external_table_name", name))
834 cur.execute(sql2, (name + "_dtype", dtype))
836 def delete_external_table(self, name):
837 """Delete an external table."""
838 if not self._external_table_exists(name):
839 return
841 with self.managed_connection() as con:
842 cur = con.cursor()
844 sql = f"DROP TABLE {name}"
845 cur.execute(sql)
847 sql = "DELETE FROM information WHERE value=?"
848 cur.execute(sql, (name,))
849 sql = "DELETE FROM information WHERE name=?"
850 cur.execute(sql, (name + "_dtype",))
852 def _convert_to_recognized_types(self, value):
853 """Convert Numpy types to python types."""
854 if np.issubdtype(type(value), np.integer):
855 return int(value)
856 elif np.issubdtype(type(value), np.floating):
857 return float(value)
858 return value
860 def _insert_in_external_table(self, cursor, name=None, entries=None):
861 """Insert into external table"""
862 if name is None or entries is None:
863 # There is nothing to do
864 return
866 id = entries.pop("id")
867 dtype = self._guess_type(entries)
868 expected_dtype = self._get_value_type_of_table(cursor, name)
869 if dtype != expected_dtype:
870 raise ValueError("The provided data type for table {} "
871 "is {}, while it is initialized to "
872 "be of type {}"
873 "".format(name, dtype, expected_dtype))
875 # First we check if entries already exists
876 cursor.execute(f"SELECT key FROM {name} WHERE id=?", (id,))
877 updates = []
878 for item in cursor.fetchall():
879 value = entries.pop(item[0], None)
880 if value is not None:
881 updates.append(
882 (value, id, self._convert_to_recognized_types(item[0])))
884 # Update entry if key and ID already exists
885 sql = f"UPDATE {name} SET value=? WHERE id=? AND key=?"
886 cursor.executemany(sql, updates)
888 # Insert the ones that does not already exist
889 inserts = [(k, self._convert_to_recognized_types(v), id)
890 for k, v in entries.items()]
891 sql = f"INSERT INTO {name} VALUES (?, ?, ?)"
892 cursor.executemany(sql, inserts)
894 def _guess_type(self, entries):
895 """Guess the type based on the first entry."""
896 values = [v for _, v in entries.items()]
898 # Check if all datatypes are the same
899 all_types = [type(v) for v in values]
900 if any(t != all_types[0] for t in all_types):
901 typenames = [t.__name__ for t in all_types]
902 raise ValueError("Inconsistent datatypes in the table. "
903 "given types: {}".format(typenames))
905 val = values[0]
906 if isinstance(val, int) or np.issubdtype(type(val), np.integer):
907 return "INTEGER"
908 if isinstance(val, float) or np.issubdtype(type(val), np.floating):
909 return "REAL"
910 if isinstance(val, str):
911 return "TEXT"
912 raise ValueError("Unknown datatype!")
914 def _get_value_type_of_table(self, cursor, tab_name):
915 """Return the expected value name."""
916 sql = "SELECT value FROM information WHERE name=?"
917 cursor.execute(sql, (tab_name + "_dtype",))
918 return cursor.fetchone()[0]
920 def _read_external_table(self, name, id):
921 """Read row from external table."""
923 with self.managed_connection() as con:
924 cur = con.cursor()
925 cur.execute(f"SELECT * FROM {name} WHERE id=?", (id,))
926 items = cur.fetchall()
927 dictionary = {item[0]: item[1] for item in items}
929 return dictionary
931 def get_all_key_names(self):
932 """Create set of all key names."""
933 with self.managed_connection() as con:
934 cur = con.cursor()
935 cur.execute('SELECT DISTINCT key FROM keys;')
936 all_keys = {row[0] for row in cur.fetchall()}
937 return all_keys
940if __name__ == '__main__':
941 from ase.db import connect
942 con = connect(sys.argv[1])
943 con._initialize(con._connect())
944 print('Version:', con.version)