Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import errno 

2import functools 

3import os 

4import io 

5import pickle 

6import sys 

7import time 

8import string 

9import warnings 

10from importlib import import_module 

11from math import sin, cos, radians, atan2, degrees 

12from contextlib import contextmanager, ExitStack 

13from math import gcd 

14from pathlib import PurePath, Path 

15import re 

16 

17import numpy as np 

18 

19from ase.formula import formula_hill, formula_metal 

20 

21__all__ = ['exec_', 'basestring', 'import_module', 'seterr', 'plural', 

22 'devnull', 'gcd', 'convert_string_to_fd', 'Lock', 

23 'opencew', 'OpenLock', 'rotate', 'irotate', 'pbc2pbc', 'givens', 

24 'hsv2rgb', 'hsv', 'pickleload', 'FileNotFoundError', 

25 'formula_hill', 'formula_metal', 'PurePath', 'xwopen', 

26 'tokenize_version'] 

27 

28 

29def tokenize_version(version_string: str): 

30 """Parse version string into a tuple for version comparisons. 

31 

32 Usage: tokenize_version('3.8') < tokenize_version('3.8.1'). 

33 """ 

34 tokens = [] 

35 for component in version_string.split('.'): 

36 match = re.match(r'(\d*)(.*)', component) 

37 assert match is not None, f'Cannot parse component {component}' 

38 number_str, tail = match.group(1, 2) 

39 try: 

40 number = int(number_str) 

41 except ValueError: 

42 number = -1 

43 tokens += [number, tail] 

44 return tuple(tokens) 

45 

46 

47# Python 2+3 compatibility stuff (let's try to remove these things): 

48basestring = str 

49pickleload = functools.partial(pickle.load, encoding='bytes') 

50 

51 

52def deprecated(msg, category=FutureWarning): 

53 """Return a decorator deprecating a function. 

54 

55 Use like @deprecated('warning message and explanation').""" 

56 def deprecated_decorator(func): 

57 @functools.wraps(func) 

58 def deprecated_function(*args, **kwargs): 

59 warning = msg 

60 if not isinstance(warning, Warning): 

61 warning = category(warning) 

62 warnings.warn(warning) 

63 return func(*args, **kwargs) 

64 return deprecated_function 

65 return deprecated_decorator 

66 

67 

68@contextmanager 

69def seterr(**kwargs): 

70 """Set how floating-point errors are handled. 

71 

72 See np.seterr() for more details. 

73 """ 

74 old = np.seterr(**kwargs) 

75 try: 

76 yield 

77 finally: 

78 np.seterr(**old) 

79 

80 

81def plural(n, word): 

82 """Use plural for n!=1. 

83 

84 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg') 

85 ('0 eggs', '1 egg', '2 eggs') 

86 """ 

87 if n == 1: 

88 return '1 ' + word 

89 return '%d %ss' % (n, word) 

90 

91 

92class DevNull: 

93 encoding = 'UTF-8' 

94 closed = False 

95 

96 _use_os_devnull = deprecated('use open(os.devnull) instead', 

97 DeprecationWarning) 

98 # Deprecated for ase-3.21.0. Change to futurewarning later on. 

99 

100 @_use_os_devnull 

101 def write(self, string): 

102 pass 

103 

104 @_use_os_devnull 

105 def flush(self): 

106 pass 

107 

108 @_use_os_devnull 

109 def seek(self, offset, whence=0): 

110 return 0 

111 

112 @_use_os_devnull 

113 def tell(self): 

114 return 0 

115 

116 @_use_os_devnull 

117 def close(self): 

118 pass 

119 

120 @_use_os_devnull 

121 def isatty(self): 

122 return False 

123 

124 @_use_os_devnull 

125 def read(self, n=-1): 

126 return '' 

127 

128 

129devnull = DevNull() 

130 

131 

132@deprecated('convert_string_to_fd does not facilitate proper resource ' 

133 'management. ' 

134 'Please use e.g. ase.utils.IOContext class instead.') 

