Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""SQLite3 backend. 

2 

3Versions: 

4 

51) Added 3 more columns. 

62) Changed "user" to "username". 

73) Now adding keys to keyword table and added an "information" table containing 

8 a version number. 

94) Got rid of keywords. 

105) Add fmax, smax, mass, volume, charge 

116) Use REAL for magmom and drop possibility for non-collinear spin 

127) Volume can be None 

138) Added name='metadata' row to "information" table 

149) Row data is now stored in binary format. 

15""" 

16 

17import json 

18import numbers 

19import os 

20import sqlite3 

21import sys 

22from contextlib import contextmanager 

23 

24import numpy as np 

25 

26import ase.io.jsonio 

27from ase.data import atomic_numbers 

28from ase.calculators.calculator import all_properties 

29from ase.db.row import AtomsRow 

30from ase.db.core import (Database, ops, now, lock, invop, parse_selection, 

31 object_to_bytes, bytes_to_object) 

32from ase.parallel import parallel_function 

33 

34VERSION = 9 

35 

36init_statements = [ 

37 """CREATE TABLE systems ( 

38 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name 

39 unique_id TEXT UNIQUE, 

40 ctime REAL, 

41 mtime REAL, 

42 username TEXT, 

43 numbers BLOB, -- stuff that defines an Atoms object 

44 positions BLOB, 

45 cell BLOB, 

46 pbc INTEGER, 

47 initial_magmoms BLOB, 

48 initial_charges BLOB, 

49 masses BLOB, 

50 tags BLOB, 

51 momenta BLOB, 

52 constraints TEXT, -- constraints and calculator 

53 calculator TEXT, 

54 calculator_parameters TEXT, 

55 energy REAL, -- calculated properties 

56 free_energy REAL, 

57 forces BLOB, 

58 stress BLOB, 

59 dipole BLOB, 

60 magmoms BLOB, 

61 magmom REAL, 

62 charges BLOB, 

63 key_value_pairs TEXT, -- key-value pairs and data as json 

64 data BLOB, 

65 natoms INTEGER, -- stuff for making queries faster 

66 fmax REAL, 

67 smax REAL, 

68 volume REAL, 

69 mass REAL, 

70 charge REAL)""", 

71 

72 """CREATE TABLE species ( 

73 Z INTEGER, 

74 n INTEGER, 

75 id INTEGER, 

76 FOREIGN KEY (id) REFERENCES systems(id))""", 

77 

78 """CREATE TABLE keys ( 

79 key TEXT, 

80 id INTEGER, 

81 FOREIGN KEY (id) REFERENCES systems(id))""", 

82 

83 """CREATE TABLE text_key_values ( 

84 key TEXT, 

85 value TEXT, 

86 id INTEGER, 

87 FOREIGN KEY (id) REFERENCES systems(id))""", 

88 

89 """CREATE TABLE number_key_values ( 

90 key TEXT, 

91 value REAL, 

92 id INTEGER, 

93 FOREIGN KEY (id) REFERENCES systems(id))""", 

94 

95 """CREATE TABLE information ( 

96 name TEXT, 

97 value TEXT)""", 

98 

99 "INSERT INTO information VALUES ('version', '{}')".format(VERSION)] 

100 

101index_statements = [ 

102 'CREATE INDEX unique_id_index ON systems(unique_id)', 

103 'CREATE INDEX ctime_index ON systems(ctime)', 

104 'CREATE INDEX username_index ON systems(username)', 

105 'CREATE INDEX calculator_index ON systems(calculator)', 

106 'CREATE INDEX species_index ON species(Z)', 

107 'CREATE INDEX key_index ON keys(key)', 

108 'CREATE INDEX text_index ON text_key_values(key)', 

109 'CREATE INDEX number_index ON number_key_values(key)'] 

110 

111all_tables = ['systems', 'species', 'keys', 

112 'text_key_values', 'number_key_values'] 

113 

114 

115def float_if_not_none(x): 

116 """Convert numpy.float64 to float - old db-interfaces need that.""" 

117 if x is not None: 

118 return float(x) 

119 

120 

121class SQLite3Database(Database): 

122 type = 'db' 

123 initialized = False 

124 _allow_reading_old_format = False 

125 default = 'NULL' # used for autoincrement id 

126 connection = None 

127 version = None 

128 columnnames = [line.split()[0].lstrip() 

129 for line in init_statements[0].splitlines()[1:]] 

130 

131 def encode(self, obj, binary=False): 

132 if binary: 

133 return object_to_bytes(obj) 

134 return ase.io.jsonio.encode(obj) 

135 

136 def decode(self, txt, lazy=False): 

137 if lazy: 

138 return txt 

139 if isinstance(txt, str): 

140 return ase.io.jsonio.decode(txt) 

141 return bytes_to_object(txt) 

142 

143 def blob(self, array): 

144 """Convert array to blob/buffer object.""" 

145 

146 if array is None: 

147 return None 

148 if len(array) == 0: 

149 array = np.zeros(0) 

150 if array.dtype == np.int64: 

151 array = array.astype(np.int32) 

152 if not np.little_endian: 

153 array = array.byteswap() 

154 return memoryview(np.ascontiguousarray(array)) 

155 

156 def deblob(self, buf, dtype=float, shape=None): 

157 """Convert blob/buffer object to ndarray of correct dtype and shape. 

