Coverage for /builds/debichem-team/python-ase/ase/db/postgresql.py: 91.30%

138 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-03-06 04:00 +0000

1import json 

2 

3import numpy as np 

4from psycopg2 import connect 

5from psycopg2.extras import execute_values 

6 

7from ase.db.sqlite import ( 

8 VERSION, 

9 SQLite3Database, 

10 index_statements, 

11 init_statements, 

12) 

13from ase.io.jsonio import create_ase_object, create_ndarray 

14from ase.io.jsonio import encode as ase_encode 

15 

16jsonb_indices = [ 

17 'CREATE INDEX idxkeys ON systems USING GIN (key_value_pairs);', 

18 'CREATE INDEX idxcalc ON systems USING GIN (calculator_parameters);'] 

19 

20 

21def remove_nan_and_inf(obj): 

22 if isinstance(obj, float) and not np.isfinite(obj): 

23 return {'__special_number__': str(obj)} 

24 if isinstance(obj, list): 

25 return [remove_nan_and_inf(x) for x in obj] 

26 if isinstance(obj, dict): 

27 return {key: remove_nan_and_inf(value) for key, value in obj.items()} 

28 if isinstance(obj, np.ndarray) and not np.isfinite(obj).all(): 

29 return remove_nan_and_inf(obj.tolist()) 

30 return obj 

31 

32 

33def insert_nan_and_inf(obj): 

34 if isinstance(obj, dict) and '__special_number__' in obj: 

35 return float(obj['__special_number__']) 

36 if isinstance(obj, list): 

37 return [insert_nan_and_inf(x) for x in obj] 

38 if isinstance(obj, dict): 

39 return {key: insert_nan_and_inf(value) for key, value in obj.items()} 

40 return obj 

41 

42 

43class Connection: 

44 def __init__(self, con): 

45 self.con = con 

46 

47 def cursor(self): 

48 return Cursor(self.con.cursor()) 

49 

50 def commit(self): 

51 self.con.commit() 

52 

53 def close(self): 

54 self.con.close() 

55 

56 

57class Cursor: 

58 def __init__(self, cur): 

59 self.cur = cur 

60 

61 def fetchone(self): 

62 return self.cur.fetchone() 

63 

64 def fetchall(self): 

65 return self.cur.fetchall() 

66 

67 def execute(self, statement, *args): 

68 self.cur.execute(statement.replace('?', '%s'), *args) 

69 

70 def executemany(self, statement, *args): 

71 if len(args[0]) > 0: 

72 N = len(args[0][0]) 

73 else: 

74 return 

75 if 'INSERT INTO systems' in statement: 

76 q = 'DEFAULT' + ', ' + ', '.join('?' * N) # DEFAULT for id 

77 else: 

78 q = ', '.join('?' * N) 

79 statement = statement.replace(f'({q})', '%s') 

80 q = '({})'.format(q.replace('?', '%s')) 

81 

82 execute_values(self.cur, statement.replace('?', '%s'), 

83 argslist=args[0], template=q, page_size=len(args[0])) 

84 

85 

86def insert_ase_and_ndarray_objects(obj): 

87 if isinstance(obj, dict): 

88 objtype = obj.pop('__ase_objtype__', None) 

89 if objtype is not None: 

90 return create_ase_object(objtype, 

91 insert_ase_and_ndarray_objects(obj)) 

92 data = obj.get('__ndarray__') 

93 if data is not None: 

94 return create_ndarray(*data) 

95 return {key: insert_ase_and_ndarray_objects(value) 

96 for key, value in obj.items()} 

97 if isinstance(obj, list): 

98 return [insert_ase_and_ndarray_objects(value) for value in obj] 

99 return obj 

100 

101 

102class PostgreSQLDatabase(SQLite3Database): 

103 type = 'postgresql' 

104 default = 'DEFAULT' 

105 

106 def encode(self, obj, binary=False): 