135def convert_string_to_fd(name, world=None): 

136 """Create a file-descriptor for text output. 

137 

138 Will open a file for writing with given name. Use None for no output and 

139 '-' for sys.stdout. 

140 """ 

141 if world is None: 

142 from ase.parallel import world 

143 if name is None or world.rank != 0: 

144 return open(os.devnull, 'w') 

145 if name == '-': 

146 return sys.stdout 

147 if isinstance(name, (str, PurePath)): 

148 return open(str(name), 'w') # str for py3.5 pathlib 

149 return name # we assume name is already a file-descriptor 

150 

151 

152# Only Windows has O_BINARY: 

153CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0) 

154 

155 

156@contextmanager 

157def xwopen(filename, world=None): 

158 """Create and open filename exclusively for writing. 

159 

160 If master cpu gets exclusive write access to filename, a file 

161 descriptor is returned (a dummy file descriptor is returned on the 

162 slaves). If the master cpu does not get write access, None is 

163 returned on all processors.""" 

164 

165 fd = opencew(filename, world) 

166 try: 

167 yield fd 

168 finally: 

169 if fd is not None: 

170 fd.close() 

171 

172 

173#@deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak') 

174def opencew(filename, world=None): 

175 return _opencew(filename, world) 

176 

177 

178def _opencew(filename, world=None): 

179 if world is None: 

180 from ase.parallel import world 

181 

182 closelater = [] 

183 

184 def opener(file, flags): 

185 return os.open(file, flags | CEW_FLAGS) 

186 

187 try: 

188 error = 0 

189 if world.rank == 0: 

190 try: 

191 fd = open(filename, 'wb', opener=opener) 

192 except OSError as ex: 

193 error = ex.errno 

194 else: 

195 closelater.append(fd) 

196 else: 

197 fd = open(os.devnull, 'wb') 

198 closelater.append(fd) 

199 

200 # Synchronize: 

201 error = world.sum(error) 

202 if error == errno.EEXIST: 

203 return None 

204 if error: 

205 raise OSError(error, 'Error', filename) 

206 

207 return fd 

208 except BaseException: 

209 for fd in closelater: 

210 fd.close() 

211 raise 

212 

213 

214def opencew_text(*args, **kwargs): 

215 fd = opencew(*args, **kwargs) 

216 if fd is None: 

217 return None 

218 return io.TextIOWrapper(fd) 

219 

220 

221class Lock: 

222 def __init__(self, name='lock', world=None, timeout=float('inf')): 

223 self.name = str(name) 

224 self.timeout = timeout 

225 if world is None: 

226 from ase.parallel import world 

227 self.world = world 

228 

229 def acquire(self): 

230 dt = 0.2 

231 t1 = time.time() 

232 while True: 

233 fd = opencew(self.name, self.world) 

234 if fd is not None: 

235 self.fd = fd 

236 break 

237 time_left = self.timeout - (time.time() - t1) 

238 if time_left <= 0: 

239 raise TimeoutError 

240 time.sleep(min(dt, time_left)) 

241 dt *= 2 

242 

243 def release(self): 

244 self.world.barrier() 

245 # Important to close fd before deleting file on windows 

246 # as a WinError would otherwise be raised. 

247 self.fd.close() 

248 if self.world.rank == 0: 

249 os.remove(self.name) 

250 self.world.barrier() 

251 

252 def __enter__(self): 

253 self.acquire() 

254 

255 def __exit__(self, type, value, tb): 

256 self.release() 

257 

258 

259class OpenLock: 

260 def acquire(self): 

261 pass 

262 

263 def release(self): 

264 pass 

265 

266 def __enter__(self): 

267 pass 

268 

269 def __exit__(self, type, value, tb): 

270 pass 

271 

272 

273def search_current_git_hash(arg, world=None): 

274 """Search for .git directory and current git commit hash. 

275 

276 Parameters: 

277 

278 arg: str (directory path) or python module 

279 .git directory is searched from the parent directory of 

280 the given directory or module. 

281 """ 

