Coverage for /builds/debichem-team/python-ase/ase/db/sqlite.py: 90.13%

557 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1"""SQLite3 backend. 

2 

3Versions: 

4 

51) Added 3 more columns. 

62) Changed "user" to "username". 

73) Now adding keys to keyword table and added an "information" table containing 

8 a version number. 

94) Got rid of keywords. 

105) Add fmax, smax, mass, volume, charge 

116) Use REAL for magmom and drop possibility for non-collinear spin 

127) Volume can be None 

138) Added name='metadata' row to "information" table 

149) Row data is now stored in binary format. 

15""" 

16 

17import json 

18import numbers 

19import os 

20import sqlite3 

21import sys 

22from contextlib import contextmanager 

23 

24import numpy as np 

25 

26import ase.io.jsonio 

27from ase.calculators.calculator import all_properties 

28from ase.data import atomic_numbers 

29from ase.db.core import ( 

30 Database, 

31 bytes_to_object, 

32 invop, 

33 lock, 

34 now, 

35 object_to_bytes, 

36 ops, 

37 parse_selection, 

38) 

39from ase.db.row import AtomsRow 

40from ase.parallel import parallel_function 

41 

42VERSION = 9 

43 

44init_statements = [ 

45 """CREATE TABLE systems ( 

46 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name 

47 unique_id TEXT UNIQUE, 

48 ctime REAL, 

49 mtime REAL, 

50 username TEXT, 

51 numbers BLOB, -- stuff that defines an Atoms object 

52 positions BLOB, 

53 cell BLOB, 

54 pbc INTEGER, 

55 initial_magmoms BLOB, 

56 initial_charges BLOB, 

57 masses BLOB, 

58 tags BLOB, 

59 momenta BLOB, 

60 constraints TEXT, -- constraints and calculator 

61 calculator TEXT, 

62 calculator_parameters TEXT, 

63 energy REAL, -- calculated properties 

64 free_energy REAL, 

65 forces BLOB, 

66 stress BLOB, 

67 dipole BLOB, 

68 magmoms BLOB, 

69 magmom REAL, 

70 charges BLOB, 

71 key_value_pairs TEXT, -- key-value pairs and data as json 

72 data BLOB, 

73 natoms INTEGER, -- stuff for making queries faster 

74 fmax REAL, 

75 smax REAL, 

76 volume REAL, 

77 mass REAL, 

78 charge REAL)""", 

79 

80 """CREATE TABLE species ( 

81 Z INTEGER, 

82 n INTEGER, 

83 id INTEGER, 

84 FOREIGN KEY (id) REFERENCES systems(id))""", 

85 

86 """CREATE TABLE keys ( 

87 key TEXT, 

88 id INTEGER, 

89 FOREIGN KEY (id) REFERENCES systems(id))""", 

90 

91 """CREATE TABLE text_key_values ( 

92 key TEXT, 

93 value TEXT, 

94 id INTEGER, 

95 FOREIGN KEY (id) REFERENCES systems(id))""", 

96 

97 """CREATE TABLE number_key_values ( 

98 key TEXT, 

99 value REAL, 

100 id INTEGER, 

101 FOREIGN KEY (id) REFERENCES systems(id))""", 

102 

103 """CREATE TABLE information ( 

104 name TEXT, 

105 value TEXT)""", 

106 

107 f"INSERT INTO information VALUES ('version', '{VERSION}')"] 

108 

109index_statements = [ 

110 'CREATE INDEX unique_id_index ON systems(unique_id)', 

111 'CREATE INDEX ctime_index ON systems(ctime)', 

112 'CREATE INDEX username_index ON systems(username)', 

113 'CREATE INDEX calculator_index ON systems(calculator)', 

114 'CREATE INDEX species_index ON species(Z)', 

115 'CREATE INDEX key_index ON keys(key)', 

116 'CREATE INDEX text_index ON text_key_values(key)', 

117 'CREATE INDEX number_index ON number_key_values(key)'] 

118 

119all_tables = ['systems', 'species', 'keys', 

120 'text_key_values', 'number_key_values'] 

121 

122 

123def float_if_not_none(x): 

124 """Convert numpy.float64 to float - old db-interfaces need that.""" 