107 return ase_encode(remove_nan_and_inf(obj)) 

108 

109 def decode(self, obj, lazy=False): 

110 return insert_ase_and_ndarray_objects(insert_nan_and_inf(obj)) 

111 

112 def blob(self, array): 

113 """Convert array to blob/buffer object.""" 

114 

115 if array is None: 

116 return None 

117 if len(array) == 0: 

118 array = np.zeros(0) 

119 if array.dtype == np.int64: 

120 array = array.astype(np.int32) 

121 return array.tolist() 

122 

123 def deblob(self, buf, dtype=float, shape=None): 

124 """Convert blob/buffer object to ndarray of correct dtype and shape. 

125 

126 (without creating an extra view).""" 

127 if buf is None: 

128 return None 

129 return np.array(buf, dtype=dtype) 

130 

131 def _connect(self): 

132 return Connection(connect(self.filename)) 

133 

134 def _initialize(self, con): 

135 if self.initialized: 

136 return 

137 

138 self._metadata = {} 

139 

140 cur = con.cursor() 

141 cur.execute("show search_path;") 

142 schema = cur.fetchone()[0].split(', ') 

143 if schema[0] == '"$user"': 

144 schema = schema[1] 

145 else: 

146 schema = schema[0] 

147 

148 cur.execute(""" 

149 SELECT EXISTS(select * from information_schema.tables where 

150 table_name='information' and table_schema='{}'); 

151 """.format(schema)) 

152 

153 if not cur.fetchone()[0]: # information schema doesn't exist. 

154 # Initialize database: 

155 sql = ';\n'.join(init_statements) 

156 sql = schema_update(sql) 

157 cur.execute(sql) 

158 if self.create_indices: 

159 cur.execute(';\n'.join(index_statements)) 

160 cur.execute(';\n'.join(jsonb_indices)) 

161 con.commit() 

162 self.version = VERSION 

163 else: 

164 cur.execute('select * from information;') 

165 for name, value in cur.fetchall(): 

166 if name == 'version': 

167 self.version = int(value) 

168 elif name == 'metadata': 

169 self._metadata = json.loads(value) 

170 

171 assert 5 < self.version <= VERSION 

172 

173 self.initialized = True 

174 

175 def get_offset_string(self, offset, limit=None): 

176 # postgresql allows you to set offset without setting limit; 

177 # very practical 

178 return f'\nOFFSET {offset}' 

179 

180 def get_last_id(self, cur): 

181 cur.execute('SELECT last_value FROM systems_id_seq') 

182 id = cur.fetchone()[0] 

183 return int(id) 

184 

185 

186def schema_update(sql): 

187 for a, b in [('REAL', 'DOUBLE PRECISION'), 

188 ('INTEGER PRIMARY KEY AUTOINCREMENT', 

189 'SERIAL PRIMARY KEY')]: 

190 sql = sql.replace(a, b) 

191 

192 arrays_1D = ['numbers', 'initial_magmoms', 'initial_charges', 'masses', 

193 'tags', 'momenta', 'stress', 'dipole', 'magmoms', 'charges'] 

194 

195 arrays_2D = ['positions', 'cell', 'forces'] 

196 

197 txt2jsonb = ['calculator_parameters', 'key_value_pairs'] 

198 

199 for column in arrays_1D: 

200 if column in ['numbers', 'tags']: 

201 dtype = 'INTEGER' 

202 else: 

203 dtype = 'DOUBLE PRECISION' 

204 sql = sql.replace(f'{column} BLOB,', 

205 f'{column} {dtype}[],') 

206 for column in arrays_2D: 

207 sql = sql.replace(f'{column} BLOB,', 

208 f'{column} DOUBLE PRECISION[][],') 

209 for column in txt2jsonb: 

210 sql = sql.replace(f'{column} TEXT,', 

211 f'{column} JSONB,') 

212 

213 sql = sql.replace('data BLOB,', 'data JSONB,') 

214 

215 return sql