Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk


2netcdftrajectory - I/O trajectory files in the AMBER NetCDF convention 


4More information on the AMBER NetCDF conventions can be found at 

5http://ambermd.org/netcdf/. This module supports extensions to 

6these conventions, such as writing of additional fields and writing to 

7HDF5 (NetCDF-4) files. 


9A netCDF4-python is required by this module: 


11 netCDF4-python - https://github.com/Unidata/netcdf4-python 


13NetCDF files can be directly visualized using the libAtoms flavor of 

14AtomEye (http://www.libatoms.org/), 

15VMD (http://www.ks.uiuc.edu/Research/vmd/) 

16or Ovito (http://www.ovito.org/, starting with version 2.3). 




20import os 

21import warnings 


23import numpy as np 


25import ase 


27from ase.data import atomic_masses 

28from ase.geometry import cellpar_to_cell 

29import collections 

30from functools import reduce 



33class NetCDFTrajectory: 

34 """ 

35 Reads/writes Atoms objects into an AMBER-style .nc trajectory file. 

36 """ 


38 # Default dimension names 

39 _frame_dim = 'frame' 

40 _spatial_dim = 'spatial' 

41 _atom_dim = 'atom' 

42 _cell_spatial_dim = 'cell_spatial' 

43 _cell_angular_dim = 'cell_angular' 

44 _label_dim = 'label' 

45 _Voigt_dim = 'Voigt' # For stress/strain tensors 


47 # Default field names. If it is a list, check for any of these names upon 

48 # opening. Upon writing, use the first name. 

49 _spatial_var = 'spatial' 

50 _cell_spatial_var = 'cell_spatial' 

51 _cell_angular_var = 'cell_angular' 

52 _time_var = 'time' 

53 _numbers_var = ['atom_types', 'type', 'Z'] 

54 _positions_var = 'coordinates' 

55 _velocities_var = 'velocities' 

56 _cell_origin_var = 'cell_origin' 

57 _cell_lengths_var = 'cell_lengths' 

58 _cell_angles_var = 'cell_angles' 


60 _default_vars = reduce(lambda x, y: x + y, 

61 [_numbers_var, [_positions_var], [_velocities_var], 

62 [_cell_origin_var], [_cell_lengths_var], 

63 [_cell_angles_var]]) 


65 def __init__(self, filename, mode='r', atoms=None, types_to_numbers=None, 

66 double=True, netcdf_format='NETCDF3_CLASSIC', keep_open=True, 

67 index_var='id', chunk_size=1000000): 

68 """ 

69 A NetCDFTrajectory can be created in read, write or append mode. 


71 Parameters: 


73 filename: 

74 The name of the parameter file. Should end in .nc. 


76 mode='r': 

77 The mode. 


79 'r' is read mode, the file should already exist, and no atoms 

80 argument should be specified. 


82 'w' is write mode. The atoms argument specifies the Atoms object 

83 to be written to the file, if not given it must instead be given 

84 as an argument to the write() method. 


86 'a' is append mode. It acts a write mode, except that data is 

87 appended to a preexisting file. 


89 atoms=None: 

90 The Atoms object to be written in write or append mode. 


92 types_to_numbers=None: 

93 Dictionary or list for conversion of atom types to atomic numbers 

94 when reading a trajectory file. 


96 double=True: 

97 Create new variable in double precision. 


99 netcdf_format='NETCDF3_CLASSIC': 

100 Format string for the underlying NetCDF file format. Only relevant 

101 if a new file is created. More information can be found at 

102 https://www.unidata.ucar.edu/software/netcdf/docs/netcdf/File-Format.html 


104 'NETCDF3_CLASSIC' is the original binary format. 


106 'NETCDF3_64BIT' can be used to write larger files. 


108 'NETCDF4_CLASSIC' is HDF5 with some NetCDF limitations. 


110 'NETCDF4' is HDF5. 


112 keep_open=True: 

113 Keep the file open during consecutive read/write operations. 

114 Set to false if you experience data corruption. This will close the 

115 file after each read/write operation by comes with serious 

116 performance penalty. 


118 index_var='id': 

119 Name of variable containing the atom indices. Atoms are reordered 

120 by this index upon reading if this variable is present. Default 

121 value is for LAMMPS output. None switches atom indices off. 


123 chunk_size=1000000: 

124 Maximum size of consecutive number of records (along the 'atom') 

125 dimension read when reading from a NetCDF file. This is used to 

126 reduce the memory footprint of a read operation on very large files. 

127 """ 

128 self.nc = None 

129 self.chunk_size = chunk_size 


131 self.numbers = None 

132 self.pre_observers = [] # Callback functions before write 

133 self.post_observers = [] # Callback functions after write are called 


135 self.has_header = False 

136 self._set_atoms(atoms) 


138 self.types_to_numbers = None 

139 if isinstance(types_to_numbers, list): 

140 types_to_numbers = {x: y for x, y in enumerate(types_to_numbers)} 

141 if types_to_numbers is not None: 

142 self.types_to_numbers = types_to_numbers 


144 self.index_var = index_var 


146 if self.index_var is not None: 

147 self._default_vars += [self.index_var] 


149 # 'l' should be a valid type according to the netcdf4-python 

150 # documentation, but does not appear to work. 

151 self.dtype_conv = {'l': 'i'} 

152 if not double: 

153 self.dtype_conv.update(dict(d='f')) 


155 self.extra_per_frame_vars = [] 

156 self.extra_per_file_vars = [] 

157 # per frame atts are global quantities, not quantities stored for each 

158 # atom 

159 self.extra_per_frame_atts = [] 


161 self.mode = mode 

162 self.netcdf_format = netcdf_format 


164 if atoms: 

165 self.n_atoms = len(atoms) 

166 else: 

167 self.n_atoms = None 


169 self.filename = filename 

170 if keep_open is None: 

171 # Only netCDF4-python supports append to files 

172 self.keep_open = self.mode == 'r' 

173 else: 

174 self.keep_open = keep_open 


176 def __del__(self): 

177 self.close() 


179 def _open(self): 

180 """ 

181 Opens the file. 


183 For internal use only. 

184 """ 

185 import netCDF4 

186 if self.nc is not None: 

187 return 

188 if self.mode == 'a' and not os.path.exists(self.filename): 

189 self.mode = 'w' 

190 self.nc = netCDF4.Dataset(self.filename, self.mode, 

191 format=self.netcdf_format) 


193 self.frame = 0 

194 if self.mode == 'r' or self.mode == 'a': 

195 self._read_header() 

196 self.frame = self._len() 


198 def _set_atoms(self, atoms=None): 

199 """ 

200 Associate an Atoms object with the trajectory. 


202 For internal use only. 

203 """ 

204 if atoms is not None and not hasattr(atoms, 'get_positions'): 

205 raise TypeError('"atoms" argument is not an Atoms object.') 

206 self.atoms = atoms 


208 def _read_header(self): 

209 if not self.n_atoms: 

210 self.n_atoms = len(self.nc.dimensions[self._atom_dim]) 


212 for name, var in self.nc.variables.items(): 

213 # This can be unicode which confuses ASE 

214 name = str(name) 

215 # _default_vars is taken care of already 

216 if name not in self._default_vars: 

217 if len(var.dimensions) >= 2: 

218 if var.dimensions[0] == self._frame_dim: 

219 if var.dimensions[1] == self._atom_dim: 

220 self.extra_per_frame_vars += [name] 

221 else: 

222 self.extra_per_frame_atts += [name] 


224 elif len(var.dimensions) == 1: 

225 if var.dimensions[0] == self._atom_dim: 

226 self.extra_per_file_vars += [name] 

227 elif var.dimensions[0] == self._frame_dim: 

228 self.extra_per_frame_atts += [name] 


230 self.has_header = True 


232 def write(self, atoms=None, frame=None, arrays=None, time=None): 

233 """ 

234 Write the atoms to the file. 


236 If the atoms argument is not given, the atoms object specified 

237 when creating the trajectory object is used. 

238 """ 

239 self._open() 

240 self._call_observers(self.pre_observers) 

241 if atoms is None: 

242 atoms = self.atoms 


244 if hasattr(atoms, 'interpolate'): 

245 # seems to be a NEB 

246 neb = atoms 

247 assert not neb.parallel 

248 try: 

249 neb.get_energies_and_forces(all=True) 

250 except AttributeError: 

251 pass 

252 for image in neb.images: 

253 self.write(image) 

254 return 


256 if not self.has_header: 

257 self._define_file_structure(atoms) 

258 else: 

259 if len(atoms) != self.n_atoms: 

260 raise ValueError('Bad number of atoms!') 


262 if frame is None: 

263 i = self.frame 

264 else: 

265 i = frame 


267 # Number can be per file variable 

268 numbers = self._get_variable(self._numbers_var) 

269 if numbers.dimensions[0] == self._frame_dim: 

270 numbers[i] = atoms.get_atomic_numbers() 

271 else: 

272 if np.any(numbers != atoms.get_atomic_numbers()): 

273 raise ValueError('Atomic numbers do not match!') 

274 self._get_variable(self._positions_var)[i] = atoms.get_positions() 

275 if atoms.has('momenta'): 

276 self._add_velocities() 

277 self._get_variable(self._velocities_var)[i] = \ 

278 atoms.get_momenta() / atoms.get_masses().reshape(-1, 1) 

279 a, b, c, alpha, beta, gamma = atoms.cell.cellpar() 

280 if np.any(np.logical_not(atoms.pbc)): 

281 warnings.warn('Atoms have nonperiodic directions. Cell lengths in ' 

282 'these directions are lost and will be ' 

283 'shrink-wrapped when reading the NetCDF file.') 

284 cell_lengths = np.array([a, b, c]) * atoms.pbc 

285 self._get_variable(self._cell_lengths_var)[i] = cell_lengths 

286 self._get_variable(self._cell_angles_var)[i] = [alpha, beta, gamma] 

287 self._get_variable(self._cell_origin_var)[i] = \ 

288 atoms.get_celldisp().reshape(3) 

289 if arrays is not None: 

290 for array in arrays: 

291 data = atoms.get_array(array) 

292 if array in self.extra_per_file_vars: 

293 # This field exists but is per file data. Check that the 

294 # data remains consistent. 

295 if np.any(self._get_variable(array) != data): 

296 raise ValueError('Trying to write Atoms object with ' 

297 'incompatible data for the {0} ' 

298 'array.'.format(array)) 

299 else: 

300 self._add_array(atoms, array, data.dtype, data.shape) 

301 self._get_variable(array)[i] = data 

302 if time is not None: 

303 self._add_time() 

304 self._get_variable(self._time_var)[i] = time 


306 self.sync() 


308 self._call_observers(self.post_observers) 

309 self.frame += 1 

310 self._close() 


312 def write_arrays(self, atoms, frame, arrays): 

313 self._open() 

314 self._call_observers(self.pre_observers) 

315 for array in arrays: 

316 data = atoms.get_array(array) 

317 if array in self.extra_per_file_vars: 

318 # This field exists but is per file data. Check that the 

319 # data remains consistent. 

320 if np.any(self._get_variable(array) != data): 

321 raise ValueError('Trying to write Atoms object with ' 

322 'incompatible data for the {0} ' 

323 'array.'.format(array)) 

324 else: 

325 self._add_array(atoms, array, data.dtype, data.shape) 

326 self._get_variable(array)[frame] = data 

327 self._call_observers(self.post_observers) 

328 self._close() 


330 def _define_file_structure(self, atoms): 

331 self.nc.Conventions = 'AMBER' 

332 self.nc.ConventionVersion = '1.0' 

333 self.nc.program = 'ASE' 

334 self.nc.programVersion = ase.__version__ 

335 self.nc.title = "MOL" 


337 if self._frame_dim not in self.nc.dimensions: 

338 self.nc.createDimension(self._frame_dim, None) 

339 if self._spatial_dim not in self.nc.dimensions: 

340 self.nc.createDimension(self._spatial_dim, 3) 

341 if self._atom_dim not in self.nc.dimensions: 

342 self.nc.createDimension(self._atom_dim, len(atoms)) 

343 if self._cell_spatial_dim not in self.nc.dimensions: 

344 self.nc.createDimension(self._cell_spatial_dim, 3) 

345 if self._cell_angular_dim not in self.nc.dimensions: 

346 self.nc.createDimension(self._cell_angular_dim, 3) 

347 if self._label_dim not in self.nc.dimensions: 

348 self.nc.createDimension(self._label_dim, 5) 


350 # Self-describing variables from AMBER convention 

351 if not self._has_variable(self._spatial_var): 

352 self.nc.createVariable(self._spatial_var, 'S1', 

353 (self._spatial_dim,)) 

354 self.nc.variables[self._spatial_var][:] = ['x', 'y', 'z'] 

355 if not self._has_variable(self._cell_spatial_var): 

356 self.nc.createVariable(self._cell_spatial_dim, 'S1', 

357 (self._cell_spatial_dim,)) 

358 self.nc.variables[self._cell_spatial_var][:] = ['a', 'b', 'c'] 

359 if not self._has_variable(self._cell_angular_var): 

360 self.nc.createVariable(self._cell_angular_var, 'S1', 

361 (self._cell_angular_dim, self._label_dim,)) 

362 self.nc.variables[self._cell_angular_var][0] = [x for x in 'alpha'] 

363 self.nc.variables[self._cell_angular_var][1] = [x for x in 'beta '] 

364 self.nc.variables[self._cell_angular_var][2] = [x for x in 'gamma'] 


366 if not self._has_variable(self._numbers_var): 

367 self.nc.createVariable(self._numbers_var[0], 'i', 

368 (self._frame_dim, self._atom_dim,)) 

369 if not self._has_variable(self._positions_var): 

370 self.nc.createVariable(self._positions_var, 'f4', 

371 (self._frame_dim, self._atom_dim, 

372 self._spatial_dim)) 

373 self.nc.variables[self._positions_var].units = 'Angstrom' 

374 self.nc.variables[self._positions_var].scale_factor = 1. 

375 if not self._has_variable(self._cell_lengths_var): 

376 self.nc.createVariable(self._cell_lengths_var, 'd', 

377 (self._frame_dim, self._cell_spatial_dim)) 

378 self.nc.variables[self._cell_lengths_var].units = 'Angstrom' 

379 self.nc.variables[self._cell_lengths_var].scale_factor = 1. 

380 if not self._has_variable(self._cell_angles_var): 

381 self.nc.createVariable(self._cell_angles_var, 'd', 

382 (self._frame_dim, self._cell_angular_dim)) 

383 self.nc.variables[self._cell_angles_var].units = 'degree' 

384 if not self._has_variable(self._cell_origin_var): 

385 self.nc.createVariable(self._cell_origin_var, 'd', 

386 (self._frame_dim, self._cell_spatial_dim)) 

387 self.nc.variables[self._cell_origin_var].units = 'Angstrom' 

388 self.nc.variables[self._cell_origin_var].scale_factor = 1. 


390 def _add_time(self): 

391 if not self._has_variable(self._time_var): 

392 self.nc.createVariable(self._time_var, 'f8', (self._frame_dim,)) 


394 def _add_velocities(self): 

395 if not self._has_variable(self._velocities_var): 

396 self.nc.createVariable(self._velocities_var, 'f4', 

397 (self._frame_dim, self._atom_dim, 

398 self._spatial_dim)) 

399 self.nc.variables[self._positions_var].units = \ 

400 'Angstrom/Femtosecond' 

401 self.nc.variables[self._positions_var].scale_factor = 1. 


403 def _add_array(self, atoms, array_name, type, shape): 

404 if not self._has_variable(array_name): 

405 dims = [self._frame_dim] 

406 for i in shape: 

407 if i == len(atoms): 

408 dims += [self._atom_dim] 

409 elif i == 3: 

410 dims += [self._spatial_dim] 

411 elif i == 6: 

412 # This can only be stress/strain tensor in Voigt notation 

413 if self._Voigt_dim not in self.nc.dimensions: 

414 self.nc.createDimension(self._Voigt_dim, 6) 

415 dims += [self._Voigt_dim] 

416 else: 

417 raise TypeError("Don't know how to dump array of shape {0}" 

418 " into NetCDF trajectory.".format(shape)) 

419 if hasattr(type, 'char'): 

420 t = self.dtype_conv.get(type.char, type) 

421 else: 

422 t = type 

423 self.nc.createVariable(array_name, t, dims) 


425 def _get_variable(self, name, exc=True): 

426 if isinstance(name, list): 

427 for n in name: 

428 if n in self.nc.variables: 

429 return self.nc.variables[n] 

430 if exc: 

431 raise RuntimeError( 

432 'None of the variables {0} was found in the ' 

433 'NetCDF trajectory.'.format(', '.join(name))) 

434 else: 

435 if name in self.nc.variables: 

436 return self.nc.variables[name] 

437 if exc: 

438 raise RuntimeError('Variables {0} was found in the NetCDF ' 

439 'trajectory.'.format(name)) 

440 return None 


442 def _has_variable(self, name): 

443 if isinstance(name, list): 

444 for n in name: 

445 if n in self.nc.variables: 

446 return True 

447 return False 

448 else: 

449 return name in self.nc.variables 


451 def _get_data(self, name, frame, index, exc=True): 

452 var = self._get_variable(name, exc=exc) 

453 if var is None: 

454 return None 

455 if var.dimensions[0] == self._frame_dim: 

456 data = np.zeros(var.shape[1:], dtype=var.dtype) 

457 s = var.shape[1] 

458 if s < self.chunk_size: 

459 data[index] = var[frame] 

460 else: 

461 # If this is a large data set, only read chunks from it to 

462 # reduce memory footprint of the NetCDFTrajectory reader. 

463 for i in range((s - 1) // self.chunk_size + 1): 

464 sl = slice(i * self.chunk_size, 

465 min((i + 1) * self.chunk_size, s)) 

466 data[index[sl]] = var[frame, sl] 

467 else: 

468 data = np.zeros(var.shape, dtype=var.dtype) 

469 s = var.shape[0] 

470 if s < self.chunk_size: 

471 data[index] = var[...] 

472 else: 

473 # If this is a large data set, only read chunks from it to 

474 # reduce memory footprint of the NetCDFTrajectory reader. 

475 for i in range((s - 1) // self.chunk_size + 1): 

476 sl = slice(i * self.chunk_size, 

477 min((i + 1) * self.chunk_size, s)) 

478 data[index[sl]] = var[sl] 

479 return data 


481 def __enter__(self): 

482 return self 


484 def __exit__(self, *args): 

485 self.close() 


487 def close(self): 

488 """Close the trajectory file.""" 

489 if self.nc is not None: 

490 self.nc.close() 

491 self.nc = None 


493 def _close(self): 

494 if not self.keep_open: 

495 self.close() 

496 if self.mode == 'w': 

497 self.mode = 'a' 


499 def sync(self): 

500 self.nc.sync() 


502 def __getitem__(self, i=-1): 

503 self._open() 


505 if isinstance(i, slice): 

506 return [self[j] for j in range(*i.indices(self._len()))] 


508 N = self._len() 

509 if 0 <= i < N: 

510 # Non-periodic boundaries have cell_length == 0.0 

511 cell_lengths = \ 

512 np.array(self.nc.variables[self._cell_lengths_var][i][:]) 

513 pbc = np.abs(cell_lengths > 1e-6) 


515 # Do we have a cell origin? 

516 if self._has_variable(self._cell_origin_var): 

517 origin = np.array( 

518 self.nc.variables[self._cell_origin_var][i][:]) 

519 else: 

520 origin = np.zeros([3], dtype=float) 


522 # Do we have an index variable? 

523 if (self.index_var is not None and 

524 self._has_variable(self.index_var)): 

525 index = np.array(self.nc.variables[self.index_var][i][:]) 

526 # The index variable can be non-consecutive, we here construct 

527 # a consecutive one. 

528 consecutive_index = np.zeros_like(index) 

529 consecutive_index[np.argsort(index)] = np.arange(self.n_atoms) 

530 else: 

531 consecutive_index = np.arange(self.n_atoms) 


533 # Read element numbers 

534 self.numbers = self._get_data(self._numbers_var, i, 

535 consecutive_index, exc=False) 

536 if self.numbers is None: 

537 self.numbers = np.ones(self.n_atoms, dtype=int) 

538 if self.types_to_numbers is not None: 

539 d = set(self.numbers).difference(self.types_to_numbers.keys()) 

540 if len(d) > 0: 

541 self.types_to_numbers.update({num: num for num in d}) 

542 func = np.vectorize(self.types_to_numbers.get) 

543 self.numbers = func(self.numbers) 

544 self.masses = atomic_masses[self.numbers] 


546 # Read positions 

547 positions = self._get_data(self._positions_var, i, 

548 consecutive_index) 


550 # Determine cell size for non-periodic directions from shrink 

551 # wrapped cell. 

552 for dim in np.arange(3)[np.logical_not(pbc)]: 

553 origin[dim] = positions[:, dim].min() 

554 cell_lengths[dim] = positions[:, dim].max() - origin[dim] 


556 # Construct cell shape from cell lengths and angles 

557 cell = cellpar_to_cell( 

558 list(cell_lengths) + 

559 list(self.nc.variables[self._cell_angles_var][i]) 

560 ) 


562 # Compute momenta from velocities (if present) 

563 momenta = self._get_data(self._velocities_var, i, 

564 consecutive_index, exc=False) 

565 if momenta is not None: 

566 momenta *= self.masses.reshape(-1, 1) 


568 # Fill info dict with additional data found in the NetCDF file 

569 info = {} 

570 for name in self.extra_per_frame_atts: 

571 info[name] = np.array(self.nc.variables[name][i]) 


573 # Create atoms object 

574 atoms = ase.Atoms( 

575 positions=positions, 

576 numbers=self.numbers, 

577 cell=cell, 

578 celldisp=origin, 

579 momenta=momenta, 

580 masses=self.masses, 

581 pbc=pbc, 

582 info=info 

583 ) 


585 # Attach additional arrays found in the NetCDF file 

586 for name in self.extra_per_frame_vars: 

587 atoms.set_array(name, self._get_data(name, i, 

588 consecutive_index)) 

589 for name in self.extra_per_file_vars: 

590 atoms.set_array(name, self._get_data(name, i, 

591 consecutive_index)) 

592 self._close() 

593 return atoms 


595 i = N + i 

596 if i < 0 or i >= N: 

597 self._close() 

598 raise IndexError('Trajectory index out of range.') 

599 return self[i] 


601 def _len(self): 

602 if self._frame_dim in self.nc.dimensions: 

603 return int(self._get_variable(self._positions_var).shape[0]) 

604 else: 

605 return 0 


607 def __len__(self): 

608 self._open() 

609 n_frames = self._len() 

610 self._close() 

611 return n_frames 


613 def pre_write_attach(self, function, interval=1, *args, **kwargs): 

614 """ 

615 Attach a function to be called before writing begins. 


617 function: The function or callable object to be called. 


619 interval: How often the function is called. Default: every time (1). 


621 All other arguments are stored, and passed to the function. 

622 """ 

623 if not isinstance(function, collections.Callable): 

624 raise ValueError('Callback object must be callable.') 

625 self.pre_observers.append((function, interval, args, kwargs)) 


627 def post_write_attach(self, function, interval=1, *args, **kwargs): 

628 """ 

629 Attach a function to be called after writing ends. 


631 function: The function or callable object to be called. 


633 interval: How often the function is called. Default: every time (1). 


635 All other arguments are stored, and passed to the function. 

636 """ 

637 if not isinstance(function, collections.Callable): 

638 raise ValueError('Callback object must be callable.') 

639 self.post_observers.append((function, interval, args, kwargs)) 


641 def _call_observers(self, obs): 

642 """Call pre/post write observers.""" 

643 for function, interval, args, kwargs in obs: 

644 if self.write_counter % interval == 0: 

645 function(*args, **kwargs) 



648def read_netcdftrajectory(filename, index=-1): 

649 with NetCDFTrajectory(filename, mode='r') as traj: 

650 return traj[index] 



653def write_netcdftrajectory(filename, images): 

654 if hasattr(images, 'get_positions'): 

655 images = [images] 


657 with NetCDFTrajectory(filename, mode='w') as traj: 

658 for atoms in images: 

659 traj.write(atoms)