Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import os 

2import sys 

3import errno 

4import pickle 

5import warnings 

6import collections 

7 

8# Python 3 stuff: 

9try: 

10 unicode 

11except NameError: 

12 unicode = str 

13 

14# pass for WindowsError on non-Win platforms 

15try: 

16 WindowsError 

17except NameError: 

18 class WindowsError(OSError): 

19 pass 

20 

21import numpy as np 

22 

23from ase.atoms import Atoms 

24from ase.calculators.singlepoint import SinglePointCalculator 

25from ase.calculators.calculator import PropertyNotImplementedError 

26from ase.constraints import FixAtoms 

27from ase.parallel import world, barrier 

28 

29 

30class PickleTrajectory: 

31 """Reads/writes Atoms objects into a .traj file.""" 

32 # Per default, write these quantities 

33 write_energy = True 

34 write_forces = True 

35 write_stress = True 

36 write_charges = True 

37 write_magmoms = True 

38 write_momenta = True 

39 write_info = True 

40 

41 def __init__(self, filename, mode='r', atoms=None, master=None, 

42 backup=True, _warn=True): 

43 """A PickleTrajectory can be created in read, write or append mode. 

44 

45 Parameters: 

46 

47 filename: 

48 The name of the parameter file. Should end in .traj. 

49 

50 mode='r': 

51 The mode. 

52 

53 'r' is read mode, the file should already exist, and 

54 no atoms argument should be specified. 

55 

56 'w' is write mode. If the file already exists, it is 

57 renamed by appending .bak to the file name. The atoms 

58 argument specifies the Atoms object to be written to the 

59 file, if not given it must instead be given as an argument 

60 to the write() method. 

61 

62 'a' is append mode. It acts a write mode, except that 

63 data is appended to a preexisting file. 

64 

65 atoms=None: 

66 The Atoms object to be written in write or append mode. 

67 

68 master=None: 

69 Controls which process does the actual writing. The 

70 default is that process number 0 does this. If this 

71 argument is given, processes where it is True will write. 

72 

73 backup=True: 

74 Use backup=False to disable renaming of an existing file. 

75 """ 

76 

77 if _warn: 

78 msg = 'Please stop using old trajectory files!' 

79 if mode == 'r': 

80 msg += ('\nConvert to the new future-proof format like this:\n' 

81 '\n $ python3 -m ase.io.trajectory ' + 

82 filename + '\n') 

83 raise DeprecationWarning(msg) 

84 

85 self.numbers = None 

86 self.pbc = None 

87 self.sanitycheck = True 

88 self.pre_observers = [] # Callback functions before write 

89 self.post_observers = [] # Callback functions after write 

90 

91 # Counter used to determine when callbacks are called: 

92 self.write_counter = 0 

93 

94 self.offsets = [] 

95 if master is None: 

96 master = (world.rank == 0) 

97 self.master = master 

98 self.backup = backup 

99 self.set_atoms(atoms) 

100 self.open(filename, mode) 

101 

102 def open(self, filename, mode): 

103 """Opens the file. 

104 

105 For internal use only. 

106 """ 

107 self.fd = filename 

108 if mode == 'r': 

109 if isinstance(filename, str): 

110 self.fd = open(filename, 'rb') 

111 self.read_header() 

112 elif mode == 'a': 

113 exists = True 

114 if isinstance(filename, str): 

115 exists = os.path.isfile(filename) 

116 if exists: 

117 exists = os.path.getsize(filename) > 0 

118 if exists: 

119 self.fd = open(filename, 'rb') 

120 self.read_header() 

121 self.fd.close() 

122 barrier() 

123 if self.master: 

124 self.fd = open(filename, 'ab+') 

125 else: 

126 self.fd = open(os.devnull, 'ab+') 

127 elif mode == 'w': 

128 if self.master: 

129 if isinstance(filename, str): 

130 if self.backup and os.path.isfile(filename): 

131 try: 

132 os.rename(filename, filename + '.bak') 

133 except WindowsError as e: 

134 # this must run on Win only! Not atomic! 

135 if e.errno != errno.EEXIST: 

136 raise 

137 os.unlink(filename + '.bak') 

138 os.rename(filename, filename + '.bak') 

139 self.fd = open(filename, 'wb') 

140 else: 

141 self.fd = open(os.devnull, 'wb') 

142 else: 

143 raise ValueError('mode must be "r", "w" or "a".') 

144 

145 def set_atoms(self, atoms=None): 

146 """Associate an Atoms object with the trajectory. 

147 

148 Mostly for internal use. 

149 """ 

150 if atoms is not None and not hasattr(atoms, 'get_positions'): 

151 raise TypeError('"atoms" argument is not an Atoms object.') 

152 self.atoms = atoms 

153 

154 def read_header(self): 

155 if hasattr(self.fd, 'name'): 

156 if os.path.isfile(self.fd.name): 

157 if os.path.getsize(self.fd.name) == 0: 

158 return 

159 self.fd.seek(0) 

160 try: 

161 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory': 

162 raise IOError('This is not a trajectory file!') 

163 d = pickle.load(self.fd) 

