Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import datetime
2import json
4import numpy as np
5from ase.utils import reader, writer
8class MyEncoder(json.JSONEncoder):
9 def default(self, obj):
10 if hasattr(obj, 'todict'):
11 d = obj.todict()
13 if not isinstance(d, dict):
14 raise RuntimeError('todict() of {} returned object of type {} '
15 'but should have returned dict'
16 .format(obj, type(d)))
17 if hasattr(obj, 'ase_objtype'):
18 d['__ase_objtype__'] = obj.ase_objtype
20 return d
21 if isinstance(obj, np.ndarray):
22 flatobj = obj.ravel()
23 if np.iscomplexobj(obj):
24 flatobj.dtype = obj.real.dtype
25 return {'__ndarray__': (obj.shape,
26 obj.dtype.name,
27 flatobj.tolist())}
28 if isinstance(obj, np.integer):
29 return int(obj)
30 if isinstance(obj, np.bool_):
31 return bool(obj)
32 if isinstance(obj, datetime.datetime):
33 return {'__datetime__': obj.isoformat()}
34 if isinstance(obj, complex):
35 return {'__complex__': (obj.real, obj.imag)}
36 return json.JSONEncoder.default(self, obj)
39encode = MyEncoder().encode
42def object_hook(dct):
43 if '__datetime__' in dct:
44 return datetime.datetime.strptime(dct['__datetime__'],
45 '%Y-%m-%dT%H:%M:%S.%f')
47 if '__complex__' in dct:
48 return complex(*dct['__complex__'])
50 if '__ndarray__' in dct:
51 return create_ndarray(*dct['__ndarray__'])
53 # No longer used (only here for backwards compatibility):
54 if '__complex_ndarray__' in dct:
55 r, i = (np.array(x) for x in dct['__complex_ndarray__'])
56 return r + i * 1j
58 if '__ase_objtype__' in dct:
59 objtype = dct.pop('__ase_objtype__')
60 dct = numpyfy(dct)
61 return create_ase_object(objtype, dct)
63 return dct
66def create_ndarray(shape, dtype, data):
67 """Create ndarray from shape, dtype and flattened data."""
68 array = np.empty(shape, dtype=dtype)
69 flatbuf = array.ravel()
70 if np.iscomplexobj(array):
71 flatbuf.dtype = array.real.dtype
72 flatbuf[:] = data
73 return array
76def create_ase_object(objtype, dct):
77 # We just try each object type one after another and instantiate
78 # them manually, depending on which kind it is.
79 # We can formalize this later if it ever becomes necessary.
80 if objtype == 'cell':
81 from ase.cell import Cell
82 dct.pop('pbc', None) # compatibility; we once had pbc
83 obj = Cell(**dct)
84 elif objtype == 'bandstructure':
85 from ase.spectrum.band_structure import BandStructure
86 obj = BandStructure(**dct)
87 elif objtype == 'bandpath':
88 from ase.dft.kpoints import BandPath
89 obj = BandPath(path=dct.pop('labelseq'), **dct)
90 elif objtype == 'atoms':
91 from ase import Atoms
92 obj = Atoms.fromdict(dct)
93 elif objtype == 'vibrationsdata':
94 from ase.vibrations import VibrationsData
95 obj = VibrationsData.fromdict(dct)
96 else:
97 raise ValueError('Do not know how to decode object type {} '
98 'into an actual object'.format(objtype))
99 assert obj.ase_objtype == objtype
100 return obj
103mydecode = json.JSONDecoder(object_hook=object_hook).decode
106def intkey(key):
107 """Convert str to int if possible."""
108 try:
109 return int(key)
110 except ValueError:
111 return key
114def fix_int_keys_in_dicts(obj):
115 """Convert "int" keys: "1" -> 1.
117 The json.dump() function will convert int keys in dicts to str keys.
118 This function goes the other way.
119 """
120 if isinstance(obj, dict):
121 return {intkey(key): fix_int_keys_in_dicts(value)
122 for key, value in obj.items()}
123 return obj
126def numpyfy(obj):
127 if isinstance(obj, dict):
128 if '__complex_ndarray__' in obj:
129 r, i = (np.array(x) for x in obj['__complex_ndarray__'])
130 return r + i * 1j
131 if isinstance(obj, list) and len(obj) > 0:
132 try:
133 a = np.array(obj)
134 except ValueError:
135 pass
136 else:
137 if a.dtype in [bool, int, float]:
138 return a
139 obj = [numpyfy(value) for value in obj]
140 return obj
143def decode(txt, always_array=True):
144 obj = mydecode(txt)
145 obj = fix_int_keys_in_dicts(obj)
146 if always_array:
147 obj = numpyfy(obj)
148 return obj
151@reader
152def read_json(fd, always_array=True):
153 dct = decode(fd.read(), always_array=always_array)
154 return dct
157@writer
158def write_json(fd, obj):
159 fd.write(encode(obj))