282 if world is None: 

283 from ase.parallel import world 

284 if world.rank != 0: 

285 return None 

286 

287 # Check argument 

288 if isinstance(arg, str): 

289 # Directory path 

290 dpath = arg 

291 else: 

292 # Assume arg is module 

293 dpath = os.path.dirname(arg.__file__) 

294 # dpath = os.path.abspath(dpath) 

295 # in case this is just symlinked into $PYTHONPATH 

296 dpath = os.path.realpath(dpath) 

297 dpath = os.path.dirname(dpath) # Go to the parent directory 

298 git_dpath = os.path.join(dpath, '.git') 

299 if not os.path.isdir(git_dpath): 

300 # Replace this 'if' with a loop if you want to check 

301 # further parent directories 

302 return None 

303 HEAD_file = os.path.join(git_dpath, 'HEAD') 

304 if not os.path.isfile(HEAD_file): 

305 return None 

306 with open(HEAD_file, 'r') as fd: 

307 line = fd.readline().strip() 

308 if line.startswith('ref: '): 

309 ref = line[5:] 

310 ref_file = os.path.join(git_dpath, ref) 

311 else: 

312 # Assuming detached HEAD state 

313 ref_file = HEAD_file 

314 if not os.path.isfile(ref_file): 

315 return None 

316 with open(ref_file, 'r') as fd: 

317 line = fd.readline().strip() 

318 if all(c in string.hexdigits for c in line): 

319 return line 

320 return None 

321 

322 

323def rotate(rotations, rotation=np.identity(3)): 

324 """Convert string of format '50x,-10y,120z' to a rotation matrix. 

325 

326 Note that the order of rotation matters, i.e. '50x,40z' is different 

327 from '40z,50x'. 

328 """ 

329 

330 if rotations == '': 

331 return rotation.copy() 

332 

333 for i, a in [('xyz'.index(s[-1]), radians(float(s[:-1]))) 

334 for s in rotations.split(',')]: 

335 s = sin(a) 

336 c = cos(a) 

337 if i == 0: 

338 rotation = np.dot(rotation, [(1, 0, 0), 

339 (0, c, s), 

340 (0, -s, c)]) 

341 elif i == 1: 

342 rotation = np.dot(rotation, [(c, 0, -s), 

343 (0, 1, 0), 

344 (s, 0, c)]) 

345 else: 

346 rotation = np.dot(rotation, [(c, s, 0), 

347 (-s, c, 0), 

348 (0, 0, 1)]) 

349 return rotation 

350 

351 

352def givens(a, b): 

353 """Solve the equation system:: 

354 

355 [ c s] [a] [r] 

356 [ ] . [ ] = [ ] 

357 [-s c] [b] [0] 

358 """ 

359 sgn = np.sign 

360 if b == 0: 

361 c = sgn(a) 

362 s = 0 

363 r = abs(a) 

364 elif abs(b) >= abs(a): 

365 cot = a / b 

366 u = sgn(b) * (1 + cot**2)**0.5 

367 s = 1. / u 

368 c = s * cot 

369 r = b * u 

370 else: 

371 tan = b / a 

372 u = sgn(a) * (1 + tan**2)**0.5 

373 c = 1. / u 

374 s = c * tan 

375 r = a * u 

376 return c, s, r 

377 

378 

379def irotate(rotation, initial=np.identity(3)): 

380 """Determine x, y, z rotation angles from rotation matrix.""" 

381 a = np.dot(initial, rotation) 

382 cx, sx, rx = givens(a[2, 2], a[1, 2]) 

383 cy, sy, ry = givens(rx, a[0, 2]) 

384 cz, sz, rz = givens(cx * a[1, 1] - sx * a[2, 1], 

385 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1])) 

386 x = degrees(atan2(sx, cx)) 

387 y = degrees(atan2(-sy, cy)) 

388 z = degrees(atan2(sz, cz)) 

