Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import errno
2import functools
3import os
4import io
5import pickle
6import sys
7import time
8import string
9import warnings
10from importlib import import_module
11from math import sin, cos, radians, atan2, degrees
12from contextlib import contextmanager, ExitStack
13from math import gcd
14from pathlib import PurePath, Path
15import re
17import numpy as np
19from ase.formula import formula_hill, formula_metal
21__all__ = ['exec_', 'basestring', 'import_module', 'seterr', 'plural',
22 'devnull', 'gcd', 'convert_string_to_fd', 'Lock',
23 'opencew', 'OpenLock', 'rotate', 'irotate', 'pbc2pbc', 'givens',
24 'hsv2rgb', 'hsv', 'pickleload', 'FileNotFoundError',
25 'formula_hill', 'formula_metal', 'PurePath', 'xwopen',
26 'tokenize_version']
29def tokenize_version(version_string: str):
30 """Parse version string into a tuple for version comparisons.
32 Usage: tokenize_version('3.8') < tokenize_version('3.8.1').
33 """
34 tokens = []
35 for component in version_string.split('.'):
36 match = re.match(r'(\d*)(.*)', component)
37 assert match is not None, f'Cannot parse component {component}'
38 number_str, tail = match.group(1, 2)
39 try:
40 number = int(number_str)
41 except ValueError:
42 number = -1
43 tokens += [number, tail]
44 return tuple(tokens)
47# Python 2+3 compatibility stuff (let's try to remove these things):
48basestring = str
49pickleload = functools.partial(pickle.load, encoding='bytes')
52def deprecated(msg, category=FutureWarning):
53 """Return a decorator deprecating a function.
55 Use like @deprecated('warning message and explanation')."""
56 def deprecated_decorator(func):
57 @functools.wraps(func)
58 def deprecated_function(*args, **kwargs):
59 warning = msg
60 if not isinstance(warning, Warning):
61 warning = category(warning)
62 warnings.warn(warning)
63 return func(*args, **kwargs)
64 return deprecated_function
65 return deprecated_decorator
68@contextmanager
69def seterr(**kwargs):
70 """Set how floating-point errors are handled.
72 See np.seterr() for more details.
73 """
74 old = np.seterr(**kwargs)
75 try:
76 yield
77 finally:
78 np.seterr(**old)
81def plural(n, word):
82 """Use plural for n!=1.
84 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg')
85 ('0 eggs', '1 egg', '2 eggs')
86 """
87 if n == 1:
88 return '1 ' + word
89 return '%d %ss' % (n, word)
92class DevNull:
93 encoding = 'UTF-8'
94 closed = False
96 _use_os_devnull = deprecated('use open(os.devnull) instead',
97 DeprecationWarning)
98 # Deprecated for ase-3.21.0. Change to futurewarning later on.
100 @_use_os_devnull
101 def write(self, string):
102 pass
104 @_use_os_devnull
105 def flush(self):
106 pass
108 @_use_os_devnull
109 def seek(self, offset, whence=0):
110 return 0
112 @_use_os_devnull
113 def tell(self):
114 return 0
116 @_use_os_devnull
117 def close(self):
118 pass
120 @_use_os_devnull
121 def isatty(self):
122 return False
124 @_use_os_devnull
125 def read(self, n=-1):
126 return ''
129devnull = DevNull()
132@deprecated('convert_string_to_fd does not facilitate proper resource '
133 'management. '
134 'Please use e.g. ase.utils.IOContext class instead.')
135def convert_string_to_fd(name, world=None):
136 """Create a file-descriptor for text output.
138 Will open a file for writing with given name. Use None for no output and
139 '-' for sys.stdout.
140 """
141 if world is None:
142 from ase.parallel import world
143 if name is None or world.rank != 0:
144 return open(os.devnull, 'w')
145 if name == '-':
146 return sys.stdout
147 if isinstance(name, (str, PurePath)):
148 return open(str(name), 'w') # str for py3.5 pathlib
149 return name # we assume name is already a file-descriptor
152# Only Windows has O_BINARY:
153CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0)
156@contextmanager
157def xwopen(filename, world=None):
158 """Create and open filename exclusively for writing.
160 If master cpu gets exclusive write access to filename, a file
161 descriptor is returned (a dummy file descriptor is returned on the
162 slaves). If the master cpu does not get write access, None is
163 returned on all processors."""
165 fd = opencew(filename, world)
166 try:
167 yield fd
168 finally:
169 if fd is not None:
170 fd.close()
173#@deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak')
174def opencew(filename, world=None):
175 return _opencew(filename, world)
178def _opencew(filename, world=None):
179 if world is None:
180 from ase.parallel import world
182 closelater = []
184 def opener(file, flags):
185 return os.open(file, flags | CEW_FLAGS)
187 try:
188 error = 0
189 if world.rank == 0:
190 try:
191 fd = open(filename, 'wb', opener=opener)
192 except OSError as ex:
193 error = ex.errno
194 else:
195 closelater.append(fd)
196 else:
197 fd = open(os.devnull, 'wb')
198 closelater.append(fd)
200 # Synchronize:
201 error = world.sum(error)
202 if error == errno.EEXIST:
203 return None
204 if error:
205 raise OSError(error, 'Error', filename)
207 return fd
208 except BaseException:
209 for fd in closelater:
210 fd.close()
211 raise
214def opencew_text(*args, **kwargs):
215 fd = opencew(*args, **kwargs)
216 if fd is None:
217 return None
218 return io.TextIOWrapper(fd)
221class Lock:
222 def __init__(self, name='lock', world=None, timeout=float('inf')):
223 self.name = str(name)
224 self.timeout = timeout
225 if world is None:
226 from ase.parallel import world
227 self.world = world
229 def acquire(self):
230 dt = 0.2
231 t1 = time.time()
232 while True:
233 fd = opencew(self.name, self.world)
234 if fd is not None:
235 self.fd = fd
236 break
237 time_left = self.timeout - (time.time() - t1)
238 if time_left <= 0:
239 raise TimeoutError
240 time.sleep(min(dt, time_left))
241 dt *= 2
243 def release(self):
244 self.world.barrier()
245 # Important to close fd before deleting file on windows
246 # as a WinError would otherwise be raised.
247 self.fd.close()
248 if self.world.rank == 0:
249 os.remove(self.name)
250 self.world.barrier()
252 def __enter__(self):
253 self.acquire()
255 def __exit__(self, type, value, tb):
256 self.release()
259class OpenLock:
260 def acquire(self):
261 pass
263 def release(self):
264 pass
266 def __enter__(self):
267 pass
269 def __exit__(self, type, value, tb):
270 pass
273def search_current_git_hash(arg, world=None):
274 """Search for .git directory and current git commit hash.
276 Parameters:
278 arg: str (directory path) or python module
279 .git directory is searched from the parent directory of
280 the given directory or module.
281 """
282 if world is None:
283 from ase.parallel import world
284 if world.rank != 0:
285 return None
287 # Check argument
288 if isinstance(arg, str):
289 # Directory path
290 dpath = arg
291 else:
292 # Assume arg is module
293 dpath = os.path.dirname(arg.__file__)
294 # dpath = os.path.abspath(dpath)
295 # in case this is just symlinked into $PYTHONPATH
296 dpath = os.path.realpath(dpath)
297 dpath = os.path.dirname(dpath) # Go to the parent directory
298 git_dpath = os.path.join(dpath, '.git')
299 if not os.path.isdir(git_dpath):
300 # Replace this 'if' with a loop if you want to check
301 # further parent directories
302 return None
303 HEAD_file = os.path.join(git_dpath, 'HEAD')
304 if not os.path.isfile(HEAD_file):
305 return None
306 with open(HEAD_file, 'r') as fd:
307 line = fd.readline().strip()
308 if line.startswith('ref: '):
309 ref = line[5:]
310 ref_file = os.path.join(git_dpath, ref)
311 else:
312 # Assuming detached HEAD state
313 ref_file = HEAD_file
314 if not os.path.isfile(ref_file):
315 return None
316 with open(ref_file, 'r') as fd:
317 line = fd.readline().strip()
318 if all(c in string.hexdigits for c in line):
319 return line
320 return None
323def rotate(rotations, rotation=np.identity(3)):
324 """Convert string of format '50x,-10y,120z' to a rotation matrix.
326 Note that the order of rotation matters, i.e. '50x,40z' is different
327 from '40z,50x'.
328 """
330 if rotations == '':
331 return rotation.copy()
333 for i, a in [('xyz'.index(s[-1]), radians(float(s[:-1])))
334 for s in rotations.split(',')]:
335 s = sin(a)
336 c = cos(a)
337 if i == 0:
338 rotation = np.dot(rotation, [(1, 0, 0),
339 (0, c, s),
340 (0, -s, c)])
341 elif i == 1:
342 rotation = np.dot(rotation, [(c, 0, -s),
343 (0, 1, 0),
344 (s, 0, c)])
345 else:
346 rotation = np.dot(rotation, [(c, s, 0),
347 (-s, c, 0),
348 (0, 0, 1)])
349 return rotation
352def givens(a, b):
353 """Solve the equation system::
355 [ c s] [a] [r]
356 [ ] . [ ] = [ ]
357 [-s c] [b] [0]
358 """
359 sgn = np.sign
360 if b == 0:
361 c = sgn(a)
362 s = 0
363 r = abs(a)
364 elif abs(b) >= abs(a):
365 cot = a / b
366 u = sgn(b) * (1 + cot**2)**0.5
367 s = 1. / u
368 c = s * cot
369 r = b * u
370 else:
371 tan = b / a
372 u = sgn(a) * (1 + tan**2)**0.5
373 c = 1. / u
374 s = c * tan
375 r = a * u
376 return c, s, r
379def irotate(rotation, initial=np.identity(3)):
380 """Determine x, y, z rotation angles from rotation matrix."""
381 a = np.dot(initial, rotation)
382 cx, sx, rx = givens(a[2, 2], a[1, 2])
383 cy, sy, ry = givens(rx, a[0, 2])
384 cz, sz, rz = givens(cx * a[1, 1] - sx * a[2, 1],
385 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1]))
386 x = degrees(atan2(sx, cx))
387 y = degrees(atan2(-sy, cy))
388 z = degrees(atan2(sz, cz))
389 return x, y, z
392def pbc2pbc(pbc):
393 newpbc = np.empty(3, bool)
394 newpbc[:] = pbc
395 return newpbc
398def hsv2rgb(h, s, v):
399 """http://en.wikipedia.org/wiki/HSL_and_HSV
401 h (hue) in [0, 360[
402 s (saturation) in [0, 1]
403 v (value) in [0, 1]
405 return rgb in range [0, 1]
406 """
407 if v == 0:
408 return 0, 0, 0
409 if s == 0:
410 return v, v, v
412 i, f = divmod(h / 60., 1)
413 p = v * (1 - s)
414 q = v * (1 - s * f)
415 t = v * (1 - s * (1 - f))
417 if i == 0:
418 return v, t, p
419 elif i == 1:
420 return q, v, p
421 elif i == 2:
422 return p, v, t
423 elif i == 3:
424 return p, q, v
425 elif i == 4:
426 return t, p, v
427 elif i == 5:
428 return v, p, q
429 else:
430 raise RuntimeError('h must be in [0, 360]')
433def hsv(array, s=.9, v=.9):
434 array = (array + array.min()) * 359. / (array.max() - array.min())
435 result = np.empty((len(array.flat), 3))
436 for rgb, h in zip(result, array.flat):
437 rgb[:] = hsv2rgb(h, s, v)
438 return np.reshape(result, array.shape + (3,))
441# This code does the same, but requires pylab
442# def cmap(array, name='hsv'):
443# import pylab
444# a = (array + array.min()) / array.ptp()
445# rgba = getattr(pylab.cm, name)(a)
446# return rgba[:-1] # return rgb only (not alpha)
449def longsum(x):
450 """128-bit floating point sum."""
451 return float(np.asarray(x, dtype=np.longdouble).sum())
454@contextmanager
455def workdir(path, mkdir=False):
456 """Temporarily change, and optionally create, working directory."""
457 path = Path(path)
458 if mkdir:
459 path.mkdir(parents=True, exist_ok=True)
461 olddir = os.getcwd()
462 os.chdir(str(path)) # py3.6 allows chdir(path) but we still need 3.5
463 try:
464 yield # Yield the Path or dirname maybe?
465 finally:
466 os.chdir(olddir)
469class iofunction:
470 """Decorate func so it accepts either str or file.
472 (Won't work on functions that return a generator.)"""
473 def __init__(self, mode):
474 self.mode = mode
476 def __call__(self, func):
477 @functools.wraps(func)
478 def iofunc(file, *args, **kwargs):
479 openandclose = isinstance(file, (str, PurePath))
480 fd = None
481 try:
482 if openandclose:
483 fd = open(str(file), self.mode)
484 else:
485 fd = file
486 obj = func(fd, *args, **kwargs)
487 return obj
488 finally:
489 if openandclose and fd is not None:
490 # fd may be None if open() failed
491 fd.close()
492 return iofunc
495def writer(func):
496 return iofunction('w')(func)
499def reader(func):
500 return iofunction('r')(func)
503# The next two functions are for hotplugging into a JSONable class
504# using the jsonable decorator. We are supposed to have this kind of stuff
505# in ase.io.jsonio, but we'd rather import them from a 'basic' module
506# like ase/utils than one which triggers a lot of extra (cyclic) imports.
508def write_json(self, fd):
509 """Write to JSON file."""
510 from ase.io.jsonio import write_json as _write_json
511 _write_json(fd, self)
514@classmethod # type: ignore
515def read_json(cls, fd):
516 """Read new instance from JSON file."""
517 from ase.io.jsonio import read_json as _read_json
518 obj = _read_json(fd)
519 assert type(obj) is cls
520 return obj
523def jsonable(name):
524 """Decorator for facilitating JSON I/O with a class.
526 Pokes JSON-based read and write functions into the class.
528 In order to write an object to JSON, it needs to be a known simple type
529 (such as ndarray, float, ...) or implement todict(). If the class
530 defines a string called ase_objtype, the decoder will want to convert
531 the object back into its original type when reading."""
532 def jsonableclass(cls):
533 cls.ase_objtype = name
534 if not hasattr(cls, 'todict'):
535 raise TypeError('Class must implement todict()')
537 # We may want the write and read to be optional.
538 # E.g. a calculator might want to be JSONable, but not
539 # that .write() produces a JSON file.
540 #
541 # This is mostly for 'lightweight' object IO.
542 cls.write = write_json
543 cls.read = read_json
544 return cls
545 return jsonableclass
548class ExperimentalFeatureWarning(Warning):
549 pass
552def experimental(func):
553 """Decorator for functions not ready for production use."""
554 @functools.wraps(func)
555 def expfunc(*args, **kwargs):
556 warnings.warn('This function may change or misbehave: {}()'
557 .format(func.__qualname__),
558 ExperimentalFeatureWarning)
559 return func(*args, **kwargs)
560 return expfunc
563def lazymethod(meth):
564 """Decorator for lazy evaluation and caching of data.
566 Example::
568 class MyClass:
570 @lazymethod
571 def thing(self):
572 return expensive_calculation()
574 The method body is only executed first time thing() is called, and
575 its return value is stored. Subsequent calls return the cached
576 value."""
577 name = meth.__name__
579 @functools.wraps(meth)
580 def getter(self):
581 try:
582 cache = self._lazy_cache
583 except AttributeError:
584 cache = self._lazy_cache = {}
586 if name not in cache:
587 cache[name] = meth(self)
588 return cache[name]
589 return getter
592def atoms_to_spglib_cell(atoms):
593 """Convert atoms into data suitable for calling spglib."""
594 return (atoms.get_cell(),
595 atoms.get_scaled_positions(),
596 atoms.get_atomic_numbers())
599def warn_legacy(feature_name):
600 warnings.warn(
601 f'The {feature_name} feature is untested and ASE developers do not '
602 'know whether it works or how to use it. Please rehabilitate it '
603 '(by writing unittests) or it may be removed.',
604 FutureWarning)
607def lazyproperty(meth):
608 """Decorator like lazymethod, but making item available as a property."""
609 return property(lazymethod(meth))
612class IOContext:
613 @lazyproperty
614 def _exitstack(self):
615 return ExitStack()
617 def __enter__(self):
618 return self
620 def __exit__(self, *args):
621 self.close()
623 def closelater(self, fd):
624 return self._exitstack.enter_context(fd)
626 def close(self):
627 self._exitstack.close()
629 def openfile(self, file, comm=None, mode='w'):
630 from ase.parallel import world
631 if comm is None:
632 comm = world
634 if hasattr(file, 'close'):
635 return file # File already opened, not for us to close.
637 if file is None or comm.rank != 0:
638 return self.closelater(open(os.devnull, mode=mode))
640 if file == '-':
641 return sys.stdout
643 return self.closelater(open(file, mode=mode))