Coverage for /builds/debichem-team/python-ase/ase/db/postgresql.py: 91.30%
138 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import json
3import numpy as np
4from psycopg2 import connect
5from psycopg2.extras import execute_values
7from ase.db.sqlite import (
8 VERSION,
9 SQLite3Database,
10 index_statements,
11 init_statements,
12)
13from ase.io.jsonio import create_ase_object, create_ndarray
14from ase.io.jsonio import encode as ase_encode
16jsonb_indices = [
17 'CREATE INDEX idxkeys ON systems USING GIN (key_value_pairs);',
18 'CREATE INDEX idxcalc ON systems USING GIN (calculator_parameters);']
21def remove_nan_and_inf(obj):
22 if isinstance(obj, float) and not np.isfinite(obj):
23 return {'__special_number__': str(obj)}
24 if isinstance(obj, list):
25 return [remove_nan_and_inf(x) for x in obj]
26 if isinstance(obj, dict):
27 return {key: remove_nan_and_inf(value) for key, value in obj.items()}
28 if isinstance(obj, np.ndarray) and not np.isfinite(obj).all():
29 return remove_nan_and_inf(obj.tolist())
30 return obj
33def insert_nan_and_inf(obj):
34 if isinstance(obj, dict) and '__special_number__' in obj:
35 return float(obj['__special_number__'])
36 if isinstance(obj, list):
37 return [insert_nan_and_inf(x) for x in obj]
38 if isinstance(obj, dict):
39 return {key: insert_nan_and_inf(value) for key, value in obj.items()}
40 return obj
43class Connection:
44 def __init__(self, con):
45 self.con = con
47 def cursor(self):
48 return Cursor(self.con.cursor())
50 def commit(self):
51 self.con.commit()
53 def close(self):
54 self.con.close()
57class Cursor:
58 def __init__(self, cur):
59 self.cur = cur
61 def fetchone(self):
62 return self.cur.fetchone()
64 def fetchall(self):
65 return self.cur.fetchall()
67 def execute(self, statement, *args):
68 self.cur.execute(statement.replace('?', '%s'), *args)
70 def executemany(self, statement, *args):
71 if len(args[0]) > 0:
72 N = len(args[0][0])
73 else:
74 return
75 if 'INSERT INTO systems' in statement:
76 q = 'DEFAULT' + ', ' + ', '.join('?' * N) # DEFAULT for id
77 else:
78 q = ', '.join('?' * N)
79 statement = statement.replace(f'({q})', '%s')
80 q = '({})'.format(q.replace('?', '%s'))
82 execute_values(self.cur, statement.replace('?', '%s'),
83 argslist=args[0], template=q, page_size=len(args[0]))
86def insert_ase_and_ndarray_objects(obj):
87 if isinstance(obj, dict):
88 objtype = obj.pop('__ase_objtype__', None)
89 if objtype is not None:
90 return create_ase_object(objtype,
91 insert_ase_and_ndarray_objects(obj))
92 data = obj.get('__ndarray__')
93 if data is not None:
94 return create_ndarray(*data)
95 return {key: insert_ase_and_ndarray_objects(value)
96 for key, value in obj.items()}
97 if isinstance(obj, list):
98 return [insert_ase_and_ndarray_objects(value) for value in obj]
99 return obj
102class PostgreSQLDatabase(SQLite3Database):
103 type = 'postgresql'
104 default = 'DEFAULT'
106 def encode(self, obj, binary=False):
107 return ase_encode(remove_nan_and_inf(obj))
109 def decode(self, obj, lazy=False):
110 return insert_ase_and_ndarray_objects(insert_nan_and_inf(obj))
112 def blob(self, array):
113 """Convert array to blob/buffer object."""
115 if array is None:
116 return None
117 if len(array) == 0:
118 array = np.zeros(0)
119 if array.dtype == np.int64:
120 array = array.astype(np.int32)
121 return array.tolist()
123 def deblob(self, buf, dtype=float, shape=None):
124 """Convert blob/buffer object to ndarray of correct dtype and shape.
126 (without creating an extra view)."""
127 if buf is None:
128 return None
129 return np.array(buf, dtype=dtype)
131 def _connect(self):
132 return Connection(connect(self.filename))
134 def _initialize(self, con):
135 if self.initialized:
136 return
138 self._metadata = {}
140 cur = con.cursor()
141 cur.execute("show search_path;")
142 schema = cur.fetchone()[0].split(', ')
143 if schema[0] == '"$user"':
144 schema = schema[1]
145 else:
146 schema = schema[0]
148 cur.execute("""
149 SELECT EXISTS(select * from information_schema.tables where
150 table_name='information' and table_schema='{}');
151 """.format(schema))
153 if not cur.fetchone()[0]: # information schema doesn't exist.
154 # Initialize database:
155 sql = ';\n'.join(init_statements)
156 sql = schema_update(sql)
157 cur.execute(sql)
158 if self.create_indices:
159 cur.execute(';\n'.join(index_statements))
160 cur.execute(';\n'.join(jsonb_indices))
161 con.commit()
162 self.version = VERSION
163 else:
164 cur.execute('select * from information;')
165 for name, value in cur.fetchall():
166 if name == 'version':
167 self.version = int(value)
168 elif name == 'metadata':
169 self._metadata = json.loads(value)
171 assert 5 < self.version <= VERSION
173 self.initialized = True
175 def get_offset_string(self, offset, limit=None):
176 # postgresql allows you to set offset without setting limit;
177 # very practical
178 return f'\nOFFSET {offset}'
180 def get_last_id(self, cur):
181 cur.execute('SELECT last_value FROM systems_id_seq')
182 id = cur.fetchone()[0]
183 return int(id)
186def schema_update(sql):
187 for a, b in [('REAL', 'DOUBLE PRECISION'),
188 ('INTEGER PRIMARY KEY AUTOINCREMENT',
189 'SERIAL PRIMARY KEY')]:
190 sql = sql.replace(a, b)
192 arrays_1D = ['numbers', 'initial_magmoms', 'initial_charges', 'masses',
193 'tags', 'momenta', 'stress', 'dipole', 'magmoms', 'charges']
195 arrays_2D = ['positions', 'cell', 'forces']
197 txt2jsonb = ['calculator_parameters', 'key_value_pairs']
199 for column in arrays_1D:
200 if column in ['numbers', 'tags']:
201 dtype = 'INTEGER'
202 else:
203 dtype = 'DOUBLE PRECISION'
204 sql = sql.replace(f'{column} BLOB,',
205 f'{column} {dtype}[],')
206 for column in arrays_2D:
207 sql = sql.replace(f'{column} BLOB,',
208 f'{column} DOUBLE PRECISION[][],')
209 for column in txt2jsonb:
210 sql = sql.replace(f'{column} TEXT,',
211 f'{column} JSONB,')
213 sql = sql.replace('data BLOB,', 'data JSONB,')
215 return sql