125 if x is not None: 

126 return float(x) 

127 

128 

129class SQLite3Database(Database): 

130 type = 'db' 

131 initialized = False 

132 _allow_reading_old_format = False 

133 default = 'NULL' # used for autoincrement id 

134 connection = None 

135 version = None 

136 columnnames = [line.split()[0].lstrip() 

137 for line in init_statements[0].splitlines()[1:]] 

138 

139 def encode(self, obj, binary=False): 

140 if binary: 

141 return object_to_bytes(obj) 

142 return ase.io.jsonio.encode(obj) 

143 

144 def decode(self, txt, lazy=False): 

145 if lazy: 

146 return txt 

147 if isinstance(txt, str): 

148 return ase.io.jsonio.decode(txt) 

149 return bytes_to_object(txt) 

150 

151 def blob(self, array): 

152 """Convert array to blob/buffer object.""" 

153 

154 if array is None: 

155 return None 

156 if len(array) == 0: 

157 array = np.zeros(0) 

158 if array.dtype == np.int64: 

159 array = array.astype(np.int32) 

160 if not np.little_endian: 

161 array = array.byteswap() 

162 return memoryview(np.ascontiguousarray(array)) 

163 

164 def deblob(self, buf, dtype=float, shape=None): 

165 """Convert blob/buffer object to ndarray of correct dtype and shape. 

166 

167 (without creating an extra view).""" 

168 if buf is None: 

169 return None 

170 if len(buf) == 0: 

171 array = np.zeros(0, dtype) 

172 else: 

173 array = np.frombuffer(buf, dtype) 

174 if not np.little_endian: 

175 array = array.byteswap() 

176 if shape is not None: 

177 array.shape = shape 

178 return array 

179 

180 def _connect(self): 

181 return sqlite3.connect(self.filename, timeout=20) 

182 

183 def __enter__(self): 

184 assert self.connection is None 

185 self.change_count = 0 

186 self.connection = self._connect() 

187 return self 

188 

189 def __exit__(self, exc_type, exc_value, tb): 

190 if exc_type is None: 

191 self.connection.commit() 

192 else: 

193 self.connection.rollback() 

194 self.connection.close() 

195 self.connection = None 

196 

197 @contextmanager 

198 def managed_connection(self, commit_frequency=5000): 

199 try: 

200 con = self.connection or self._connect() 

201 self._initialize(con) 

202 yield con 

203 except ValueError as exc: 

204 if self.connection is None: 

205 con.close() 

206 raise exc 

207 else: 

208 if self.connection is None: 

209 con.commit() 

210 con.close() 

211 else: 

212 self.change_count += 1 

213 if self.change_count % commit_frequency == 0: 

214 con.commit() 

215 

216 def _initialize(self, con): 

217 if self.initialized: 

218 return 

219 

220 self._metadata = {} 

221 

222 cur = con.execute( 

223 'SELECT COUNT(*) FROM sqlite_master WHERE name="systems"') 

224 

225 if cur.fetchone()[0] == 0: 

226 for statement in init_statements: 

227 con.execute(statement) 

228 if self.create_indices: 

229 for statement in index_statements: 

230 con.execute(statement) 

231 con.commit() 

232 self.version = VERSION 

233 else: 

234 cur = con.execute( 

235 'SELECT COUNT(*) FROM sqlite_master WHERE name="user_index"') 

236 if cur.fetchone()[0] == 1: 

237 # Old version with "user" instead of "username" column 

238 self.version = 1 

239 else: 

240 try: 

241 cur = con.execute( 

242 'SELECT value FROM information WHERE name="version"') 

243 except sqlite3.OperationalError: 

244 self.version = 2 

245 else: 

246 self.version = int(cur.fetchone()[0]) 

247 

248 cur = con.execute( 

249 'SELECT value FROM information WHERE name="metadata"') 

250 results = cur.fetchall() 

251 if results: 

252 self._metadata = json.loads(results[0][0]) 

253 

254 if self.version > VERSION: 

255 raise OSError('Can not read new ase.db format ' 

256 '(version {}). Please update to latest ASE.' 

257 .format(self.version)) 

258 if self.version < 5 and not self._allow_reading_old_format: 