389 return x, y, z 

390 

391 

392def pbc2pbc(pbc): 

393 newpbc = np.empty(3, bool) 

394 newpbc[:] = pbc 

395 return newpbc 

396 

397 

398def hsv2rgb(h, s, v): 

399 """http://en.wikipedia.org/wiki/HSL_and_HSV 

400 

401 h (hue) in [0, 360[ 

402 s (saturation) in [0, 1] 

403 v (value) in [0, 1] 

404 

405 return rgb in range [0, 1] 

406 """ 

407 if v == 0: 

408 return 0, 0, 0 

409 if s == 0: 

410 return v, v, v 

411 

412 i, f = divmod(h / 60., 1) 

413 p = v * (1 - s) 

414 q = v * (1 - s * f) 

415 t = v * (1 - s * (1 - f)) 

416 

417 if i == 0: 

418 return v, t, p 

419 elif i == 1: 

420 return q, v, p 

421 elif i == 2: 

422 return p, v, t 

423 elif i == 3: 

424 return p, q, v 

425 elif i == 4: 

426 return t, p, v 

427 elif i == 5: 

428 return v, p, q 

429 else: 

430 raise RuntimeError('h must be in [0, 360]') 

431 

432 

433def hsv(array, s=.9, v=.9): 

434 array = (array + array.min()) * 359. / (array.max() - array.min()) 

435 result = np.empty((len(array.flat), 3)) 

436 for rgb, h in zip(result, array.flat): 

437 rgb[:] = hsv2rgb(h, s, v) 

438 return np.reshape(result, array.shape + (3,)) 

439 

440 

441# This code does the same, but requires pylab 

442# def cmap(array, name='hsv'): 

443# import pylab 

444# a = (array + array.min()) / array.ptp() 

445# rgba = getattr(pylab.cm, name)(a) 

446# return rgba[:-1] # return rgb only (not alpha) 

447 

448 

449def longsum(x): 

450 """128-bit floating point sum.""" 

451 return float(np.asarray(x, dtype=np.longdouble).sum()) 

452 

453 

454@contextmanager 

455def workdir(path, mkdir=False): 

456 """Temporarily change, and optionally create, working directory.""" 

457 path = Path(path) 

458 if mkdir: 

459 path.mkdir(parents=True, exist_ok=True) 

460 

461 olddir = os.getcwd() 

462 os.chdir(str(path)) # py3.6 allows chdir(path) but we still need 3.5 

463 try: 

464 yield # Yield the Path or dirname maybe? 

465 finally: 

466 os.chdir(olddir) 

467 

468 

469class iofunction: 

470 """Decorate func so it accepts either str or file. 

471 

472 (Won't work on functions that return a generator.)""" 

473 def __init__(self, mode): 

474 self.mode = mode 

475 

476 def __call__(self, func): 

477 @functools.wraps(func) 

478 def iofunc(file, *args, **kwargs): 

479 openandclose = isinstance(file, (str, PurePath)) 

480 fd = None 

481 try: 

482 if openandclose: 

483 fd = open(str(file), self.mode) 

484 else: 

485 fd = file 

486 obj = func(fd, *args, **kwargs) 

487 return obj 

488 finally: 

489 if openandclose and fd is not None: 

490 # fd may be None if open() failed 

491 fd.close() 

492 return iofunc 

493 

494 

495def writer(func): 

496 return iofunction('w')(func) 

497 

498 

499def reader(func): 

500 return iofunction('r')(func) 

501 

502 

503# The next two functions are for hotplugging into a JSONable class 

504# using the jsonable decorator. We are supposed to have this kind of stuff 

505# in ase.io.jsonio, but we'd rather import them from a 'basic' module 

506# like ase/utils than one which triggers a lot of extra (cyclic) imports. 

507 

508def write_json(self, fd): 

509 """Write to JSON file.""" 

510 from ase.io.jsonio import write_json as _write_json 

511 _write_json(fd, self) 

512 

513 

514@classmethod # type: ignore 

