Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import sys
3import errno
4import pickle
5import warnings
6import collections
8# Python 3 stuff:
9try:
10 unicode
11except NameError:
12 unicode = str
14# pass for WindowsError on non-Win platforms
15try:
16 WindowsError
17except NameError:
18 class WindowsError(OSError):
19 pass
21import numpy as np
23from ase.atoms import Atoms
24from ase.calculators.singlepoint import SinglePointCalculator
25from ase.calculators.calculator import PropertyNotImplementedError
26from ase.constraints import FixAtoms
27from ase.parallel import world, barrier
30class PickleTrajectory:
31 """Reads/writes Atoms objects into a .traj file."""
32 # Per default, write these quantities
33 write_energy = True
34 write_forces = True
35 write_stress = True
36 write_charges = True
37 write_magmoms = True
38 write_momenta = True
39 write_info = True
41 def __init__(self, filename, mode='r', atoms=None, master=None,
42 backup=True, _warn=True):
43 """A PickleTrajectory can be created in read, write or append mode.
45 Parameters:
47 filename:
48 The name of the parameter file. Should end in .traj.
50 mode='r':
51 The mode.
53 'r' is read mode, the file should already exist, and
54 no atoms argument should be specified.
56 'w' is write mode. If the file already exists, it is
57 renamed by appending .bak to the file name. The atoms
58 argument specifies the Atoms object to be written to the
59 file, if not given it must instead be given as an argument
60 to the write() method.
62 'a' is append mode. It acts a write mode, except that
63 data is appended to a preexisting file.
65 atoms=None:
66 The Atoms object to be written in write or append mode.
68 master=None:
69 Controls which process does the actual writing. The
70 default is that process number 0 does this. If this
71 argument is given, processes where it is True will write.
73 backup=True:
74 Use backup=False to disable renaming of an existing file.
75 """
77 if _warn:
78 msg = 'Please stop using old trajectory files!'
79 if mode == 'r':
80 msg += ('\nConvert to the new future-proof format like this:\n'
81 '\n $ python3 -m ase.io.trajectory ' +
82 filename + '\n')
83 raise DeprecationWarning(msg)
85 self.numbers = None
86 self.pbc = None
87 self.sanitycheck = True
88 self.pre_observers = [] # Callback functions before write
89 self.post_observers = [] # Callback functions after write
91 # Counter used to determine when callbacks are called:
92 self.write_counter = 0
94 self.offsets = []
95 if master is None:
96 master = (world.rank == 0)
97 self.master = master
98 self.backup = backup
99 self.set_atoms(atoms)
100 self.open(filename, mode)
102 def open(self, filename, mode):
103 """Opens the file.
105 For internal use only.
106 """
107 self.fd = filename
108 if mode == 'r':
109 if isinstance(filename, str):
110 self.fd = open(filename, 'rb')
111 self.read_header()
112 elif mode == 'a':
113 exists = True
114 if isinstance(filename, str):
115 exists = os.path.isfile(filename)
116 if exists:
117 exists = os.path.getsize(filename) > 0
118 if exists:
119 self.fd = open(filename, 'rb')
120 self.read_header()
121 self.fd.close()
122 barrier()
123 if self.master:
124 self.fd = open(filename, 'ab+')
125 else:
126 self.fd = open(os.devnull, 'ab+')
127 elif mode == 'w':
128 if self.master:
129 if isinstance(filename, str):
130 if self.backup and os.path.isfile(filename):
131 try:
132 os.rename(filename, filename + '.bak')
133 except WindowsError as e:
134 # this must run on Win only! Not atomic!
135 if e.errno != errno.EEXIST:
136 raise
137 os.unlink(filename + '.bak')
138 os.rename(filename, filename + '.bak')
139 self.fd = open(filename, 'wb')
140 else:
141 self.fd = open(os.devnull, 'wb')
142 else:
143 raise ValueError('mode must be "r", "w" or "a".')
145 def set_atoms(self, atoms=None):
146 """Associate an Atoms object with the trajectory.
148 Mostly for internal use.
149 """
150 if atoms is not None and not hasattr(atoms, 'get_positions'):
151 raise TypeError('"atoms" argument is not an Atoms object.')
152 self.atoms = atoms
154 def read_header(self):
155 if hasattr(self.fd, 'name'):
156 if os.path.isfile(self.fd.name):
157 if os.path.getsize(self.fd.name) == 0:
158 return
159 self.fd.seek(0)
160 try:
161 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory':
162 raise IOError('This is not a trajectory file!')
163 d = pickle.load(self.fd)
164 except EOFError:
165 raise EOFError('Bad trajectory file.')
167 self.pbc = d['pbc']
168 self.numbers = d['numbers']
169 self.tags = d.get('tags')
170 self.masses = d.get('masses')
171 self.constraints = dict2constraints(d)
172 self.offsets.append(self.fd.tell())
174 def write(self, atoms=None):
175 if atoms is None:
176 atoms = self.atoms
178 for image in atoms.iterimages():
179 self._write_atoms(image)
181 def _write_atoms(self, atoms):
182 """Write the atoms to the file.
184 If the atoms argument is not given, the atoms object specified
185 when creating the trajectory object is used.
186 """
187 self._call_observers(self.pre_observers)
189 if len(self.offsets) == 0:
190 self.write_header(atoms)
191 else:
192 if (atoms.pbc != self.pbc).any():
193 raise ValueError('Bad periodic boundary conditions!')
194 elif self.sanitycheck and len(atoms) != len(self.numbers):
195 raise ValueError('Bad number of atoms!')
196 elif self.sanitycheck and (atoms.numbers != self.numbers).any():
197 raise ValueError('Bad atomic numbers!')
199 if atoms.has('momenta'):
200 momenta = atoms.get_momenta()
201 else:
202 momenta = None
204 d = {'positions': atoms.get_positions(),
205 'cell': atoms.get_cell(),
206 'momenta': momenta}
208 if atoms.calc is not None:
209 if self.write_energy:
210 d['energy'] = atoms.get_potential_energy()
211 if self.write_forces:
212 assert self.write_energy
213 try:
214 d['forces'] = atoms.get_forces(apply_constraint=False)
215 except PropertyNotImplementedError:
216 pass
217 if self.write_stress:
218 assert self.write_energy
219 try:
220 d['stress'] = atoms.get_stress()
221 except PropertyNotImplementedError:
222 pass
223 if self.write_charges:
224 try:
225 d['charges'] = atoms.get_charges()
226 except PropertyNotImplementedError:
227 pass
228 if self.write_magmoms:
229 try:
230 magmoms = atoms.get_magnetic_moments()
231 if any(np.asarray(magmoms).flat):
232 d['magmoms'] = magmoms
233 except (PropertyNotImplementedError, AttributeError):
234 pass
236 if 'magmoms' not in d and atoms.has('initial_magmoms'):
237 d['magmoms'] = atoms.get_initial_magnetic_moments()
238 if 'charges' not in d and atoms.has('initial_charges'):
239 charges = atoms.get_initial_charges()
240 if (charges != 0).any():
241 d['charges'] = charges
243 if self.write_info:
244 d['info'] = stringnify_info(atoms.info)
246 if self.master:
247 pickle.dump(d, self.fd, protocol=2)
248 self.fd.flush()
249 self.offsets.append(self.fd.tell())
250 self._call_observers(self.post_observers)
251 self.write_counter += 1
253 def write_header(self, atoms):
254 self.fd.write(b'PickleTrajectory')
255 if atoms.has('tags'):
256 tags = atoms.get_tags()
257 else:
258 tags = None
259 if atoms.has('masses'):
260 masses = atoms.get_masses()
261 else:
262 masses = None
263 d = {'version': 3,
264 'pbc': atoms.get_pbc(),
265 'numbers': atoms.get_atomic_numbers(),
266 'tags': tags,
267 'masses': masses,
268 'constraints': [], # backwards compatibility
269 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)}
270 pickle.dump(d, self.fd, protocol=2)
271 self.header_written = True
272 self.offsets.append(self.fd.tell())
274 # Atomic numbers and periodic boundary conditions are only
275 # written once - in the header. Store them here so that we can
276 # check that they are the same for all images:
277 self.numbers = atoms.get_atomic_numbers()
278 self.pbc = atoms.get_pbc()
280 def close(self):
281 """Close the trajectory file."""
282 self.fd.close()
284 def __getitem__(self, i=-1):
285 if isinstance(i, slice):
286 return [self[j] for j in range(*i.indices(len(self)))]
288 N = len(self.offsets)
289 if 0 <= i < N:
290 self.fd.seek(self.offsets[i])
291 try:
292 d = pickle.load(self.fd, encoding='bytes')
293 d = {k.decode() if isinstance(k, bytes) else k: v
294 for k, v in d.items()}
295 except EOFError:
296 raise IndexError
297 if i == N - 1:
298 self.offsets.append(self.fd.tell())
299 charges = d.get('charges')
300 magmoms = d.get('magmoms')
301 try:
302 constraints = [c.copy() for c in self.constraints]
303 except Exception:
304 constraints = []
305 warnings.warn('Constraints did not unpickle correctly.')
306 atoms = Atoms(positions=d['positions'],
307 numbers=self.numbers,
308 cell=d['cell'],
309 momenta=d['momenta'],
310 magmoms=magmoms,
311 charges=charges,
312 tags=self.tags,
313 masses=self.masses,
314 pbc=self.pbc,
315 info=unstringnify_info(d.get('info', {})),
316 constraint=constraints)
317 if 'energy' in d:
318 calc = SinglePointCalculator(
319 atoms,
320 energy=d.get('energy', None),
321 forces=d.get('forces', None),
322 stress=d.get('stress', None),
323 magmoms=magmoms)
324 atoms.calc = calc
325 return atoms
327 if i >= N:
328 for j in range(N - 1, i + 1):
329 atoms = self[j]
330 return atoms
332 i = len(self) + i
333 if i < 0:
334 raise IndexError('Trajectory index out of range.')
335 return self[i]
337 def __len__(self):
338 if len(self.offsets) == 0:
339 return 0
340 N = len(self.offsets) - 1
341 while True:
342 self.fd.seek(self.offsets[N])
343 try:
344 pickle.load(self.fd)
345 except EOFError:
346 return N
347 self.offsets.append(self.fd.tell())
348 N += 1
350 def pre_write_attach(self, function, interval=1, *args, **kwargs):
351 """Attach a function to be called before writing begins.
353 function: The function or callable object to be called.
355 interval: How often the function is called. Default: every time (1).
357 All other arguments are stored, and passed to the function.
358 """
359 if not isinstance(function, collections.Callable):
360 raise ValueError('Callback object must be callable.')
361 self.pre_observers.append((function, interval, args, kwargs))
363 def post_write_attach(self, function, interval=1, *args, **kwargs):
364 """Attach a function to be called after writing ends.
366 function: The function or callable object to be called.
368 interval: How often the function is called. Default: every time (1).
370 All other arguments are stored, and passed to the function.
371 """
372 if not isinstance(function, collections.Callable):
373 raise ValueError('Callback object must be callable.')
374 self.post_observers.append((function, interval, args, kwargs))
376 def _call_observers(self, obs):
377 """Call pre/post write observers."""
378 for function, interval, args, kwargs in obs:
379 if self.write_counter % interval == 0:
380 function(*args, **kwargs)
382 def __enter__(self):
383 return self
385 def __exit__(self, *args):
386 self.close()
389def stringnify_info(info):
390 """Return a stringnified version of the dict *info* that is
391 ensured to be picklable. Items with non-string keys or
392 unpicklable values are dropped and a warning is issued."""
393 stringnified = {}
394 for k, v in info.items():
395 if not isinstance(k, str):
396 warnings.warn('Non-string info-dict key is not stored in ' +
397 'trajectory: ' + repr(k), UserWarning)
398 continue
399 try:
400 # Should highest protocol be used here for efficiency?
401 # Protocol 2 seems not to raise an exception when one
402 # tries to pickle a file object, so by using that, we
403 # might end up with file objects in inconsistent states.
404 s = pickle.dumps(v, protocol=0)
405 except pickle.PicklingError:
406 warnings.warn('Skipping not picklable info-dict item: ' +
407 '"%s" (%s)' % (k, sys.exc_info()[1]), UserWarning)
408 else:
409 stringnified[k] = s
410 return stringnified
413def unstringnify_info(stringnified):
414 """Convert the dict *stringnified* to a dict with unstringnified
415 objects and return it. Objects that cannot be unpickled will be
416 skipped and a warning will be issued."""
417 info = {}
418 for k, s in stringnified.items():
419 try:
420 v = pickle.loads(s)
421 except pickle.UnpicklingError:
422 warnings.warn('Skipping not unpicklable info-dict item: ' +
423 '"%s" (%s)' % (k, sys.exc_info()[1]), UserWarning)
424 else:
425 info[k] = v
426 return info
429def dict2constraints(d):
430 """Convert dict unpickled from trajectory file to list of constraints."""
432 version = d.get('version', 1)
434 if version == 1:
435 return d['constraints']
436 elif version in (2, 3):
437 try:
438 constraints = pickle.loads(d['constraints_string'])
439 for c in constraints:
440 if isinstance(c, FixAtoms) and c.index.dtype == bool:
441 # Special handling of old pickles:
442 c.index = np.arange(len(c.index))[c.index]
443 return constraints
444 except (AttributeError, KeyError, EOFError, ImportError, TypeError):
445 warnings.warn('Could not unpickle constraints!')
446 return []
447 else:
448 return []