259 raise OSError('Please convert to new format. ' + 

260 'Use: python -m ase.db.convert ' + self.filename) 

261 

262 self.initialized = True 

263 

264 def _write(self, atoms, key_value_pairs, data, id): 

265 ext_tables = key_value_pairs.pop("external_tables", {}) 

266 Database._write(self, atoms, key_value_pairs, data) 

267 

268 mtime = now() 

269 

270 encode = self.encode 

271 blob = self.blob 

272 

273 if not isinstance(atoms, AtomsRow): 

274 row = AtomsRow(atoms) 

275 row.ctime = mtime 

276 row.user = os.getenv('USER') 

277 else: 

278 row = atoms 

279 # Extract the external tables from AtomsRow 

280 names = self._get_external_table_names() 

281 for name in names: 

282 new_table = row.get(name, {}) 

283 if new_table: 

284 ext_tables[name] = new_table 

285 

286 if not id and not key_value_pairs and not ext_tables: 

287 key_value_pairs = row.key_value_pairs 

288 

289 for k, v in ext_tables.items(): 

290 dtype = self._guess_type(v) 

291 self._create_table_if_not_exists(k, dtype) 

292 

293 constraints = row._constraints 

294 if constraints: 

295 if isinstance(constraints, list): 

296 constraints = encode(constraints) 

297 else: 

298 constraints = None 

299 

300 values = (row.unique_id, 

301 row.ctime, 

302 mtime, 

303 row.user, 

304 blob(row.numbers), 

305 blob(row.positions), 

306 blob(row.cell), 

307 int(np.dot(row.pbc, [1, 2, 4])), 

308 blob(row.get('initial_magmoms')), 

309 blob(row.get('initial_charges')), 

310 blob(row.get('masses')), 

311 blob(row.get('tags')), 

312 blob(row.get('momenta')), 

313 constraints) 

314 

315 if 'calculator' in row: 

316 values += (row.calculator, encode(row.calculator_parameters)) 

317 else: 

318 values += (None, None) 

319 

320 if not data: 

321 data = row._data 

322 

323 with self.managed_connection() as con: 

324 if not isinstance(data, (str, bytes)): 

325 data = encode(data, binary=self.version >= 9) 

326 

327 values += (float_if_not_none(row.get('energy')), 

328 float_if_not_none(row.get('free_energy')), 

329 blob(row.get('forces')), 

330 blob(row.get('stress')), 

331 blob(row.get('dipole')), 

332 blob(row.get('magmoms')), 

333 row.get('magmom'), 

334 blob(row.get('charges')), 

335 encode(key_value_pairs), 

336 data, 

337 len(row.numbers), 

338 float_if_not_none(row.get('fmax')), 

339 float_if_not_none(row.get('smax')), 

340 float_if_not_none(row.get('volume')), 

341 float(row.mass), 

342 float(row.charge)) 

343 

344 cur = con.cursor() 

345 if id is None: 

346 q = self.default + ', ' + ', '.join('?' * len(values)) 

347 cur.execute(f'INSERT INTO systems VALUES ({q})', 

348 values) 

349 id = self.get_last_id(cur) 

350 else: 

351 self._delete(cur, [id], ['keys', 'text_key_values', 

352 'number_key_values', 'species']) 

353 q = ', '.join(name + '=?' for name in self.columnnames[1:]) 

354 cur.execute(f'UPDATE systems SET {q} WHERE id=?', 

355 values + (id,)) 

356 

357 count = row.count_atoms() 

358 if count: 

359 species = [(atomic_numbers[symbol], n, id) 

360 for symbol, n in count.items()] 

361 cur.executemany('INSERT INTO species VALUES (?, ?, ?)', 

362 species) 

363 

364 text_key_values = [] 

365 number_key_values = [] 

366 for key, value in key_value_pairs.items(): 

367 if isinstance(value, (numbers.Real, np.bool_)): 

368 number_key_values.append([key, float(value), id]) 

369 else: 

370 assert isinstance(value, str) 

371 text_key_values.append([key, value, id]) 

372 

373 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)', 

374 text_key_values) 

375 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)', 

376 number_key_values) 