158 

159 (without creating an extra view).""" 

160 if buf is None: 

161 return None 

162 if len(buf) == 0: 

163 array = np.zeros(0, dtype) 

164 else: 

165 array = np.frombuffer(buf, dtype) 

166 if not np.little_endian: 

167 array = array.byteswap() 

168 if shape is not None: 

169 array.shape = shape 

170 return array 

171 

172 def _connect(self): 

173 return sqlite3.connect(self.filename, timeout=20) 

174 

175 def __enter__(self): 

176 assert self.connection is None 

177 self.change_count = 0 

178 self.connection = self._connect() 

179 return self 

180 

181 def __exit__(self, exc_type, exc_value, tb): 

182 if exc_type is None: 

183 self.connection.commit() 

184 else: 

185 self.connection.rollback() 

186 self.connection.close() 

187 self.connection = None 

188 

189 @contextmanager 

190 def managed_connection(self, commit_frequency=5000): 

191 try: 

192 con = self.connection or self._connect() 

193 self._initialize(con) 

194 yield con 

195 except ValueError as exc: 

196 if self.connection is None: 

197 con.close() 

198 raise exc 

199 else: 

200 if self.connection is None: 

201 con.commit() 

202 con.close() 

203 else: 

204 self.change_count += 1 

205 if self.change_count % commit_frequency == 0: 

206 con.commit() 

207 

208 def _initialize(self, con): 

209 if self.initialized: 

210 return 

211 

212 self._metadata = {} 

213 

214 cur = con.execute( 

215 'SELECT COUNT(*) FROM sqlite_master WHERE name="systems"') 

216 

217 if cur.fetchone()[0] == 0: 

218 for statement in init_statements: 

219 con.execute(statement) 

220 if self.create_indices: 

221 for statement in index_statements: 

222 con.execute(statement) 

223 con.commit() 

224 self.version = VERSION 

225 else: 

226 cur = con.execute( 

227 'SELECT COUNT(*) FROM sqlite_master WHERE name="user_index"') 

228 if cur.fetchone()[0] == 1: 

229 # Old version with "user" instead of "username" column 

230 self.version = 1 

231 else: 

232 try: 

233 cur = con.execute( 

234 'SELECT value FROM information WHERE name="version"') 

235 except sqlite3.OperationalError: 

236 self.version = 2 

237 else: 

238 self.version = int(cur.fetchone()[0]) 

239 

240 cur = con.execute( 

241 'SELECT value FROM information WHERE name="metadata"') 

242 results = cur.fetchall() 

243 if results: 

244 self._metadata = json.loads(results[0][0]) 

245 

246 if self.version > VERSION: 

247 raise IOError('Can not read new ase.db format ' 

248 '(version {}). Please update to latest ASE.' 

249 .format(self.version)) 

250 if self.version < 5 and not self._allow_reading_old_format: 

251 raise IOError('Please convert to new format. ' + 

252 'Use: python -m ase.db.convert ' + self.filename) 

253 

254 self.initialized = True 

255 

256 def _write(self, atoms, key_value_pairs, data, id): 

257 ext_tables = key_value_pairs.pop("external_tables", {}) 

258 Database._write(self, atoms, key_value_pairs, data) 

259 

260 mtime = now() 

261 

262 encode = self.encode 

263 blob = self.blob 

264 

265 if not isinstance(atoms, AtomsRow): 

266 row = AtomsRow(atoms) 

267 row.ctime = mtime 

268 row.user = os.getenv('USER') 

269 else: 

270 row = atoms 

271 # Extract the external tables from AtomsRow 

272 names = self._get_external_table_names() 

273 for name in names: 

274 new_table = row.get(name, {}) 

275 if new_table: 

276 ext_tables[name] = new_table 

277 

278 if not id and not key_value_pairs and not ext_tables: 

279 key_value_pairs = row.key_value_pairs 

280 

281 for k, v in ext_tables.items(): 

282 dtype = self._guess_type(v) 

283 self._create_table_if_not_exists(k, dtype) 

284 

285 constraints = row._constraints 

286 if constraints: 

287 if isinstance(constraints, list): 

288 constraints = encode(constraints) 

289 else: 

290 constraints = None 

291 

292 values = (row.unique_id, 

293 row.ctime, 

294 mtime, 

295 row.user, 

296 blob(row.numbers), 

297 blob(row.positions), 

298 blob(row.cell), 

299 int(np.dot(row.pbc, [1, 2, 4])), 

300 blob(row.get('initial_magmoms')), 

301 blob(row.get('initial_charges')), 

302 blob(row.get('masses')), 

303 blob(row.get('tags')), 

304 blob(row.get('momenta')), 

305 constraints) 

306 

307 if 'calculator' in row: 

308 values += (row.calculator, encode(row.calculator_parameters)) 

309 else: 

310 values += (None, None) 

311 

312 if not data: 

313 data = row._data 

314 

315 with self.managed_connection() as con: 

316 if not isinstance(data, (str, bytes)): 

317 data = encode(data, binary=self.version >= 9) 

318 

319 values += (row.get('energy'), 

320 row.get('free_energy'), 

321 blob(row.get('forces')), 

322 blob(row.get('stress')), 

323 blob(row.get('dipole')), 

324 blob(row.get('magmoms')), 

325 row.get('magmom'), 

326 blob(row.get('charges')), 

327 encode(key_value_pairs), 

328 data, 

329 len(row.numbers), 

330 float_if_not_none(row.get('fmax')), 

331 float_if_not_none(row.get('smax')), 

332 float_if_not_none(row.get('volume')), 

333 float(row.mass), 

334 float(row.charge)) 

335 

336 cur = con.cursor() 

337 if id is None: 

338 q = self.default + ', ' + ', '.join('?' * len(values)) 

339 cur.execute('INSERT INTO systems VALUES ({})'.format(q), 

340 values) 

341 id = self.get_last_id(cur) 

342 else: 

343 self._delete(cur, [id], ['keys', 'text_key_values', 

344 'number_key_values', 'species']) 

345 q = ', '.join(name + '=?' for name in self.columnnames[1:]) 

346 cur.execute('UPDATE systems SET {} WHERE id=?'.format(q), 

347 values + (id,)) 

348 

349 count = row.count_atoms() 

350 if count: 

351 species = [(atomic_numbers[symbol], n, id) 

352 for symbol, n in count.items()] 

353 cur.executemany('INSERT INTO species VALUES (?, ?, ?)', 

354 species) 

355 

356 text_key_values = [] 

357 number_key_values = [] 

358 for key, value in key_value_pairs.items(): 

359 if isinstance(value, (numbers.Real, np.bool_)): 

360 number_key_values.append([key, float(value), id]) 

361 else: 

362 assert isinstance(value, str) 

363 text_key_values.append([key, value, id]) 

364 

365 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)', 

366 text_key_values) 

367 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)', 

368 number_key_values) 

369 cur.executemany('INSERT INTO keys VALUES (?, ?)', 

370 [(key, id) for key in key_value_pairs]) 

371 

372 # Insert entries in the valid tables 

373 for tabname in ext_tables.keys(): 

374 entries = ext_tables[tabname] 

375 entries['id'] = id 

376 self._insert_in_external_table( 

377 cur, name=tabname, entries=ext_tables[tabname]) 

378 

379 return id 

380 

381 def _update(self, id, key_value_pairs, data=None): 

382 """Update key_value_pairs and data for a single row """ 

383 encode = self.encode 

384 ext_tables = key_value_pairs.pop('external_tables', {}) 

385 

386 for k, v in ext_tables.items(): 

387 dtype = self._guess_type(v) 

388 self._create_table_if_not_exists(k, dtype) 

389 

390 mtime = now() 

391 with self.managed_connection() as con: 

392 cur = con.cursor() 

393 cur.execute( 

394 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?', 

395 (mtime, encode(key_value_pairs), id)) 

396 if data: 

397 if not isinstance(data, (str, bytes)): 

398 data = encode(data, binary=self.version >= 9) 

399 cur.execute('UPDATE systems set data=? where id=?', (data, id)) 

400 

401 self._delete(cur, [id], ['keys', 'text_key_values', 

402 'number_key_values']) 

403 

404 text_key_values = [] 

405 number_key_values = [] 

406 for key, value in key_value_pairs.items(): 

407 if isinstance(value, (numbers.Real, np.bool_)): 

408 number_key_values.append([key, float(value), id]) 

409 else: 

410 assert isinstance(value, str) 

411 text_key_values.append([key, value, id]) 

412 

413 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)', 

414 text_key_values) 

415 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)', 

416 number_key_values) 

417 cur.executemany('INSERT INTO keys VALUES (?, ?)', 

418 [(key, id) for key in key_value_pairs]) 

419 

420 # Insert entries in the valid tables 

421 for tabname in ext_tables.keys(): 

422 entries = ext_tables[tabname] 

423 entries['id'] = id 

424 self._insert_in_external_table( 

425 cur, name=tabname, entries=ext_tables[tabname]) 

426 

427 return id 

428 

429 def get_last_id(self, cur): 

430 cur.execute('SELECT seq FROM sqlite_sequence WHERE name="systems"') 

431 result = cur.fetchone() 

432 if result is not None: 

433 id = result[0] 

434 return id 

435 else: 

436 return 0 

437 

438 def _get_row(self, id): 

439 with self.managed_connection() as con: 

440 cur = con.cursor() 

441 if id is None: 

442 cur.execute('SELECT COUNT(*) FROM systems') 

443 assert cur.fetchone()[0] == 1 

444 cur.execute('SELECT * FROM systems') 

445 else: 

446 cur.execute('SELECT * FROM systems WHERE id=?', (id,)) 

447 values = cur.fetchone() 

448 

449 return self._convert_tuple_to_row(values) 

450 

451 def _convert_tuple_to_row(self, values): 

452 deblob = self.deblob 

453 decode = self.decode 

454 

455 values = self._old2new(values) 

456 dct = {'id': values[0], 

457 'unique_id': values[1], 

458 'ctime': values[2], 

459 'mtime': values[3], 

460 'user': values[4], 

461 'numbers': deblob(values[5], np.int32), 

462 'positions': deblob(values[6], shape=(-1, 3)), 

463 'cell': deblob(values[7], shape=(3, 3))} 

464 

465 if values[8] is not None: 

466 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool) 

467 if values[9] is not None: 

468 dct['initial_magmoms'] = deblob(values[9]) 

469 if values[10] is not None: 

470 dct['initial_charges'] = deblob(values[10]) 

471 if values[11] is not None: 

472 dct['masses'] = deblob(values[11]) 

473 if values[12] is not None: 

474 dct['tags'] = deblob(values[12], np.int32) 

475 if values[13] is not None: 

476 dct['momenta'] = deblob(values[13], shape=(-1, 3)) 

477 if values[14] is not None: 

478 dct['constraints'] = values[14] 

479 if values[15] is not None: 

480 dct['calculator'] = values[15] 

481 if values[16] is not None: 

482 dct['calculator_parameters'] = decode(values[16]) 

483 if values[17] is not None: 

484 dct['energy'] = values[17] 

485 if values[18] is not None: 

486 dct['free_energy'] = values[18] 

487 if values[19] is not None: 

488 dct['forces'] = deblob(values[19], shape=(-1, 3)) 

489 if values[20] is not None: 

490 dct['stress'] = deblob(values[20]) 

491 if values[21] is not None: 

492 dct['dipole'] = deblob(values[21]) 

493 if values[22] is not None: 

494 dct['magmoms'] = deblob(values[22]) 

495 if values[23] is not None: 

496 dct['magmom'] = values[23] 

497 if values[24] is not None: 

498 dct['charges'] = deblob(values[24]) 

499 if values[25] != '{}': 

500 dct['key_value_pairs'] = decode(values[25]) 

501 if len(values) >= 27 and values[26] != 'null': 

502 dct['data'] = decode(values[26], lazy=True) 

503 

504 # Now we need to update with info from the external tables 

505 external_tab = self._get_external_table_names() 

506 tables = {} 

507 for tab in external_tab: 

508 row = self._read_external_table(tab, dct["id"]) 

509 tables[tab] = row 

510 

511 dct.update(tables) 

512 return AtomsRow(dct) 

513 

514 def _old2new(self, values): 

515 if self.type == 'postgresql': 

516 assert self.version >= 8, 'Your db-version is too old!' 

517 assert self.version >= 4, 'Your db-file is too old!' 

518 if self.version < 5: 

519 pass # should be ok for reading by convert.py script 

520 if self.version < 6: 

521 m = values[23] 

522 if m is not None and not isinstance(m, float): 

523 magmom = float(self.deblob(m, shape=())) 

524 values = values[:23] + (magmom,) + values[24:] 

525 return values 

526 

527 def create_select_statement(self, keys, cmps, 

528 sort=None, order=None, sort_table=None, 

529 what='systems.*'): 

530 tables = ['systems'] 

531 where = [] 

532 args = [] 

533 for key in keys: 

534 if key == 'forces': 

535 where.append('systems.fmax IS NOT NULL') 

536 elif key == 'strain': 

537 where.append('systems.smax IS NOT NULL') 

538 elif key in ['energy', 'fmax', 'smax', 

539 'constraints', 'calculator']: 

540 where.append('systems.{} IS NOT NULL'.format(key)) 

541 else: 

542 if '-' not in key: 

543 q = 'systems.id in (select id from keys where key=?)' 

544 else: 

545 key = key.replace('-', '') 

546 q = 'systems.id not in (select id from keys where key=?)' 

547 where.append(q) 

548 args.append(key) 

549 

550 # Special handling of "H=0" and "H<2" type of selections: 

551 bad = {} 

552 for key, op, value in cmps: 

553 if isinstance(key, int): 

554 bad[key] = bad.get(key, True) and ops[op](0, value) 

555 

556 for key, op, value in cmps: 

557 if key in ['id', 'energy', 'magmom', 'ctime', 'user', 

558 'calculator', 'natoms', 'pbc', 'unique_id', 

559 'fmax', 'smax', 'volume', 'mass', 'charge']: 

560 if key == 'user': 

561 key = 'username' 

562 elif key == 'pbc': 

563 assert op in ['=', '!='] 

564 value = int(np.dot([x == 'T' for x in value], [1, 2, 4])) 

565 elif key == 'magmom': 

566 assert self.version >= 6, 'Update your db-file' 

567 where.append('systems.{}{}?'.format(key, op)) 

568 args.append(value) 

569 elif isinstance(key, int): 

570 if self.type == 'postgresql': 

571 where.append( 

572 'cardinality(array_positions(' + 

573 'numbers::int[], ?)){}?'.format(op)) 

574 args += [key, value] 

575 else: 

576 if bad[key]: 

577 where.append( 

578 'systems.id not in (select id from species ' + 

579 'where Z=? and n{}?)'.format(invop[op])) 

580 args += [key, value] 

581 else: 

582 where.append('systems.id in (select id from species ' + 

583 'where Z=? and n{}?)'.format(op)) 

584 args += [key, value] 

585 

586 elif self.type == 'postgresql': 

587 jsonop = '->' 

588 if isinstance(value, str): 

589 jsonop = '->>' 

590 elif isinstance(value, bool): 

591 jsonop = '->>' 

592 value = str(value).lower() 

593 where.append("systems.key_value_pairs {} '{}'{}?" 

594 .format(jsonop, key, op)) 

595 args.append(str(value)) 

596 

597 elif isinstance(value, str): 

598 where.append('systems.id in (select id from text_key_values ' + 

599 'where key=? and value{}?)'.format(op)) 

600 args += [key, value] 

601 else: 

602 where.append( 

603 'systems.id in (select id from number_key_values ' + 

604 'where key=? and value{}?)'.format(op)) 

605 args += [key, float(value)] 

606 

607 if sort: 

608 if sort_table != 'systems': 

609 tables.append('{} AS sort_table'.format(sort_table)) 

610 where.append('systems.id=sort_table.id AND ' 

611 'sort_table.key=?') 

612 args.append(sort) 

613 sort_table = 'sort_table' 

614 sort = 'value' 

615 

616 sql = 'SELECT {} FROM\n '.format(what) + ', '.join(tables) 

617 if where: 

618 sql += '\n WHERE\n ' + ' AND\n '.join(where) 

619 if sort: 

620 # XXX use "?" instead of "{}" 

621 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format( 

622 sort_table, sort, order) 

623 

624 return sql, args 

625 

626 def _select(self, keys, cmps, explain=False, verbosity=0, 

627 limit=None, offset=0, sort=None, include_data=True, 

628 columns='all'): 

629 

630 values = np.array([None for i in range(27)]) 

631 values[25] = '{}' 

632 values[26] = 'null' 

633 

634 if columns == 'all': 

635 columnindex = list(range(26)) 

636 else: 

637 columnindex = [c for c in range(0, 26) 

638 if self.columnnames[c] in columns] 

639 if include_data: 

640 columnindex.append(26) 

641 

642 if sort: 

643 if sort[0] == '-': 

644 order = 'DESC' 

645 sort = sort[1:] 

646 else: 

647 order = 'ASC' 

648 if sort in ['id', 'energy', 'username', 'calculator', 

649 'ctime', 'mtime', 'magmom', 'pbc', 

650 'fmax', 'smax', 'volume', 'mass', 'charge', 'natoms']: 

651 sort_table = 'systems' 

652 else: 

653 for dct in self._select(keys + [sort], cmps=[], limit=1, 

654 include_data=False, 

655 columns=['key_value_pairs']): 

656 if isinstance(dct['key_value_pairs'][sort], str): 

657 sort_table = 'text_key_values' 

658 else: 

659 sort_table = 'number_key_values' 

660 break 

661 else: 

662 # No rows. Just pick a table: 

663 sort_table = 'number_key_values' 

664 

665 else: 

666 order = None 

667 sort_table = None 

668 

669 what = ', '.join('systems.' + name 

670 for name in 

671 np.array(self.columnnames)[np.array(columnindex)]) 

672 

673 sql, args = self.create_select_statement(keys, cmps, sort, order, 

674 sort_table, what) 

675 

676 if explain: 

677 sql = 'EXPLAIN QUERY PLAN ' + sql 

678 

679 if limit: 

680 sql += '\nLIMIT {0}'.format(limit) 

681 

682 if offset: 

683 sql += self.get_offset_string(offset, limit=limit) 

684 

685 if verbosity == 2: 

686 print(sql, args) 

687 

688 with self.managed_connection() as con: 

689 cur = con.cursor() 

690 cur.execute(sql, args) 

691 if explain: 

692 for row in cur.fetchall(): 

693 yield {'explain': row} 

694 else: 

695 n = 0 

696 for shortvalues in cur.fetchall(): 

697 values[columnindex] = shortvalues 

698 yield self._convert_tuple_to_row(tuple(values)) 

699 n += 1 

700 

701 if sort and sort_table != 'systems': 

702 # Yield rows without sort key last: 

703 if limit is not None: 

704 if n == limit: 

705 return 

706 limit -= n 

707 for row in self._select(keys + ['-' + sort], cmps, 

708 limit=limit, offset=offset, 

709 include_data=include_data, 

710 columns=columns): 

711 yield row 

712 

713 def get_offset_string(self, offset, limit=None): 

714 sql = '' 

715 if not limit: 

716 # In sqlite you cannot have offset without limit, so we 

717 # set it to -1 meaning no limit 

718 sql += '\nLIMIT -1' 

719 sql += '\nOFFSET {0}'.format(offset) 

720 return sql 

721 

722 @parallel_function 

723 def count(self, selection=None, **kwargs): 

724 keys, cmps = parse_selection(selection, **kwargs) 

725 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)') 

726 

727 with self.managed_connection() as con: 

728 cur = con.cursor() 

729 cur.execute(sql, args) 

730 return cur.fetchone()[0] 

731 

732 def analyse(self): 

733 with self.managed_connection() as con: 

734 con.execute('ANALYZE') 

735 

736 @parallel_function 

737 @lock 

738 def delete(self, ids): 

739 if len(ids) == 0: 

740 return 

741 table_names = self._get_external_table_names() + all_tables[::-1] 

742 with self.managed_connection() as con: 

743 self._delete(con.cursor(), ids, 

744 tables=table_names) 

745 self.vacuum() 

746 

747 def _delete(self, cur, ids, tables=None): 

748 tables = tables or all_tables[::-1] 

749 for table in tables: 

750 cur.execute('DELETE FROM {} WHERE id in ({});'. 

751 format(table, ', '.join([str(id) for id in ids]))) 

752 

753 def vacuum(self): 

754 if not self.type == 'db': 

755 return 

756 

757 with self.managed_connection() as con: 

758 con.commit() 

759 con.cursor().execute("VACUUM") 

760 

761 @property 

762 def metadata(self): 

763 if self._metadata is None: 

764 self._initialize(self._connect()) 

765 return self._metadata.copy() 

766 

767 @metadata.setter 

768 def metadata(self, dct): 

769 self._metadata = dct 

770 md = json.dumps(dct) 

771 with self.managed_connection() as con: 

772 cur = con.cursor() 

773 cur.execute( 

774 "SELECT COUNT(*) FROM information WHERE name='metadata'") 

775 

776 if cur.fetchone()[0]: 

777 cur.execute( 

778 "UPDATE information SET value=? WHERE name='metadata'", 

779 [md]) 

780 else: 

781 cur.execute('INSERT INTO information VALUES (?, ?)', 

782 ('metadata', md)) 

783 

784 def _get_external_table_names(self, db_con=None): 

785 """Return a list with the external table names.""" 

786 sql = "SELECT value FROM information WHERE name='external_table_name'" 

787 with self.managed_connection() as con: 

788 cur = con.cursor() 

789 cur.execute(sql) 

790 ext_tab_names = [x[0] for x in cur.fetchall()] 

791 return ext_tab_names 

792 

793 def _external_table_exists(self, name): 

794 """Return True if an external table name exists.""" 

795 return name in self._get_external_table_names() 

796 

797 def _create_table_if_not_exists(self, name, dtype): 

798 """Create a new table if it does not exits. 

