Coverage for /builds/debichem-team/python-ase/ase/io/trajectory.py: 89.89%
277 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1"""Trajectory"""
2import contextlib
3import io
4import warnings
5from typing import Tuple
7import numpy as np
9from ase import __version__
10from ase.atoms import Atoms
11from ase.calculators.calculator import PropertyNotImplementedError
12from ase.calculators.singlepoint import SinglePointCalculator, all_properties
13from ase.io.formats import is_compressed
14from ase.io.jsonio import decode, encode
15from ase.io.pickletrajectory import PickleTrajectory
16from ase.parallel import world
17from ase.utils import tokenize_version
19__all__ = ['Trajectory', 'PickleTrajectory']
22def Trajectory(filename, mode='r', atoms=None, properties=None, master=None,
23 comm=world):
24 """A Trajectory can be created in read, write or append mode.
26 Parameters:
28 filename: str
29 The name of the file. Traditionally ends in .traj.
30 mode: str
31 The mode. 'r' is read mode, the file should already exist, and
32 no atoms argument should be specified.
33 'w' is write mode. The atoms argument specifies the Atoms
34 object to be written to the file, if not given it must instead
35 be given as an argument to the write() method.
36 'a' is append mode. It acts as write mode, except that
37 data is appended to a preexisting file.
38 atoms: Atoms object
39 The Atoms object to be written in write or append mode.
40 properties: list of str
41 If specified, these calculator properties are saved in the
42 trajectory. If not specified, all supported quantities are
43 saved. Possible values: energy, forces, stress, dipole,
44 charges, magmom and magmoms.
45 master: bool
46 Controls which process does the actual writing. The
47 default is that process number 0 does this. If this
48 argument is given, processes where it is True will write.
49 comm: Communicator object
50 Communicator to handle parallel file reading and writing.
52 The atoms, properties and master arguments are ignored in read mode.
53 """
54 if mode == 'r':
55 return TrajectoryReader(filename)
56 return TrajectoryWriter(filename, mode, atoms, properties, master=master,
57 comm=comm)
60class TrajectoryWriter:
61 """Writes Atoms objects to a .traj file."""
63 def __init__(self, filename, mode='w', atoms=None, properties=None,
64 master=None, comm=world):
65 """A Trajectory writer, in write or append mode.
67 Parameters:
69 filename: str
70 The name of the file. Traditionally ends in .traj.
71 mode: str
72 The mode. 'r' is read mode, the file should already exist, and
73 no atoms argument should be specified.
74 'w' is write mode. The atoms argument specifies the Atoms
75 object to be written to the file, if not given it must instead
76 be given as an argument to the write() method.
77 'a' is append mode. It acts as write mode, except that
78 data is appended to a preexisting file.
79 atoms: Atoms object
80 The Atoms object to be written in write or append mode.
81 properties: list of str
82 If specified, these calculator properties are saved in the
83 trajectory. If not specified, all supported quantities are
84 saved. Possible values: energy, forces, stress, dipole,
85 charges, magmom and magmoms.
86 master: bool
87 Controls which process does the actual writing. The
88 default is that process number 0 does this. If this
89 argument is given, processes where it is True will write.
90 comm: MPI communicator
91 MPI communicator for this trajectory writer, by default world.
92 Passing a different communicator facilitates writing of
93 different trajectories on different MPI ranks.
94 """
95 if master is None:
96 master = comm.rank == 0
98 self.filename = filename
99 self.mode = mode
100 self.atoms = atoms
101 self.properties = properties
102 self.master = master
103 self.comm = comm
105 self.description = {}
106 self.header_data = None
107 self.multiple_headers = False
109 self._open(filename, mode)
111 def __enter__(self):
112 return self
114 def __exit__(self, exc_type, exc_value, tb):
115 self.close()
117 def set_description(self, description):
118 self.description.update(description)
120 def _open(self, filename, mode):
121 import ase.io.ulm as ulm
122 if mode not in 'aw':
123 raise ValueError('mode must be "w" or "a".')
124 if self.master:
125 self.backend = ulm.open(filename, mode, tag='ASE-Trajectory')
126 if len(self.backend) > 0 and mode == 'a':
127 with Trajectory(filename) as traj:
128 atoms = traj[0]
129 self.header_data = get_header_data(atoms)
130 else:
131 self.backend = ulm.DummyWriter()
133 def write(self, atoms=None, **kwargs):
134 """Write the atoms to the file.
136 If the atoms argument is not given, the atoms object specified
137 when creating the trajectory object is used.
139 Use keyword arguments to add extra properties::
141 writer.write(atoms, energy=117, dipole=[0, 0, 1.0])
142 """
143 if atoms is None:
144 atoms = self.atoms
146 for image in atoms.iterimages():
147 self._write_atoms(image, **kwargs)
149 def _write_atoms(self, atoms, **kwargs):
150 b = self.backend
152 if self.header_data is None:
153 b.write(version=1, ase_version=__version__)
154 if self.description:
155 b.write(description=self.description)
156 # Atomic numbers and periodic boundary conditions are written
157 # in the header in the beginning.
158 #
159 # If an image later on has other numbers/pbc, we write a new
160 # header. All subsequent images will then have their own header
161 # whether or not their numbers/pbc change.
162 self.header_data = get_header_data(atoms)
163 write_header = True
164 else:
165 if not self.multiple_headers:
166 header_data = get_header_data(atoms)
167 self.multiple_headers = not headers_equal(self.header_data,
168 header_data)
169 write_header = self.multiple_headers
171 write_atoms(b, atoms, write_header=write_header)
173 calc = atoms.calc
175 if calc is None and len(kwargs) > 0:
176 calc = SinglePointCalculator(atoms)
178 if calc is not None:
179 if not hasattr(calc, 'get_property'):
180 calc = OldCalculatorWrapper(calc)
181 c = b.child('calculator')
182 c.write(name=calc.name)
183 if hasattr(calc, 'todict'):
184 c.write(parameters=calc.todict())
185 for prop in all_properties:
186 if prop in kwargs:
187 x = kwargs[prop]
188 else:
189 if self.properties is not None:
190 if prop in self.properties:
191 x = calc.get_property(prop, atoms)
192 else:
193 x = None
194 else:
195 try:
196 x = calc.get_property(prop, atoms,
197 allow_calculation=False)
198 except (PropertyNotImplementedError, KeyError):
199 # KeyError is needed for Jacapo.
200 # XXX We can perhaps remove this.
201 x = None
202 if x is not None:
203 if prop in ['stress', 'dipole']:
204 x = x.tolist()
205 c.write(prop, x)
207 info = {}
208 for key, value in atoms.info.items():
209 try:
210 encode(value)
211 except TypeError:
212 warnings.warn(f'Skipping "{key}" info.')
213 else:
214 info[key] = value
215 if info:
216 b.write(info=info)
218 b.sync()
220 def close(self):
221 """Close the trajectory file."""
222 self.backend.close()
224 def __len__(self):
225 return self.comm.sum_scalar(len(self.backend))
228class TrajectoryReader:
229 """Reads Atoms objects from a .traj file."""
231 def __init__(self, filename):
232 """A Trajectory in read mode.
234 The filename traditionally ends in .traj.
235 """
236 self.filename = filename
237 self.numbers = None
238 self.pbc = None
239 self.masses = None
241 self._open(filename)
243 def __enter__(self):
244 return self
246 def __exit__(self, exc_type, exc_value, tb):
247 self.close()
249 def _open(self, filename):
250 import ase.io.ulm as ulm
251 self.backend = ulm.open(filename, 'r')
252 self._read_header()
254 def _read_header(self):
255 b = self.backend
256 if b.get_tag() != 'ASE-Trajectory':
257 raise OSError('This is not a trajectory file!')
259 if len(b) > 0:
260 self.pbc = b.pbc
261 self.numbers = b.numbers
262 self.masses = b.get('masses')
263 self.constraints = b.get('constraints', '[]')
264 self.description = b.get('description')
265 self.version = b.version
266 self.ase_version = b.get('ase_version')
268 def close(self):
269 """Close the trajectory file."""
270 self.backend.close()
272 def __getitem__(self, i=-1):
273 if isinstance(i, slice):
274 return SlicedTrajectory(self, i)
275 b = self.backend[i]
276 if 'numbers' in b:
277 # numbers and other header info was written alongside the image:
278 atoms = read_atoms(b, traj=self)
279 else:
280 # header info was not written because they are the same:
281 atoms = read_atoms(b,
282 header=[self.pbc, self.numbers, self.masses,
283 self.constraints],
284 traj=self)
285 if 'calculator' in b:
286 results = {}
287 implemented_properties = []
288 c = b.calculator
289 for prop in all_properties:
290 if prop in c:
291 results[prop] = c.get(prop)
292 implemented_properties.append(prop)
293 calc = SinglePointCalculator(atoms, **results)
294 calc.name = b.calculator.name
295 calc.implemented_properties = implemented_properties
297 if 'parameters' in c:
298 calc.parameters.update(c.parameters)
299 atoms.calc = calc
301 return atoms
303 def __len__(self):
304 return len(self.backend)
306 def __iter__(self):
307 for i in range(len(self)):
308 yield self[i]
311class SlicedTrajectory:
312 """Wrapper to return a slice from a trajectory without loading
313 from disk. Initialize with a trajectory (in read mode) and the
314 desired slice object."""
316 def __init__(self, trajectory, sliced):
317 self.trajectory = trajectory
318 self.map = range(len(self.trajectory))[sliced]
320 def __getitem__(self, i):
321 if isinstance(i, slice):
322 # Map directly to the original traj, not recursively.
323 traj = SlicedTrajectory(self.trajectory, slice(0, None))
324 traj.map = self.map[i]
325 return traj
326 return self.trajectory[self.map[i]]
328 def __len__(self):
329 return len(self.map)
332def get_header_data(atoms):
333 return {'pbc': atoms.pbc.copy(),
334 'numbers': atoms.get_atomic_numbers(),
335 'masses': atoms.get_masses() if atoms.has('masses') else None,
336 'constraints': list(atoms.constraints)}
339def headers_equal(headers1, headers2):
340 assert len(headers1) == len(headers2)
341 eq = True
342 for key in headers1:
343 eq &= np.array_equal(headers1[key], headers2[key])
344 return eq
347class VersionTooOldError(Exception):
348 pass
351def read_atoms(backend,
352 header: Tuple = None,
353 traj: TrajectoryReader = None,
354 _try_except: bool = True) -> Atoms:
355 from ase.constraints import dict2constraint
357 if _try_except:
358 try:
359 return read_atoms(backend, header, traj, False)
360 except Exception as ex:
361 if (traj is not None and tokenize_version(__version__) <
362 tokenize_version(traj.ase_version)):
363 msg = ('You are trying to read a trajectory file written '
364 f'by ASE-{traj.ase_version} from ASE-{__version__}. '
365 'It might help to update your ASE')
366 raise VersionTooOldError(msg) from ex
367 else:
368 raise
370 b = backend
371 if header:
372 pbc, numbers, masses, constraints = header
373 else:
374 pbc = b.pbc
375 numbers = b.numbers
376 masses = b.get('masses')
377 constraints = b.get('constraints', '[]')
379 atoms = Atoms(positions=b.positions,
380 numbers=numbers,
381 cell=b.cell,
382 masses=masses,
383 pbc=pbc,
384 info=b.get('info'),
385 constraint=[dict2constraint(d)
386 for d in decode(constraints)],
387 momenta=b.get('momenta'),
388 magmoms=b.get('magmoms'),
389 charges=b.get('charges'),
390 tags=b.get('tags'))
391 return atoms
394def write_atoms(backend, atoms, write_header=True):
395 b = backend
397 if write_header:
398 b.write(pbc=atoms.pbc.tolist(),
399 numbers=atoms.numbers)
400 if atoms.constraints:
401 if all(hasattr(c, 'todict') for c in atoms.constraints):
402 b.write(constraints=encode(atoms.constraints))
404 if atoms.has('masses'):
405 b.write(masses=atoms.get_masses())
407 b.write(positions=atoms.get_positions(),
408 cell=atoms.get_cell().tolist())
410 if atoms.has('tags'):
411 b.write(tags=atoms.get_tags())
412 if atoms.has('momenta'):
413 b.write(momenta=atoms.get_momenta())
414 if atoms.has('initial_magmoms'):
415 b.write(magmoms=atoms.get_initial_magnetic_moments())
416 if atoms.has('initial_charges'):
417 b.write(charges=atoms.get_initial_charges())
420def read_traj(fd, index):
421 trj = TrajectoryReader(fd)
422 for i in range(*index.indices(len(trj))):
423 yield trj[i]
426@contextlib.contextmanager
427def defer_compression(fd):
428 """Defer the file compression until all the configurations are read."""
429 # We do this because the trajectory and compressed-file
430 # internals do not play well together.
431 # Be advised not to defer compression of very long trajectories
432 # as they use a lot of memory.
433 if is_compressed(fd):
434 with io.BytesIO() as bytes_io:
435 try:
436 # write the uncompressed data to the buffer
437 yield bytes_io
438 finally:
439 # write the buffered data to the compressed file
440 bytes_io.seek(0)
441 fd.write(bytes_io.read())
442 else:
443 yield fd
446def write_traj(fd, images):
447 """Write image(s) to trajectory."""
448 if isinstance(images, Atoms):
449 images = [images]
450 with defer_compression(fd) as fd_uncompressed:
451 trj = TrajectoryWriter(fd_uncompressed)
452 for atoms in images:
453 trj.write(atoms)
456class OldCalculatorWrapper:
457 def __init__(self, calc):
458 self.calc = calc
459 try:
460 self.name = calc.name
461 except AttributeError:
462 self.name = calc.__class__.__name__.lower()
464 def get_property(self, prop, atoms, allow_calculation=True):
465 try:
466 if (not allow_calculation and
467 self.calc.calculation_required(atoms, [prop])):
468 return None
469 except AttributeError:
470 pass
472 method = 'get_' + {'energy': 'potential_energy',
473 'magmom': 'magnetic_moment',
474 'magmoms': 'magnetic_moments',
475 'dipole': 'dipole_moment'}.get(prop, prop)
476 try:
477 result = getattr(self.calc, method)(atoms)
478 except AttributeError:
479 raise PropertyNotImplementedError
480 return result
483def convert(name):
484 import os
485 t = TrajectoryWriter(name + '.new')
486 for atoms in PickleTrajectory(name, _warn=False):
487 t.write(atoms)
488 t.close()
489 os.rename(name, name + '.old')
490 os.rename(name + '.new', name)
493def main():
494 import optparse
495 parser = optparse.OptionParser(usage='python -m ase.io.trajectory '
496 'a1.traj [a2.traj ...]',
497 description='Convert old trajectory '
498 'file(s) to new format. '
499 'The old file is kept as a1.traj.old.')
500 _opts, args = parser.parse_args()
501 for name in args:
502 convert(name)
505if __name__ == '__main__':
506 main()