377 cur.executemany('INSERT INTO keys VALUES (?, ?)', 

378 [(key, id) for key in key_value_pairs]) 

379 

380 # Insert entries in the valid tables 

381 for tabname in ext_tables.keys(): 

382 entries = ext_tables[tabname] 

383 entries['id'] = id 

384 self._insert_in_external_table( 

385 cur, name=tabname, entries=ext_tables[tabname]) 

386 

387 return id 

388 

389 def _update(self, id, key_value_pairs, data=None): 

390 """Update key_value_pairs and data for a single row """ 

391 encode = self.encode 

392 ext_tables = key_value_pairs.pop('external_tables', {}) 

393 

394 for k, v in ext_tables.items(): 

395 dtype = self._guess_type(v) 

396 self._create_table_if_not_exists(k, dtype) 

397 

398 mtime = now() 

399 with self.managed_connection() as con: 

400 cur = con.cursor() 

401 cur.execute( 

402 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?', 

403 (mtime, encode(key_value_pairs), id)) 

404 if data: 

405 if not isinstance(data, (str, bytes)): 

406 data = encode(data, binary=self.version >= 9) 

407 cur.execute('UPDATE systems set data=? where id=?', (data, id)) 

408 

409 self._delete(cur, [id], ['keys', 'text_key_values', 

410 'number_key_values']) 

411 

412 text_key_values = [] 

413 number_key_values = [] 

414 for key, value in key_value_pairs.items(): 

415 if isinstance(value, (numbers.Real, np.bool_)): 

416 number_key_values.append([key, float(value), id]) 

417 else: 

418 assert isinstance(value, str) 

419 text_key_values.append([key, value, id]) 

420 

421 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)', 

422 text_key_values) 

423 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)', 

424 number_key_values) 

425 cur.executemany('INSERT INTO keys VALUES (?, ?)', 

426 [(key, id) for key in key_value_pairs]) 

427 

428 # Insert entries in the valid tables 

429 for tabname in ext_tables.keys(): 

430 entries = ext_tables[tabname] 

431 entries['id'] = id 

432 self._insert_in_external_table( 

433 cur, name=tabname, entries=ext_tables[tabname]) 

434 

435 return id 

436 

437 def get_last_id(self, cur): 

438 cur.execute('SELECT seq FROM sqlite_sequence WHERE name="systems"') 

439 result = cur.fetchone() 

440 if result is not None: 

441 id = result[0] 

442 return id 

443 else: 

444 return 0 

445 

446 def _get_row(self, id): 

447 with self.managed_connection() as con: 

448 cur = con.cursor() 

449 if id is None: 

450 cur.execute('SELECT COUNT(*) FROM systems') 

451 assert cur.fetchone()[0] == 1 

452 cur.execute('SELECT * FROM systems') 

453 else: 

454 cur.execute('SELECT * FROM systems WHERE id=?', (id,)) 

455 values = cur.fetchone() 

456 

457 return self._convert_tuple_to_row(values) 

458 

459 def _convert_tuple_to_row(self, values): 

460 deblob = self.deblob 

461 decode = self.decode 

462 

463 values = self._old2new(values) 

464 dct = {'id': values[0], 

465 'unique_id': values[1], 

466 'ctime': values[2], 

467 'mtime': values[3], 

468 'user': values[4], 

469 'numbers': deblob(values[5], np.int32), 

470 'positions': deblob(values[6], shape=(-1, 3)), 

471 'cell': deblob(values[7], shape=(3, 3))} 

472 

473 if values[8] is not None: 

474 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool) 

475 if values[9] is not None: 

476 dct['initial_magmoms'] = deblob(values[9]) 

477 if values[10] is not None: 

478 dct['initial_charges'] = deblob(values[10]) 

479 if values[11] is not None: 

480 dct['masses'] = deblob(values[11]) 

481 if values[12] is not None: 

482 dct['tags'] = deblob(values[12], np.int32) 

483 if values[13] is not None: 

484 dct['momenta'] = deblob(values[13], shape=(-1, 3)) 

485 if values[14] is not None: 

486 dct['constraints'] = values[14] 

487 if values[15] is not None: 

488 dct['calculator'] = values[15] 

489 if values[16] is not None: 