799 

800 Arguments 

801 ========== 

802 name: str 

803 Name of the new table 

804 dtype: str 

805 Datatype of the value field (typically REAL, INTEGER, TEXT etc.) 

806 """ 

807 

808 taken_names = set(all_tables + all_properties + self.columnnames) 

809 if name in taken_names: 

810 raise ValueError("External table can not be any of {}" 

811 "".format(taken_names)) 

812 

813 if self._external_table_exists(name): 

814 return 

815 

816 sql = "CREATE TABLE IF NOT EXISTS {} ".format(name) 

817 sql += "(key TEXT, value {}, id INTEGER, ".format(dtype) 

818 sql += "FOREIGN KEY (id) REFERENCES systems(id))" 

819 sql2 = "INSERT INTO information VALUES (?, ?)" 

820 with self.managed_connection() as con: 

821 cur = con.cursor() 

822 cur.execute(sql) 

823 # Insert an entry saying that there is a new external table 

824 # present and an entry with the datatype 

825 cur.execute(sql2, ("external_table_name", name)) 

826 cur.execute(sql2, (name + "_dtype", dtype)) 

827 

828 def delete_external_table(self, name): 

829 """Delete an external table.""" 

830 if not self._external_table_exists(name): 

831 return 

832 

833 with self.managed_connection() as con: 

834 cur = con.cursor() 

835 

836 sql = "DROP TABLE {}".format(name) 

837 cur.execute(sql) 

838 

839 sql = "DELETE FROM information WHERE value=?" 

840 cur.execute(sql, (name,)) 

841 sql = "DELETE FROM information WHERE name=?" 

842 cur.execute(sql, (name + "_dtype",)) 

843 

844 def _convert_to_recognized_types(self, value): 

845 """Convert Numpy types to python types.""" 

846 if np.issubdtype(type(value), np.integer): 

847 return int(value) 

848 elif np.issubdtype(type(value), np.floating): 

849 return float(value) 

850 return value 

851 

852 def _insert_in_external_table(self, cursor, name=None, entries=None): 

853 """Insert into external table""" 

854 if name is None or entries is None: 

855 # There is nothing to do 

856 return 

857 

858 id = entries.pop("id") 

859 dtype = self._guess_type(entries) 

860 expected_dtype = self._get_value_type_of_table(cursor, name) 

861 if dtype != expected_dtype: 

862 raise ValueError("The provided data type for table {} " 

863 "is {}, while it is initialized to " 

864 "be of type {}" 

865 "".format(name, dtype, expected_dtype)) 

866 

867 # First we check if entries already exists 

868 cursor.execute("SELECT key FROM {} WHERE id=?".format(name), (id,)) 

869 updates = [] 

870 for item in cursor.fetchall(): 

871 value = entries.pop(item[0], None) 

872 if value is not None: 

873 updates.append( 

874 (value, id, self._convert_to_recognized_types(item[0]))) 

875 

876 # Update entry if key and ID already exists 

877 sql = "UPDATE {} SET value=? WHERE id=? AND key=?".format(name) 

878 cursor.executemany(sql, updates) 

879 

880 # Insert the ones that does not already exist 

881 inserts = [(k, self._convert_to_recognized_types(v), id) 

882 for k, v in entries.items()] 

883 sql = "INSERT INTO {} VALUES (?, ?, ?)".format(name) 

884 cursor.executemany(sql, inserts) 

885 

886 def _guess_type(self, entries): 

887 """Guess the type based on the first entry.""" 

888 values = [v for _, v in entries.items()] 

889 

890 # Check if all datatypes are the same 

891 all_types = [type(v) for v in values] 

892 if any([t != all_types[0] for t in all_types]): 

893 typenames = [t.__name__ for t in all_types] 

894 raise ValueError("Inconsistent datatypes in the table. " 

895 "given types: {}".format(typenames)) 

896 

897 val = values[0] 

898 if isinstance(val, int) or np.issubdtype(type(val), np.integer): 

899 return "INTEGER" 

900 if isinstance(val, float) or np.issubdtype(type(val), np.floating): 

901 return "REAL" 

902 if isinstance(val, str): 

903 return "TEXT" 

904 raise ValueError("Unknown datatype!") 

905 

906 def _get_value_type_of_table(self, cursor, tab_name): 

907 """Return the expected value name.""" 

908 sql = "SELECT value FROM information WHERE name=?" 

909 cursor.execute(sql, (tab_name + "_dtype",)) 

910 return cursor.fetchone()[0] 

911 

912 def _read_external_table(self, name, id): 

913 """Read row from external table.""" 

914 

915 with self.managed_connection() as con: 

916 cur = con.cursor() 

917 cur.execute("SELECT * FROM {} WHERE id=?".format(name), (id,)) 

918 items = cur.fetchall() 

919 dictionary = dict([(item[0], item[1]) for item in items]) 

920 

921 return dictionary 

922 

923 

924if __name__ == '__main__': 

925 from ase.db import connect 

926 con = connect(sys.argv[1]) 

927 con._initialize(con._connect()) 

928 print('Version:', con.version)