164 except EOFError: 

165 raise EOFError('Bad trajectory file.') 

166 

167 self.pbc = d['pbc'] 

168 self.numbers = d['numbers'] 

169 self.tags = d.get('tags') 

170 self.masses = d.get('masses') 

171 self.constraints = dict2constraints(d) 

172 self.offsets.append(self.fd.tell()) 

173 

174 def write(self, atoms=None): 

175 if atoms is None: 

176 atoms = self.atoms 

177 

178 for image in atoms.iterimages(): 

179 self._write_atoms(image) 

180 

181 def _write_atoms(self, atoms): 

182 """Write the atoms to the file. 

183 

184 If the atoms argument is not given, the atoms object specified 

185 when creating the trajectory object is used. 

186 """ 

187 self._call_observers(self.pre_observers) 

188 

189 if len(self.offsets) == 0: 

190 self.write_header(atoms) 

191 else: 

192 if (atoms.pbc != self.pbc).any(): 

193 raise ValueError('Bad periodic boundary conditions!') 

194 elif self.sanitycheck and len(atoms) != len(self.numbers): 

195 raise ValueError('Bad number of atoms!') 

196 elif self.sanitycheck and (atoms.numbers != self.numbers).any(): 

197 raise ValueError('Bad atomic numbers!') 

198 

199 if atoms.has('momenta'): 

200 momenta = atoms.get_momenta() 

201 else: 

202 momenta = None 

203 

204 d = {'positions': atoms.get_positions(), 

205 'cell': atoms.get_cell(), 

206 'momenta': momenta} 

207 

208 if atoms.calc is not None: 

209 if self.write_energy: 

210 d['energy'] = atoms.get_potential_energy() 

211 if self.write_forces: 

212 assert self.write_energy 

213 try: 

214 d['forces'] = atoms.get_forces(apply_constraint=False) 

215 except PropertyNotImplementedError: 

216 pass 

217 if self.write_stress: 

218 assert self.write_energy 

219 try: 

220 d['stress'] = atoms.get_stress() 

221 except PropertyNotImplementedError: 

222 pass 

223 if self.write_charges: 

224 try: 

225 d['charges'] = atoms.get_charges() 

226 except PropertyNotImplementedError: 

227 pass 

228 if self.write_magmoms: 

229 try: 

230 magmoms = atoms.get_magnetic_moments() 

231 if any(np.asarray(magmoms).flat): 

232 d['magmoms'] = magmoms 

233 except (PropertyNotImplementedError, AttributeError): 

234 pass 

235 

236 if 'magmoms' not in d and atoms.has('initial_magmoms'): 

237 d['magmoms'] = atoms.get_initial_magnetic_moments() 

238 if 'charges' not in d and atoms.has('initial_charges'): 

239 charges = atoms.get_initial_charges() 

240 if (charges != 0).any(): 

241 d['charges'] = charges 

242 

243 if self.write_info: 

244 d['info'] = stringnify_info(atoms.info) 

245 

246 if self.master: 

247 pickle.dump(d, self.fd, protocol=2) 

248 self.fd.flush() 

249 self.offsets.append(self.fd.tell()) 

250 self._call_observers(self.post_observers) 

251 self.write_counter += 1 

252 

253 def write_header(self, atoms): 

254 self.fd.write(b'PickleTrajectory') 

255 if atoms.has('tags'): 

256 tags = atoms.get_tags() 

257 else: 

258 tags = None 

259 if atoms.has('masses'): 

260 masses = atoms.get_masses() 

261 else: 

262 masses = None 

263 d = {'version': 3, 

264 'pbc': atoms.get_pbc(), 

265 'numbers': atoms.get_atomic_numbers(), 

266 'tags': tags, 

267 'masses': masses, 

268 'constraints': [], # backwards compatibility 

269 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)} 

270 pickle.dump(d, self.fd, protocol=2) 

271 self.header_written = True 

272 self.offsets.append(self.fd.tell()) 

273 

274 # Atomic numbers and periodic boundary conditions are only 

275 # written once - in the header. Store them here so that we can 

276 # check that they are the same for all images: 

277 self.numbers = atoms.get_atomic_numbers() 

278 self.pbc = atoms.get_pbc() 

279 

280 def close(self): 

281 """Close the trajectory file.""" 

282 self.fd.close() 

283 

284 def __getitem__(self, i=-1): 

285 if isinstance(i, slice): 

286 return [self[j] for j in range(*i.indices(len(self)))] 

287 

288 N = len(self.offsets) 

289 if 0 <= i < N: 

290 self.fd.seek(self.offsets[i]) 

291 try: 

292 d = pickle.load(self.fd, encoding='bytes') 

293 d = {k.decode() if isinstance(k, bytes) else k: v 

294 for k, v in d.items()} 

295 except EOFError: 

296 raise IndexError 

297 if i == N - 1: 

298 self.offsets.append(self.fd.tell()) 

299 charges = d.get('charges') 

300 magmoms = d.get('magmoms') 

301 try: 

302 constraints = [c.copy() for c in self.constraints] 

303 except Exception: 

304 constraints = [] 