490 dct['calculator_parameters'] = decode(values[16]) 

491 if values[17] is not None: 

492 dct['energy'] = values[17] 

493 if values[18] is not None: 

494 dct['free_energy'] = values[18] 

495 if values[19] is not None: 

496 dct['forces'] = deblob(values[19], shape=(-1, 3)) 

497 if values[20] is not None: 

498 dct['stress'] = deblob(values[20]) 

499 if values[21] is not None: 

500 dct['dipole'] = deblob(values[21]) 

501 if values[22] is not None: 

502 dct['magmoms'] = deblob(values[22]) 

503 if values[23] is not None: 

504 dct['magmom'] = values[23] 

505 if values[24] is not None: 

506 dct['charges'] = deblob(values[24]) 

507 if values[25] != '{}': 

508 dct['key_value_pairs'] = decode(values[25]) 

509 if len(values) >= 27 and values[26] != 'null': 

510 dct['data'] = decode(values[26], lazy=True) 

511 

512 # Now we need to update with info from the external tables 

513 external_tab = self._get_external_table_names() 

514 tables = {} 

515 for tab in external_tab: 

516 row = self._read_external_table(tab, dct["id"]) 

517 tables[tab] = row 

518 

519 dct.update(tables) 

520 return AtomsRow(dct) 

521 

522 def _old2new(self, values): 

523 if self.type == 'postgresql': 

524 assert self.version >= 8, 'Your db-version is too old!' 

525 assert self.version >= 4, 'Your db-file is too old!' 

526 if self.version < 5: 

527 pass # should be ok for reading by convert.py script 

528 if self.version < 6: 

529 m = values[23] 

530 if m is not None and not isinstance(m, float): 

531 magmom = float(self.deblob(m, shape=())) 

532 values = values[:23] + (magmom,) + values[24:] 

533 return values 

534 

535 def create_select_statement(self, keys, cmps, 

536 sort=None, order=None, sort_table=None, 

537 what='systems.*'): 

538 tables = ['systems'] 

539 where = [] 

540 args = [] 

541 for key in keys: 

542 if key == 'forces': 

543 where.append('systems.fmax IS NOT NULL') 

544 elif key == 'strain': 

545 where.append('systems.smax IS NOT NULL') 

546 elif key in ['energy', 'fmax', 'smax', 

547 'constraints', 'calculator']: 

548 where.append(f'systems.{key} IS NOT NULL') 

549 else: 

550 if '-' not in key: 

551 q = 'systems.id in (select id from keys where key=?)' 

552 else: 

553 key = key.replace('-', '') 

554 q = 'systems.id not in (select id from keys where key=?)' 

555 where.append(q) 

556 args.append(key) 

557 

558 # Special handling of "H=0" and "H<2" type of selections: 

559 bad = {} 

560 for key, op, value in cmps: 

561 if isinstance(key, int): 

562 bad[key] = bad.get(key, True) and ops[op](0, value) 

563 

564 for key, op, value in cmps: 

565 if key in ['id', 'energy', 'magmom', 'ctime', 'user', 

566 'calculator', 'natoms', 'pbc', 'unique_id', 

567 'fmax', 'smax', 'volume', 'mass', 'charge']: 

568 if key == 'user': 

569 key = 'username' 

570 elif key == 'pbc': 

571 assert op in ['=', '!='] 

572 value = int(np.dot([x == 'T' for x in value], [1, 2, 4])) 

573 elif key == 'magmom': 

574 assert self.version >= 6, 'Update your db-file' 

575 where.append(f'systems.{key}{op}?') 

576 args.append(value) 

577 elif isinstance(key, int): 

578 if self.type == 'postgresql': 

579 where.append( 

580 'cardinality(array_positions(' + 

581 f'numbers::int[], ?)){op}?') 

582 args += [key, value] 

583 else: 

584 if bad[key]: 

585 where.append( 

586 'systems.id not in (select id from species ' + 

587 f'where Z=? and n{invop[op]}?)') 

588 args += [key, value] 

589 else: 

590 where.append('systems.id in (select id from species ' + 

591 f'where Z=? and n{op}?)') 

592 args += [key, value] 

593 

594 elif self.type == 'postgresql': 