515def read_json(cls, fd): 

516 """Read new instance from JSON file.""" 

517 from ase.io.jsonio import read_json as _read_json 

518 obj = _read_json(fd) 

519 assert type(obj) is cls 

520 return obj 

521 

522 

523def jsonable(name): 

524 """Decorator for facilitating JSON I/O with a class. 

525 

526 Pokes JSON-based read and write functions into the class. 

527 

528 In order to write an object to JSON, it needs to be a known simple type 

529 (such as ndarray, float, ...) or implement todict(). If the class 

530 defines a string called ase_objtype, the decoder will want to convert 

531 the object back into its original type when reading.""" 

532 def jsonableclass(cls): 

533 cls.ase_objtype = name 

534 if not hasattr(cls, 'todict'): 

535 raise TypeError('Class must implement todict()') 

536 

537 # We may want the write and read to be optional. 

538 # E.g. a calculator might want to be JSONable, but not 

539 # that .write() produces a JSON file. 

540 # 

541 # This is mostly for 'lightweight' object IO. 

542 cls.write = write_json 

543 cls.read = read_json 

544 return cls 

545 return jsonableclass 

546 

547 

548class ExperimentalFeatureWarning(Warning): 

549 pass 

550 

551 

552def experimental(func): 

553 """Decorator for functions not ready for production use.""" 

554 @functools.wraps(func) 

555 def expfunc(*args, **kwargs): 

556 warnings.warn('This function may change or misbehave: {}()' 

557 .format(func.__qualname__), 

558 ExperimentalFeatureWarning) 

559 return func(*args, **kwargs) 

560 return expfunc 

561 

562 

563def lazymethod(meth): 

564 """Decorator for lazy evaluation and caching of data. 

565 

566 Example:: 

567 

568 class MyClass: 

569 

570 @lazymethod 

571 def thing(self): 

572 return expensive_calculation() 

573 

574 The method body is only executed first time thing() is called, and 

575 its return value is stored. Subsequent calls return the cached 

576 value.""" 

577 name = meth.__name__ 

578 

579 @functools.wraps(meth) 

580 def getter(self): 

581 try: 

582 cache = self._lazy_cache 

583 except AttributeError: 

584 cache = self._lazy_cache = {} 

585 

586 if name not in cache: 

587 cache[name] = meth(self) 

588 return cache[name] 

589 return getter 

590 

591 

592def atoms_to_spglib_cell(atoms): 

593 """Convert atoms into data suitable for calling spglib.""" 

594 return (atoms.get_cell(), 

595 atoms.get_scaled_positions(), 

596 atoms.get_atomic_numbers()) 

597 

598 

599def warn_legacy(feature_name): 

600 warnings.warn( 

601 f'The {feature_name} feature is untested and ASE developers do not ' 

602 'know whether it works or how to use it. Please rehabilitate it ' 

603 '(by writing unittests) or it may be removed.', 

604 FutureWarning) 

605 

606 

607def lazyproperty(meth): 

608 """Decorator like lazymethod, but making item available as a property.""" 

609 return property(lazymethod(meth)) 

610 

611 

612class IOContext: 

613 @lazyproperty 

614 def _exitstack(self): 

615 return ExitStack() 

616 

617 def __enter__(self): 

618 return self 

619 

620 def __exit__(self, *args): 

621 self.close() 

622 

623 def closelater(self, fd): 

624 return self._exitstack.enter_context(fd) 

625 

626 def close(self): 

627 self._exitstack.close() 

628 

629 def openfile(self, file, comm=None, mode='w'): 

630 from ase.parallel import world 

631 if comm is None: 

632 comm = world 

633 

634 if hasattr(file, 'close'): 

635 return file # File already opened, not for us to close. 

636 

637 if file is None or comm.rank != 0: 

638 return self.closelater(open(os.devnull, mode=mode)) 

639 

640 if file == '-': 

641 return sys.stdout 

642 

643 return self.closelater(open(file, mode=mode))