Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2netcdftrajectory - I/O trajectory files in the AMBER NetCDF convention
4More information on the AMBER NetCDF conventions can be found at
5http://ambermd.org/netcdf/. This module supports extensions to
6these conventions, such as writing of additional fields and writing to
7HDF5 (NetCDF-4) files.
9A netCDF4-python is required by this module:
11 netCDF4-python - https://github.com/Unidata/netcdf4-python
13NetCDF files can be directly visualized using the libAtoms flavor of
14AtomEye (http://www.libatoms.org/),
15VMD (http://www.ks.uiuc.edu/Research/vmd/)
16or Ovito (http://www.ovito.org/, starting with version 2.3).
17"""
20import os
21import warnings
23import numpy as np
25import ase
27from ase.data import atomic_masses
28from ase.geometry import cellpar_to_cell
29import collections
30from functools import reduce
33class NetCDFTrajectory:
34 """
35 Reads/writes Atoms objects into an AMBER-style .nc trajectory file.
36 """
38 # Default dimension names
39 _frame_dim = 'frame'
40 _spatial_dim = 'spatial'
41 _atom_dim = 'atom'
42 _cell_spatial_dim = 'cell_spatial'
43 _cell_angular_dim = 'cell_angular'
44 _label_dim = 'label'
45 _Voigt_dim = 'Voigt' # For stress/strain tensors
47 # Default field names. If it is a list, check for any of these names upon
48 # opening. Upon writing, use the first name.
49 _spatial_var = 'spatial'
50 _cell_spatial_var = 'cell_spatial'
51 _cell_angular_var = 'cell_angular'
52 _time_var = 'time'
53 _numbers_var = ['atom_types', 'type', 'Z']
54 _positions_var = 'coordinates'
55 _velocities_var = 'velocities'
56 _cell_origin_var = 'cell_origin'
57 _cell_lengths_var = 'cell_lengths'
58 _cell_angles_var = 'cell_angles'
60 _default_vars = reduce(lambda x, y: x + y,
61 [_numbers_var, [_positions_var], [_velocities_var],
62 [_cell_origin_var], [_cell_lengths_var],
63 [_cell_angles_var]])
65 def __init__(self, filename, mode='r', atoms=None, types_to_numbers=None,
66 double=True, netcdf_format='NETCDF3_CLASSIC', keep_open=True,
67 index_var='id', chunk_size=1000000):
68 """
69 A NetCDFTrajectory can be created in read, write or append mode.
71 Parameters:
73 filename:
74 The name of the parameter file. Should end in .nc.
76 mode='r':
77 The mode.
79 'r' is read mode, the file should already exist, and no atoms
80 argument should be specified.
82 'w' is write mode. The atoms argument specifies the Atoms object
83 to be written to the file, if not given it must instead be given
84 as an argument to the write() method.
86 'a' is append mode. It acts a write mode, except that data is
87 appended to a preexisting file.
89 atoms=None:
90 The Atoms object to be written in write or append mode.
92 types_to_numbers=None:
93 Dictionary or list for conversion of atom types to atomic numbers
94 when reading a trajectory file.
96 double=True:
97 Create new variable in double precision.
99 netcdf_format='NETCDF3_CLASSIC':
100 Format string for the underlying NetCDF file format. Only relevant
101 if a new file is created. More information can be found at
102 https://www.unidata.ucar.edu/software/netcdf/docs/netcdf/File-Format.html
104 'NETCDF3_CLASSIC' is the original binary format.
106 'NETCDF3_64BIT' can be used to write larger files.
108 'NETCDF4_CLASSIC' is HDF5 with some NetCDF limitations.
110 'NETCDF4' is HDF5.
112 keep_open=True:
113 Keep the file open during consecutive read/write operations.
114 Set to false if you experience data corruption. This will close the
115 file after each read/write operation by comes with serious
116 performance penalty.
118 index_var='id':
119 Name of variable containing the atom indices. Atoms are reordered
120 by this index upon reading if this variable is present. Default
121 value is for LAMMPS output. None switches atom indices off.
123 chunk_size=1000000:
124 Maximum size of consecutive number of records (along the 'atom')
125 dimension read when reading from a NetCDF file. This is used to
126 reduce the memory footprint of a read operation on very large files.
127 """
128 self.nc = None
129 self.chunk_size = chunk_size
131 self.numbers = None
132 self.pre_observers = [] # Callback functions before write
133 self.post_observers = [] # Callback functions after write are called
135 self.has_header = False
136 self._set_atoms(atoms)
138 self.types_to_numbers = None
139 if isinstance(types_to_numbers, list):
140 types_to_numbers = {x: y for x, y in enumerate(types_to_numbers)}
141 if types_to_numbers is not None:
142 self.types_to_numbers = types_to_numbers
144 self.index_var = index_var
146 if self.index_var is not None:
147 self._default_vars += [self.index_var]
149 # 'l' should be a valid type according to the netcdf4-python
150 # documentation, but does not appear to work.
151 self.dtype_conv = {'l': 'i'}
152 if not double:
153 self.dtype_conv.update(dict(d='f'))
155 self.extra_per_frame_vars = []
156 self.extra_per_file_vars = []
157 # per frame atts are global quantities, not quantities stored for each
158 # atom
159 self.extra_per_frame_atts = []
161 self.mode = mode
162 self.netcdf_format = netcdf_format
164 if atoms:
165 self.n_atoms = len(atoms)
166 else:
167 self.n_atoms = None
169 self.filename = filename
170 if keep_open is None:
171 # Only netCDF4-python supports append to files
172 self.keep_open = self.mode == 'r'
173 else:
174 self.keep_open = keep_open
176 def __del__(self):
177 self.close()
179 def _open(self):
180 """
181 Opens the file.
183 For internal use only.
184 """
185 import netCDF4
186 if self.nc is not None:
187 return
188 if self.mode == 'a' and not os.path.exists(self.filename):
189 self.mode = 'w'
190 self.nc = netCDF4.Dataset(self.filename, self.mode,
191 format=self.netcdf_format)
193 self.frame = 0
194 if self.mode == 'r' or self.mode == 'a':
195 self._read_header()
196 self.frame = self._len()
198 def _set_atoms(self, atoms=None):
199 """
200 Associate an Atoms object with the trajectory.
202 For internal use only.
203 """
204 if atoms is not None and not hasattr(atoms, 'get_positions'):
205 raise TypeError('"atoms" argument is not an Atoms object.')
206 self.atoms = atoms
208 def _read_header(self):
209 if not self.n_atoms:
210 self.n_atoms = len(self.nc.dimensions[self._atom_dim])
212 for name, var in self.nc.variables.items():
213 # This can be unicode which confuses ASE
214 name = str(name)
215 # _default_vars is taken care of already
216 if name not in self._default_vars:
217 if len(var.dimensions) >= 2:
218 if var.dimensions[0] == self._frame_dim:
219 if var.dimensions[1] == self._atom_dim:
220 self.extra_per_frame_vars += [name]
221 else:
222 self.extra_per_frame_atts += [name]
224 elif len(var.dimensions) == 1:
225 if var.dimensions[0] == self._atom_dim:
226 self.extra_per_file_vars += [name]
227 elif var.dimensions[0] == self._frame_dim:
228 self.extra_per_frame_atts += [name]
230 self.has_header = True
232 def write(self, atoms=None, frame=None, arrays=None, time=None):
233 """
234 Write the atoms to the file.
236 If the atoms argument is not given, the atoms object specified
237 when creating the trajectory object is used.
238 """
239 self._open()
240 self._call_observers(self.pre_observers)
241 if atoms is None:
242 atoms = self.atoms
244 if hasattr(atoms, 'interpolate'):
245 # seems to be a NEB
246 neb = atoms
247 assert not neb.parallel
248 try:
249 neb.get_energies_and_forces(all=True)
250 except AttributeError:
251 pass
252 for image in neb.images:
253 self.write(image)
254 return
256 if not self.has_header:
257 self._define_file_structure(atoms)
258 else:
259 if len(atoms) != self.n_atoms:
260 raise ValueError('Bad number of atoms!')
262 if frame is None:
263 i = self.frame
264 else:
265 i = frame
267 # Number can be per file variable
268 numbers = self._get_variable(self._numbers_var)
269 if numbers.dimensions[0] == self._frame_dim:
270 numbers[i] = atoms.get_atomic_numbers()
271 else:
272 if np.any(numbers != atoms.get_atomic_numbers()):
273 raise ValueError('Atomic numbers do not match!')
274 self._get_variable(self._positions_var)[i] = atoms.get_positions()
275 if atoms.has('momenta'):
276 self._add_velocities()
277 self._get_variable(self._velocities_var)[i] = \
278 atoms.get_momenta() / atoms.get_masses().reshape(-1, 1)
279 a, b, c, alpha, beta, gamma = atoms.cell.cellpar()
280 if np.any(np.logical_not(atoms.pbc)):
281 warnings.warn('Atoms have nonperiodic directions. Cell lengths in '
282 'these directions are lost and will be '
283 'shrink-wrapped when reading the NetCDF file.')
284 cell_lengths = np.array([a, b, c]) * atoms.pbc
285 self._get_variable(self._cell_lengths_var)[i] = cell_lengths
286 self._get_variable(self._cell_angles_var)[i] = [alpha, beta, gamma]
287 self._get_variable(self._cell_origin_var)[i] = \
288 atoms.get_celldisp().reshape(3)
289 if arrays is not None:
290 for array in arrays:
291 data = atoms.get_array(array)
292 if array in self.extra_per_file_vars:
293 # This field exists but is per file data. Check that the
294 # data remains consistent.
295 if np.any(self._get_variable(array) != data):
296 raise ValueError('Trying to write Atoms object with '
297 'incompatible data for the {0} '
298 'array.'.format(array))
299 else:
300 self._add_array(atoms, array, data.dtype, data.shape)
301 self._get_variable(array)[i] = data
302 if time is not None:
303 self._add_time()
304 self._get_variable(self._time_var)[i] = time
306 self.sync()
308 self._call_observers(self.post_observers)
309 self.frame += 1
310 self._close()
312 def write_arrays(self, atoms, frame, arrays):
313 self._open()
314 self._call_observers(self.pre_observers)
315 for array in arrays:
316 data = atoms.get_array(array)
317 if array in self.extra_per_file_vars:
318 # This field exists but is per file data. Check that the
319 # data remains consistent.
320 if np.any(self._get_variable(array) != data):
321 raise ValueError('Trying to write Atoms object with '
322 'incompatible data for the {0} '
323 'array.'.format(array))
324 else:
325 self._add_array(atoms, array, data.dtype, data.shape)
326 self._get_variable(array)[frame] = data
327 self._call_observers(self.post_observers)
328 self._close()
330 def _define_file_structure(self, atoms):
331 self.nc.Conventions = 'AMBER'
332 self.nc.ConventionVersion = '1.0'
333 self.nc.program = 'ASE'
334 self.nc.programVersion = ase.__version__
335 self.nc.title = "MOL"
337 if self._frame_dim not in self.nc.dimensions:
338 self.nc.createDimension(self._frame_dim, None)
339 if self._spatial_dim not in self.nc.dimensions:
340 self.nc.createDimension(self._spatial_dim, 3)
341 if self._atom_dim not in self.nc.dimensions:
342 self.nc.createDimension(self._atom_dim, len(atoms))
343 if self._cell_spatial_dim not in self.nc.dimensions:
344 self.nc.createDimension(self._cell_spatial_dim, 3)
345 if self._cell_angular_dim not in self.nc.dimensions:
346 self.nc.createDimension(self._cell_angular_dim, 3)
347 if self._label_dim not in self.nc.dimensions:
348 self.nc.createDimension(self._label_dim, 5)
350 # Self-describing variables from AMBER convention
351 if not self._has_variable(self._spatial_var):
352 self.nc.createVariable(self._spatial_var, 'S1',
353 (self._spatial_dim,))
354 self.nc.variables[self._spatial_var][:] = ['x', 'y', 'z']
355 if not self._has_variable(self._cell_spatial_var):
356 self.nc.createVariable(self._cell_spatial_dim, 'S1',
357 (self._cell_spatial_dim,))
358 self.nc.variables[self._cell_spatial_var][:] = ['a', 'b', 'c']
359 if not self._has_variable(self._cell_angular_var):
360 self.nc.createVariable(self._cell_angular_var, 'S1',
361 (self._cell_angular_dim, self._label_dim,))
362 self.nc.variables[self._cell_angular_var][0] = [x for x in 'alpha']
363 self.nc.variables[self._cell_angular_var][1] = [x for x in 'beta ']
364 self.nc.variables[self._cell_angular_var][2] = [x for x in 'gamma']
366 if not self._has_variable(self._numbers_var):
367 self.nc.createVariable(self._numbers_var[0], 'i',
368 (self._frame_dim, self._atom_dim,))
369 if not self._has_variable(self._positions_var):
370 self.nc.createVariable(self._positions_var, 'f4',
371 (self._frame_dim, self._atom_dim,
372 self._spatial_dim))
373 self.nc.variables[self._positions_var].units = 'Angstrom'
374 self.nc.variables[self._positions_var].scale_factor = 1.
375 if not self._has_variable(self._cell_lengths_var):
376 self.nc.createVariable(self._cell_lengths_var, 'd',
377 (self._frame_dim, self._cell_spatial_dim))
378 self.nc.variables[self._cell_lengths_var].units = 'Angstrom'
379 self.nc.variables[self._cell_lengths_var].scale_factor = 1.
380 if not self._has_variable(self._cell_angles_var):
381 self.nc.createVariable(self._cell_angles_var, 'd',
382 (self._frame_dim, self._cell_angular_dim))
383 self.nc.variables[self._cell_angles_var].units = 'degree'
384 if not self._has_variable(self._cell_origin_var):
385 self.nc.createVariable(self._cell_origin_var, 'd',
386 (self._frame_dim, self._cell_spatial_dim))
387 self.nc.variables[self._cell_origin_var].units = 'Angstrom'
388 self.nc.variables[self._cell_origin_var].scale_factor = 1.
390 def _add_time(self):
391 if not self._has_variable(self._time_var):
392 self.nc.createVariable(self._time_var, 'f8', (self._frame_dim,))
394 def _add_velocities(self):
395 if not self._has_variable(self._velocities_var):
396 self.nc.createVariable(self._velocities_var, 'f4',
397 (self._frame_dim, self._atom_dim,
398 self._spatial_dim))
399 self.nc.variables[self._positions_var].units = \
400 'Angstrom/Femtosecond'
401 self.nc.variables[self._positions_var].scale_factor = 1.
403 def _add_array(self, atoms, array_name, type, shape):
404 if not self._has_variable(array_name):
405 dims = [self._frame_dim]
406 for i in shape:
407 if i == len(atoms):
408 dims += [self._atom_dim]
409 elif i == 3:
410 dims += [self._spatial_dim]
411 elif i == 6:
412 # This can only be stress/strain tensor in Voigt notation
413 if self._Voigt_dim not in self.nc.dimensions:
414 self.nc.createDimension(self._Voigt_dim, 6)
415 dims += [self._Voigt_dim]
416 else:
417 raise TypeError("Don't know how to dump array of shape {0}"
418 " into NetCDF trajectory.".format(shape))
419 if hasattr(type, 'char'):
420 t = self.dtype_conv.get(type.char, type)
421 else:
422 t = type
423 self.nc.createVariable(array_name, t, dims)
425 def _get_variable(self, name, exc=True):
426 if isinstance(name, list):
427 for n in name:
428 if n in self.nc.variables:
429 return self.nc.variables[n]
430 if exc:
431 raise RuntimeError(
432 'None of the variables {0} was found in the '
433 'NetCDF trajectory.'.format(', '.join(name)))
434 else:
435 if name in self.nc.variables:
436 return self.nc.variables[name]
437 if exc:
438 raise RuntimeError('Variables {0} was found in the NetCDF '
439 'trajectory.'.format(name))
440 return None
442 def _has_variable(self, name):
443 if isinstance(name, list):
444 for n in name:
445 if n in self.nc.variables:
446 return True
447 return False
448 else:
449 return name in self.nc.variables
451 def _get_data(self, name, frame, index, exc=True):
452 var = self._get_variable(name, exc=exc)
453 if var is None:
454 return None
455 if var.dimensions[0] == self._frame_dim:
456 data = np.zeros(var.shape[1:], dtype=var.dtype)
457 s = var.shape[1]
458 if s < self.chunk_size:
459 data[index] = var[frame]
460 else:
461 # If this is a large data set, only read chunks from it to
462 # reduce memory footprint of the NetCDFTrajectory reader.
463 for i in range((s - 1) // self.chunk_size + 1):
464 sl = slice(i * self.chunk_size,
465 min((i + 1) * self.chunk_size, s))
466 data[index[sl]] = var[frame, sl]
467 else:
468 data = np.zeros(var.shape, dtype=var.dtype)
469 s = var.shape[0]
470 if s < self.chunk_size:
471 data[index] = var[...]
472 else:
473 # If this is a large data set, only read chunks from it to
474 # reduce memory footprint of the NetCDFTrajectory reader.
475 for i in range((s - 1) // self.chunk_size + 1):
476 sl = slice(i * self.chunk_size,
477 min((i + 1) * self.chunk_size, s))
478 data[index[sl]] = var[sl]
479 return data
481 def __enter__(self):
482 return self
484 def __exit__(self, *args):
485 self.close()
487 def close(self):
488 """Close the trajectory file."""
489 if self.nc is not None:
490 self.nc.close()
491 self.nc = None
493 def _close(self):
494 if not self.keep_open:
495 self.close()
496 if self.mode == 'w':
497 self.mode = 'a'
499 def sync(self):
500 self.nc.sync()
502 def __getitem__(self, i=-1):
503 self._open()
505 if isinstance(i, slice):
506 return [self[j] for j in range(*i.indices(self._len()))]
508 N = self._len()
509 if 0 <= i < N:
510 # Non-periodic boundaries have cell_length == 0.0
511 cell_lengths = \
512 np.array(self.nc.variables[self._cell_lengths_var][i][:])
513 pbc = np.abs(cell_lengths > 1e-6)
515 # Do we have a cell origin?
516 if self._has_variable(self._cell_origin_var):
517 origin = np.array(
518 self.nc.variables[self._cell_origin_var][i][:])
519 else:
520 origin = np.zeros([3], dtype=float)
522 # Do we have an index variable?
523 if (self.index_var is not None and
524 self._has_variable(self.index_var)):
525 index = np.array(self.nc.variables[self.index_var][i][:])
526 # The index variable can be non-consecutive, we here construct
527 # a consecutive one.
528 consecutive_index = np.zeros_like(index)
529 consecutive_index[np.argsort(index)] = np.arange(self.n_atoms)
530 else:
531 consecutive_index = np.arange(self.n_atoms)
533 # Read element numbers
534 self.numbers = self._get_data(self._numbers_var, i,
535 consecutive_index, exc=False)
536 if self.numbers is None:
537 self.numbers = np.ones(self.n_atoms, dtype=int)
538 if self.types_to_numbers is not None:
539 d = set(self.numbers).difference(self.types_to_numbers.keys())
540 if len(d) > 0:
541 self.types_to_numbers.update({num: num for num in d})
542 func = np.vectorize(self.types_to_numbers.get)
543 self.numbers = func(self.numbers)
544 self.masses = atomic_masses[self.numbers]
546 # Read positions
547 positions = self._get_data(self._positions_var, i,
548 consecutive_index)
550 # Determine cell size for non-periodic directions from shrink
551 # wrapped cell.
552 for dim in np.arange(3)[np.logical_not(pbc)]:
553 origin[dim] = positions[:, dim].min()
554 cell_lengths[dim] = positions[:, dim].max() - origin[dim]
556 # Construct cell shape from cell lengths and angles
557 cell = cellpar_to_cell(
558 list(cell_lengths) +
559 list(self.nc.variables[self._cell_angles_var][i])
560 )
562 # Compute momenta from velocities (if present)
563 momenta = self._get_data(self._velocities_var, i,
564 consecutive_index, exc=False)
565 if momenta is not None:
566 momenta *= self.masses.reshape(-1, 1)
568 # Fill info dict with additional data found in the NetCDF file
569 info = {}
570 for name in self.extra_per_frame_atts:
571 info[name] = np.array(self.nc.variables[name][i])
573 # Create atoms object
574 atoms = ase.Atoms(
575 positions=positions,
576 numbers=self.numbers,
577 cell=cell,
578 celldisp=origin,
579 momenta=momenta,
580 masses=self.masses,
581 pbc=pbc,
582 info=info
583 )
585 # Attach additional arrays found in the NetCDF file
586 for name in self.extra_per_frame_vars:
587 atoms.set_array(name, self._get_data(name, i,
588 consecutive_index))
589 for name in self.extra_per_file_vars:
590 atoms.set_array(name, self._get_data(name, i,
591 consecutive_index))
592 self._close()
593 return atoms
595 i = N + i
596 if i < 0 or i >= N:
597 self._close()
598 raise IndexError('Trajectory index out of range.')
599 return self[i]
601 def _len(self):
602 if self._frame_dim in self.nc.dimensions:
603 return int(self._get_variable(self._positions_var).shape[0])
604 else:
605 return 0
607 def __len__(self):
608 self._open()
609 n_frames = self._len()
610 self._close()
611 return n_frames
613 def pre_write_attach(self, function, interval=1, *args, **kwargs):
614 """
615 Attach a function to be called before writing begins.
617 function: The function or callable object to be called.
619 interval: How often the function is called. Default: every time (1).
621 All other arguments are stored, and passed to the function.
622 """
623 if not isinstance(function, collections.Callable):
624 raise ValueError('Callback object must be callable.')
625 self.pre_observers.append((function, interval, args, kwargs))
627 def post_write_attach(self, function, interval=1, *args, **kwargs):
628 """
629 Attach a function to be called after writing ends.
631 function: The function or callable object to be called.
633 interval: How often the function is called. Default: every time (1).
635 All other arguments are stored, and passed to the function.
636 """
637 if not isinstance(function, collections.Callable):
638 raise ValueError('Callback object must be callable.')
639 self.post_observers.append((function, interval, args, kwargs))
641 def _call_observers(self, obs):
642 """Call pre/post write observers."""
643 for function, interval, args, kwargs in obs:
644 if self.write_counter % interval == 0:
645 function(*args, **kwargs)
648def read_netcdftrajectory(filename, index=-1):
649 with NetCDFTrajectory(filename, mode='r') as traj:
650 return traj[index]
653def write_netcdftrajectory(filename, images):
654 if hasattr(images, 'get_positions'):
655 images = [images]
657 with NetCDFTrajectory(filename, mode='w') as traj:
658 for atoms in images:
659 traj.write(atoms)