595 jsonop = '->' 

596 if isinstance(value, str): 

597 jsonop = '->>' 

598 elif isinstance(value, bool): 

599 jsonop = '->>' 

600 value = str(value).lower() 

601 where.append("systems.key_value_pairs {} '{}'{}?" 

602 .format(jsonop, key, op)) 

603 args.append(str(value)) 

604 

605 elif isinstance(value, str): 

606 where.append('systems.id in (select id from text_key_values ' + 

607 f'where key=? and value{op}?)') 

608 args += [key, value] 

609 else: 

610 where.append( 

611 'systems.id in (select id from number_key_values ' + 

612 f'where key=? and value{op}?)') 

613 args += [key, float(value)] 

614 

615 if sort: 

616 if sort_table != 'systems': 

617 tables.append(f'{sort_table} AS sort_table') 

618 where.append('systems.id=sort_table.id AND ' 

619 'sort_table.key=?') 

620 args.append(sort) 

621 sort_table = 'sort_table' 

622 sort = 'value' 

623 

624 sql = f'SELECT {what} FROM\n ' + ', '.join(tables) 

625 if where: 

626 sql += '\n WHERE\n ' + ' AND\n '.join(where) 

627 if sort: 

628 # XXX use "?" instead of "{}" 

629 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format( 

630 sort_table, sort, order) 

631 

632 return sql, args 

633 

634 def _select(self, keys, cmps, explain=False, verbosity=0, 

635 limit=None, offset=0, sort=None, include_data=True, 

636 columns='all'): 

637 

638 values = np.array([None for _ in range(27)]) 

639 values[25] = '{}' 

640 values[26] = 'null' 

641 

642 if columns == 'all': 

643 columnindex = list(range(26)) 

644 else: 

645 columnindex = [c for c in range(26) 

646 if self.columnnames[c] in columns] 

647 if include_data: 

648 columnindex.append(26) 

649 

650 if sort: 

651 if sort[0] == '-': 

652 order = 'DESC' 

653 sort = sort[1:] 

654 else: 

655 order = 'ASC' 

656 if sort in ['id', 'energy', 'username', 'calculator', 

657 'ctime', 'mtime', 'magmom', 'pbc', 

658 'fmax', 'smax', 'volume', 'mass', 'charge', 'natoms']: 

659 sort_table = 'systems' 

660 else: 

661 for dct in self._select(keys + [sort], cmps=[], limit=1, 

662 include_data=False, 

663 columns=['key_value_pairs']): 

664 if isinstance(dct['key_value_pairs'][sort], str): 

665 sort_table = 'text_key_values' 

666 else: 

667 sort_table = 'number_key_values' 

668 break 

669 else: 

670 # No rows. Just pick a table: 

671 sort_table = 'number_key_values' 

672 

673 else: 

674 order = None 

675 sort_table = None 

676 

677 what = ', '.join('systems.' + name 

678 for name in 

679 np.array(self.columnnames)[np.array(columnindex)]) 

680 

681 sql, args = self.create_select_statement(keys, cmps, sort, order, 

682 sort_table, what) 

683 

684 if explain: 

685 sql = 'EXPLAIN QUERY PLAN ' + sql 

686 

687 if limit: 

688 sql += f'\nLIMIT {limit}' 

689 

690 if offset: 

691 sql += self.get_offset_string(offset, limit=limit) 

692 

693 if verbosity == 2: 

694 print(sql, args) 

695 

696 with self.managed_connection() as con: 

697 cur = con.cursor() 

698 cur.execute(sql, args) 

699 if explain: 

700 for row in cur.fetchall(): 

701 yield {'explain': row} 

702 else: 

703 n = 0 

704 for shortvalues in cur.fetchall(): 

705 values[columnindex] = shortvalues 

706 yield self._convert_tuple_to_row(tuple(values)) 

707 n += 1 

708 

709 if sort and sort_table != 'systems': 

710 # Yield rows without sort key last: 

711 if limit is not None: 

712 if n == limit: 

713 return 

714 limit -= n 

715 for row in self._select(keys + ['-' + sort], cmps, 

716 limit=limit, offset=offset, 

717 include_data=include_data, 

718 columns=columns): 