305 warnings.warn('Constraints did not unpickle correctly.') 

306 atoms = Atoms(positions=d['positions'], 

307 numbers=self.numbers, 

308 cell=d['cell'], 

309 momenta=d['momenta'], 

310 magmoms=magmoms, 

311 charges=charges, 

312 tags=self.tags, 

313 masses=self.masses, 

314 pbc=self.pbc, 

315 info=unstringnify_info(d.get('info', {})), 

316 constraint=constraints) 

317 if 'energy' in d: 

318 calc = SinglePointCalculator( 

319 atoms, 

320 energy=d.get('energy', None), 

321 forces=d.get('forces', None), 

322 stress=d.get('stress', None), 

323 magmoms=magmoms) 

324 atoms.calc = calc 

325 return atoms 

326 

327 if i >= N: 

328 for j in range(N - 1, i + 1): 

329 atoms = self[j] 

330 return atoms 

331 

332 i = len(self) + i 

333 if i < 0: 

334 raise IndexError('Trajectory index out of range.') 

335 return self[i] 

336 

337 def __len__(self): 

338 if len(self.offsets) == 0: 

339 return 0 

340 N = len(self.offsets) - 1 

341 while True: 

342 self.fd.seek(self.offsets[N]) 

343 try: 

344 pickle.load(self.fd) 

345 except EOFError: 

346 return N 

347 self.offsets.append(self.fd.tell()) 

348 N += 1 

349 

350 def pre_write_attach(self, function, interval=1, *args, **kwargs): 

351 """Attach a function to be called before writing begins. 

352 

353 function: The function or callable object to be called. 

354 

355 interval: How often the function is called. Default: every time (1). 

356 

357 All other arguments are stored, and passed to the function. 

358 """ 

359 if not isinstance(function, collections.Callable): 

360 raise ValueError('Callback object must be callable.') 

361 self.pre_observers.append((function, interval, args, kwargs)) 

362 

363 def post_write_attach(self, function, interval=1, *args, **kwargs): 

364 """Attach a function to be called after writing ends. 

365 

366 function: The function or callable object to be called. 

367 

368 interval: How often the function is called. Default: every time (1). 

369 

370 All other arguments are stored, and passed to the function. 

371 """ 

372 if not isinstance(function, collections.Callable): 

373 raise ValueError('Callback object must be callable.') 

374 self.post_observers.append((function, interval, args, kwargs)) 

375 

376 def _call_observers(self, obs): 

377 """Call pre/post write observers.""" 

378 for function, interval, args, kwargs in obs: 

379 if self.write_counter % interval == 0: 

380 function(*args, **kwargs) 

381 

382 def __enter__(self): 

383 return self 

384 

385 def __exit__(self, *args): 

386 self.close() 

387 

388 

389def stringnify_info(info): 

390 """Return a stringnified version of the dict *info* that is 

391 ensured to be picklable. Items with non-string keys or 

392 unpicklable values are dropped and a warning is issued.""" 

393 stringnified = {} 

394 for k, v in info.items(): 

395 if not isinstance(k, str): 

396 warnings.warn('Non-string info-dict key is not stored in ' + 

397 'trajectory: ' + repr(k), UserWarning) 

398 continue 

399 try: 

400 # Should highest protocol be used here for efficiency? 

401 # Protocol 2 seems not to raise an exception when one 

402 # tries to pickle a file object, so by using that, we 

403 # might end up with file objects in inconsistent states. 

404 s = pickle.dumps(v, protocol=0) 

405 except pickle.PicklingError: 

406 warnings.warn('Skipping not picklable info-dict item: ' + 

407 '"%s" (%s)' % (k, sys.exc_info()[1]), UserWarning) 

408 else: 

409 stringnified[k] = s 

410 return stringnified 

411 

412 

413def unstringnify_info(stringnified): 

414 """Convert the dict *stringnified* to a dict with unstringnified 

415 objects and return it. Objects that cannot be unpickled will be 

416 skipped and a warning will be issued.""" 

417 info = {} 

418 for k, s in stringnified.items(): 

419 try: 

420 v = pickle.loads(s) 

421 except pickle.UnpicklingError: 

422 warnings.warn('Skipping not unpicklable info-dict item: ' + 

423 '"%s" (%s)' % (k, sys.exc_info()[1]), UserWarning) 

424 else: 

425 info[k] = v 

426 return info 

427 

428 

429def dict2constraints(d): 

430 """Convert dict unpickled from trajectory file to list of constraints.""" 

431 

432 version = d.get('version', 1) 

433 

434 if version == 1: 

435 return d['constraints'] 

436 elif version in (2, 3): 

437 try: 

438 constraints = pickle.loads(d['constraints_string']) 

439 for c in constraints: 

440 if isinstance(c, FixAtoms) and c.index.dtype == bool: 

441 # Special handling of old pickles: 

442 c.index = np.arange(len(c.index))[c.index] 

443 return constraints 

444 except (AttributeError, KeyError, EOFError, ImportError, TypeError): 

445 warnings.warn('Could not unpickle constraints!') 

446 return [] 

447 else: 

448 return []