719 yield row 

720 

721 def get_offset_string(self, offset, limit=None): 

722 sql = '' 

723 if not limit: 

724 # In sqlite you cannot have offset without limit, so we 

725 # set it to -1 meaning no limit 

726 sql += '\nLIMIT -1' 

727 sql += f'\nOFFSET {offset}' 

728 return sql 

729 

730 @parallel_function 

731 def count(self, selection=None, **kwargs): 

732 keys, cmps = parse_selection(selection, **kwargs) 

733 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)') 

734 

735 with self.managed_connection() as con: 

736 cur = con.cursor() 

737 cur.execute(sql, args) 

738 return cur.fetchone()[0] 

739 

740 def analyse(self): 

741 with self.managed_connection() as con: 

742 con.execute('ANALYZE') 

743 

744 @parallel_function 

745 @lock 

746 def delete(self, ids): 

747 if len(ids) == 0: 

748 return 

749 table_names = self._get_external_table_names() + all_tables[::-1] 

750 with self.managed_connection() as con: 

751 self._delete(con.cursor(), ids, 

752 tables=table_names) 

753 self.vacuum() 

754 

755 def _delete(self, cur, ids, tables=None): 

756 tables = tables or all_tables[::-1] 

757 for table in tables: 

758 cur.execute('DELETE FROM {} WHERE id in ({});'. 

759 format(table, ', '.join([str(id) for id in ids]))) 

760 

761 def vacuum(self): 

762 if self.type != 'db': 

763 return 

764 

765 with self.managed_connection() as con: 

766 con.commit() 

767 con.cursor().execute("VACUUM") 

768 

769 @property 

770 def metadata(self): 

771 if self._metadata is None: 

772 self._initialize(self._connect()) 

773 return self._metadata.copy() 

774 

775 @metadata.setter 

776 def metadata(self, dct): 

777 self._metadata = dct 

778 md = json.dumps(dct) 

779 with self.managed_connection() as con: 

780 cur = con.cursor() 

781 cur.execute( 

782 "SELECT COUNT(*) FROM information WHERE name='metadata'") 

783 

784 if cur.fetchone()[0]: 

785 cur.execute( 

786 "UPDATE information SET value=? WHERE name='metadata'", 

787 [md]) 

788 else: 

789 cur.execute('INSERT INTO information VALUES (?, ?)', 

790 ('metadata', md)) 

791 

792 def _get_external_table_names(self, db_con=None): 

793 """Return a list with the external table names.""" 

794 sql = "SELECT value FROM information WHERE name='external_table_name'" 

795 with self.managed_connection() as con: 

796 cur = con.cursor() 

797 cur.execute(sql) 

798 ext_tab_names = [x[0] for x in cur.fetchall()] 

799 return ext_tab_names 

800 

801 def _external_table_exists(self, name): 

802 """Return True if an external table name exists.""" 

803 return name in self._get_external_table_names() 

804 

805 def _create_table_if_not_exists(self, name, dtype): 

806 """Create a new table if it does not exits. 

807 

808 Arguments 

809 ========== 

810 name: str 

811 Name of the new table 

812 dtype: str 

813 Datatype of the value field (typically REAL, INTEGER, TEXT etc.) 

814 """ 

815 

816 taken_names = set(all_tables + all_properties + self.columnnames) 

817 if name in taken_names: 

818 raise ValueError("External table can not be any of {}" 

819 "".format(taken_names)) 

820 

821 if self._external_table_exists(name): 

822 return 

823 

824 sql = f"CREATE TABLE IF NOT EXISTS {name} " 

825 sql += f"(key TEXT, value {dtype}, id INTEGER, " 

826 sql += "FOREIGN KEY (id) REFERENCES systems(id))" 

827 sql2 = "INSERT INTO information VALUES (?, ?)" 

828 with self.managed_connection() as con: 

829 cur = con.cursor() 

830 cur.execute(sql) 

831 # Insert an entry saying that there is a new external table 

832 # present and an entry with the datatype 

833 cur.execute(sql2, ("external_table_name", name)) 

834 cur.execute(sql2, (name + "_dtype", dtype)) 

835 

836 def delete_external_table(self, name): 

837 """Delete an external table.""" 

838 if not self._external_table_exists(name): 

839 return 

840 

841 with self.managed_connection() as con: 

842 cur = con.cursor() 

843 

844 sql = f"DROP TABLE {name}" 

845 cur.execute(sql) 

846 

847 sql = "DELETE FROM information WHERE value=?" 

848 cur.execute(sql, (name,)) 

849 sql = "DELETE FROM information WHERE name=?" 

850 cur.execute(sql, (name + "_dtype",)) 

851 

852 def _convert_to_recognized_types(self, value): 

853 """Convert Numpy types to python types.""" 

854 if np.issubdtype(type(value), np.integer): 

855 return int(value) 

856 elif np.issubdtype(type(value), np.floating): 

857 return float(value) 

858 return value 

859 

860 def _insert_in_external_table(self, cursor, name=None, entries=None): 

861 """Insert into external table""" 

862 if name is None or entries is None: 

863 # There is nothing to do 

864 return 

865 

866 id = entries.pop("id") 

867 dtype = self._guess_type(entries) 

868 expected_dtype = self._get_value_type_of_table(cursor, name) 

869 if dtype != expected_dtype: 

870 raise ValueError("The provided data type for table {} " 

871 "is {}, while it is initialized to " 

872 "be of type {}" 

873 "".format(name, dtype, expected_dtype)) 

874 

875 # First we check if entries already exists 

876 cursor.execute(f"SELECT key FROM {name} WHERE id=?", (id,)) 

877 updates = [] 

878 for item in cursor.fetchall(): 

879 value = entries.pop(item[0], None) 

880 if value is not None: 

881 updates.append( 

882 (value, id, self._convert_to_recognized_types(item[0]))) 

883 

884 # Update entry if key and ID already exists 

885 sql = f"UPDATE {name} SET value=? WHERE id=? AND key=?" 

886 cursor.executemany(sql, updates) 

887 

888 # Insert the ones that does not already exist 

889 inserts = [(k, self._convert_to_recognized_types(v), id) 

890 for k, v in entries.items()] 

891 sql = f"INSERT INTO {name} VALUES (?, ?, ?)" 

892 cursor.executemany(sql, inserts) 

893 

894 def _guess_type(self, entries): 

895 """Guess the type based on the first entry.""" 

896 values = [v for _, v in entries.items()] 

897 

898 # Check if all datatypes are the same 

899 all_types = [type(v) for v in values] 

900 if any(t != all_types[0] for t in all_types): 

901 typenames = [t.__name__ for t in all_types] 

902 raise ValueError("Inconsistent datatypes in the table. " 

903 "given types: {}".format(typenames)) 

904 

905 val = values[0] 

906 if isinstance(val, int) or np.issubdtype(type(val), np.integer): 

907 return "INTEGER" 

908 if isinstance(val, float) or np.issubdtype(type(val), np.floating): 

909 return "REAL" 

910 if isinstance(val, str): 

911 return "TEXT" 

912 raise ValueError("Unknown datatype!") 

913 

914 def _get_value_type_of_table(self, cursor, tab_name): 

915 """Return the expected value name.""" 

916 sql = "SELECT value FROM information WHERE name=?" 

917 cursor.execute(sql, (tab_name + "_dtype",)) 

918 return cursor.fetchone()[0] 

919 

920 def _read_external_table(self, name, id): 

921 """Read row from external table.""" 

922 

923 with self.managed_connection() as con: 

924 cur = con.cursor() 

925 cur.execute(f"SELECT * FROM {name} WHERE id=?", (id,)) 

926 items = cur.fetchall() 

927 dictionary = {item[0]: item[1] for item in items} 

928 

929 return dictionary 

930 

931 def get_all_key_names(self): 

932 """Create set of all key names.""" 

933 with self.managed_connection() as con: 

934 cur = con.cursor() 

935 cur.execute('SELECT DISTINCT key FROM keys;') 

936 all_keys = {row[0] for row in cur.fetchall()} 

937 return all_keys 

938 

939 

940if __name__ == '__main__': 

941 from ase.db import connect 

942 con = connect(sys.argv[1]) 

943 con._initialize(con._connect()) 

944 print('Version:', con.version)