Coverage for /builds/debichem-team/python-ase/ase/utils/__init__.py: 84.11%
384 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-03-06 04:00 +0000
1import errno
2import functools
3import io
4import os
5import pickle
6import re
7import string
8import sys
9import time
10import warnings
11from contextlib import ExitStack, contextmanager
12from importlib import import_module
13from math import atan2, cos, degrees, gcd, radians, sin
14from pathlib import Path, PurePath
15from typing import Callable, Dict, List, Type, Union
17import numpy as np
19from ase.formula import formula_hill, formula_metal
21__all__ = ['basestring', 'import_module', 'seterr', 'plural',
22 'devnull', 'gcd', 'convert_string_to_fd', 'Lock',
23 'opencew', 'OpenLock', 'rotate', 'irotate', 'pbc2pbc', 'givens',
24 'hsv2rgb', 'hsv', 'pickleload', 'reader',
25 'formula_hill', 'formula_metal', 'PurePath', 'xwopen',
26 'tokenize_version', 'get_python_package_path_description']
29def tokenize_version(version_string: str):
30 """Parse version string into a tuple for version comparisons.
32 Usage: tokenize_version('3.8') < tokenize_version('3.8.1').
33 """
34 tokens = []
35 for component in version_string.split('.'):
36 match = re.match(r'(\d*)(.*)', component)
37 assert match is not None, f'Cannot parse component {component}'
38 number_str, tail = match.group(1, 2)
39 try:
40 number = int(number_str)
41 except ValueError:
42 number = -1
43 tokens += [number, tail]
44 return tuple(tokens)
47# Python 2+3 compatibility stuff (let's try to remove these things):
48basestring = str
49pickleload = functools.partial(pickle.load, encoding='bytes')
52def deprecated(
53 message: Union[str, Warning],
54 category: Type[Warning] = FutureWarning,
55 callback: Callable[[List, Dict], bool] = lambda args, kwargs: True
56):
57 """Return a decorator deprecating a function.
59 Parameters
60 ----------
61 message : str or Warning
62 The message to be emitted. If ``message`` is a Warning, then
63 ``category`` is ignored and ``message.__class__`` will be used.
64 category : Type[Warning], default=FutureWarning
65 The type of warning to be emitted. If ``message`` is a ``Warning``
66 instance, then ``category`` will be ignored and ``message.__class__``
67 will be used.
68 callback : Callable[[List, Dict], bool], default=lambda args, kwargs: True
69 A callable that determines if the warning should be emitted and handles
70 any processing prior to calling the deprecated function. The callable
71 will receive two arguments, a list and a dictionary. The list will
72 contain the positional arguments that the deprecated function was
73 called with at runtime while the dictionary will contain the keyword
74 arguments. The callable *must* return ``True`` if the warning is to be
75 emitted and ``False`` otherwise. The list and dictionary will be
76 unpacked into the positional and keyword arguments, respectively, used
77 to call the deprecated function.
79 Returns
80 -------
81 deprecated_decorator : Callable
82 A decorator for deprecated functions that can be used to conditionally
83 emit deprecation warnings and/or pre-process the arguments of a
84 deprecated function.
86 Example
87 -------
88 >>> # Inspect & replace a keyword parameter passed to a deprecated function
89 >>> from typing import Any, Callable, Dict, List
90 >>> import warnings
91 >>> from ase.utils import deprecated
93 >>> def alias_callback_factory(kwarg: str, alias: str) -> Callable:
94 ... def _replace_arg(_: List, kwargs: Dict[str, Any]) -> bool:
95 ... kwargs[kwarg] = kwargs[alias]
96 ... del kwargs[alias]
97 ... return True
98 ... return _replace_arg
100 >>> MESSAGE = ("Calling this function with `atoms` is deprecated. "
101 ... "Use `optimizable` instead.")
102 >>> @deprecated(
103 ... MESSAGE,
104 ... category=DeprecationWarning,
105 ... callback=alias_callback_factory("optimizable", "atoms")
106 ... )
107 ... def function(*, atoms=None, optimizable=None):
108 ... '''
109 ... .. deprecated:: 3.23.0
110 ... Calling this function with ``atoms`` is deprecated.
111 ... Use ``optimizable`` instead.
112 ... '''
113 ... print(f"atoms: {atoms}")
114 ... print(f"optimizable: {optimizable}")
116 >>> with warnings.catch_warnings(record=True) as w:
117 ... warnings.simplefilter("always")
118 ... function(atoms="atoms")
119 atoms: None
120 optimizable: atoms
122 >>> w[-1].category == DeprecationWarning
123 True
124 """
126 def deprecated_decorator(func):
127 @functools.wraps(func)
128 def deprecated_function(*args, **kwargs):
129 _args = list(args)
130 if callback(_args, kwargs):
131 warnings.warn(message, category=category, stacklevel=2)
133 return func(*_args, **kwargs)
135 return deprecated_function
137 return deprecated_decorator
140@contextmanager
141def seterr(**kwargs):
142 """Set how floating-point errors are handled.
144 See np.seterr() for more details.
145 """
146 old = np.seterr(**kwargs)
147 try:
148 yield
149 finally:
150 np.seterr(**old)
153def plural(n, word):
154 """Use plural for n!=1.
156 >>> from ase.utils import plural
158 >>> plural(0, 'egg'), plural(1, 'egg'), plural(2, 'egg')
159 ('0 eggs', '1 egg', '2 eggs')
160 """
161 if n == 1:
162 return '1 ' + word
163 return '%d %ss' % (n, word)
166class DevNull:
167 encoding = 'UTF-8'
168 closed = False
170 _use_os_devnull = deprecated('use open(os.devnull) instead',
171 DeprecationWarning)
172 # Deprecated for ase-3.21.0. Change to futurewarning later on.
174 @_use_os_devnull
175 def write(self, string):
176 pass
178 @_use_os_devnull
179 def flush(self):
180 pass
182 @_use_os_devnull
183 def seek(self, offset, whence=0):
184 return 0
186 @_use_os_devnull
187 def tell(self):
188 return 0
190 @_use_os_devnull
191 def close(self):
192 pass
194 @_use_os_devnull
195 def isatty(self):
196 return False
198 @_use_os_devnull
199 def read(self, n=-1):
200 return ''
203devnull = DevNull()
206@deprecated('convert_string_to_fd does not facilitate proper resource '
207 'management. '
208 'Please use e.g. ase.utils.IOContext class instead.')
209def convert_string_to_fd(name, world=None):
210 """Create a file-descriptor for text output.
212 Will open a file for writing with given name. Use None for no output and
213 '-' for sys.stdout.
215 .. deprecated:: 3.22.1
216 Please use e.g. :class:`ase.utils.IOContext` class instead.
217 """
218 if world is None:
219 from ase.parallel import world
220 if name is None or world.rank != 0:
221 return open(os.devnull, 'w')
222 if name == '-':
223 return sys.stdout
224 if isinstance(name, (str, PurePath)):
225 return open(str(name), 'w') # str for py3.5 pathlib
226 return name # we assume name is already a file-descriptor
229# Only Windows has O_BINARY:
230CEW_FLAGS = os.O_CREAT | os.O_EXCL | os.O_WRONLY | getattr(os, 'O_BINARY', 0)
233@contextmanager
234def xwopen(filename, world=None):
235 """Create and open filename exclusively for writing.
237 If master cpu gets exclusive write access to filename, a file
238 descriptor is returned (a dummy file descriptor is returned on the
239 slaves). If the master cpu does not get write access, None is
240 returned on all processors."""
242 fd = opencew(filename, world)
243 try:
244 yield fd
245 finally:
246 if fd is not None:
247 fd.close()
250# @deprecated('use "with xwopen(...) as fd: ..." to prevent resource leak')
251def opencew(filename, world=None):
252 return _opencew(filename, world)
255def _opencew(filename, world=None):
256 import ase.parallel as parallel
257 if world is None:
258 world = parallel.world
260 closelater = []
262 def opener(file, flags):
263 return os.open(file, flags | CEW_FLAGS)
265 try:
266 error = 0
267 if world.rank == 0:
268 try:
269 fd = open(filename, 'wb', opener=opener)
270 except OSError as ex:
271 error = ex.errno
272 else:
273 closelater.append(fd)
274 else:
275 fd = open(os.devnull, 'wb')
276 closelater.append(fd)
278 # Synchronize:
279 error = world.sum_scalar(error)
280 if error == errno.EEXIST:
281 return None
282 if error:
283 raise OSError(error, 'Error', filename)
285 return fd
286 except BaseException:
287 for fd in closelater:
288 fd.close()
289 raise
292def opencew_text(*args, **kwargs):
293 fd = opencew(*args, **kwargs)
294 if fd is None:
295 return None
296 return io.TextIOWrapper(fd)
299class Lock:
300 def __init__(self, name='lock', world=None, timeout=float('inf')):
301 self.name = str(name)
302 self.timeout = timeout
303 if world is None:
304 from ase.parallel import world
305 self.world = world
307 def acquire(self):
308 dt = 0.2
309 t1 = time.time()
310 while True:
311 fd = opencew(self.name, self.world)
312 if fd is not None:
313 self.fd = fd
314 break
315 time_left = self.timeout - (time.time() - t1)
316 if time_left <= 0:
317 raise TimeoutError
318 time.sleep(min(dt, time_left))
319 dt *= 2
321 def release(self):
322 self.world.barrier()
323 # Important to close fd before deleting file on windows
324 # as a WinError would otherwise be raised.
325 self.fd.close()
326 if self.world.rank == 0:
327 os.remove(self.name)
328 self.world.barrier()
330 def __enter__(self):
331 self.acquire()
333 def __exit__(self, type, value, tb):
334 self.release()
337class OpenLock:
338 def acquire(self):
339 pass
341 def release(self):
342 pass
344 def __enter__(self):
345 pass
347 def __exit__(self, type, value, tb):
348 pass
351def search_current_git_hash(arg, world=None):
352 """Search for .git directory and current git commit hash.
354 Parameters:
356 arg: str (directory path) or python module
357 .git directory is searched from the parent directory of
358 the given directory or module.
359 """
360 if world is None:
361 from ase.parallel import world
362 if world.rank != 0:
363 return None
365 # Check argument
366 if isinstance(arg, str):
367 # Directory path
368 dpath = arg
369 else:
370 # Assume arg is module
371 dpath = os.path.dirname(arg.__file__)
372 # dpath = os.path.abspath(dpath)
373 # in case this is just symlinked into $PYTHONPATH
374 dpath = os.path.realpath(dpath)
375 dpath = os.path.dirname(dpath) # Go to the parent directory
376 git_dpath = os.path.join(dpath, '.git')
377 if not os.path.isdir(git_dpath):
378 # Replace this 'if' with a loop if you want to check
379 # further parent directories
380 return None
381 HEAD_file = os.path.join(git_dpath, 'HEAD')
382 if not os.path.isfile(HEAD_file):
383 return None
384 with open(HEAD_file) as fd:
385 line = fd.readline().strip()
386 if line.startswith('ref: '):
387 ref = line[5:]
388 ref_file = os.path.join(git_dpath, ref)
389 else:
390 # Assuming detached HEAD state
391 ref_file = HEAD_file
392 if not os.path.isfile(ref_file):
393 return None
394 with open(ref_file) as fd:
395 line = fd.readline().strip()
396 if all(c in string.hexdigits for c in line):
397 return line
398 return None
401def rotate(rotations, rotation=np.identity(3)):
402 """Convert string of format '50x,-10y,120z' to a rotation matrix.
404 Note that the order of rotation matters, i.e. '50x,40z' is different
405 from '40z,50x'.
406 """
408 if rotations == '':
409 return rotation.copy()
411 for i, a in [('xyz'.index(s[-1]), radians(float(s[:-1])))
412 for s in rotations.split(',')]:
413 s = sin(a)
414 c = cos(a)
415 if i == 0:
416 rotation = np.dot(rotation, [(1, 0, 0),
417 (0, c, s),
418 (0, -s, c)])
419 elif i == 1:
420 rotation = np.dot(rotation, [(c, 0, -s),
421 (0, 1, 0),
422 (s, 0, c)])
423 else:
424 rotation = np.dot(rotation, [(c, s, 0),
425 (-s, c, 0),
426 (0, 0, 1)])
427 return rotation
430def givens(a, b):
431 """Solve the equation system::
433 [ c s] [a] [r]
434 [ ] . [ ] = [ ]
435 [-s c] [b] [0]
436 """
437 sgn = np.sign
438 if b == 0:
439 c = sgn(a)
440 s = 0
441 r = abs(a)
442 elif abs(b) >= abs(a):
443 cot = a / b
444 u = sgn(b) * (1 + cot**2)**0.5
445 s = 1. / u
446 c = s * cot
447 r = b * u
448 else:
449 tan = b / a
450 u = sgn(a) * (1 + tan**2)**0.5
451 c = 1. / u
452 s = c * tan
453 r = a * u
454 return c, s, r
457def irotate(rotation, initial=np.identity(3)):
458 """Determine x, y, z rotation angles from rotation matrix."""
459 a = np.dot(initial, rotation)
460 cx, sx, rx = givens(a[2, 2], a[1, 2])
461 cy, sy, _ry = givens(rx, a[0, 2])
462 cz, sz, _rz = givens(cx * a[1, 1] - sx * a[2, 1],
463 cy * a[0, 1] - sy * (sx * a[1, 1] + cx * a[2, 1]))
464 x = degrees(atan2(sx, cx))
465 y = degrees(atan2(-sy, cy))
466 z = degrees(atan2(sz, cz))
467 return x, y, z
470def pbc2pbc(pbc):
471 newpbc = np.empty(3, bool)
472 newpbc[:] = pbc
473 return newpbc
476def string2index(stridx: str) -> Union[int, slice, str]:
477 """Convert index string to either int or slice"""
478 if ':' not in stridx:
479 # may contain database accessor
480 try:
481 return int(stridx)
482 except ValueError:
483 return stridx
484 i = [None if s == '' else int(s) for s in stridx.split(':')]
485 return slice(*i)
488def hsv2rgb(h, s, v):
489 """http://en.wikipedia.org/wiki/HSL_and_HSV
491 h (hue) in [0, 360[
492 s (saturation) in [0, 1]
493 v (value) in [0, 1]
495 return rgb in range [0, 1]
496 """
497 if v == 0:
498 return 0, 0, 0
499 if s == 0:
500 return v, v, v
502 i, f = divmod(h / 60., 1)
503 p = v * (1 - s)
504 q = v * (1 - s * f)
505 t = v * (1 - s * (1 - f))
507 if i == 0:
508 return v, t, p
509 elif i == 1:
510 return q, v, p
511 elif i == 2:
512 return p, v, t
513 elif i == 3:
514 return p, q, v
515 elif i == 4:
516 return t, p, v
517 elif i == 5:
518 return v, p, q
519 else:
520 raise RuntimeError('h must be in [0, 360]')
523def hsv(array, s=.9, v=.9):
524 array = (array + array.min()) * 359. / (array.max() - array.min())
525 result = np.empty((len(array.flat), 3))
526 for rgb, h in zip(result, array.flat):
527 rgb[:] = hsv2rgb(h, s, v)
528 return np.reshape(result, array.shape + (3,))
531# This code does the same, but requires pylab
532# def cmap(array, name='hsv'):
533# import pylab
534# a = (array + array.min()) / array.ptp()
535# rgba = getattr(pylab.cm, name)(a)
536# return rgba[:-1] # return rgb only (not alpha)
539def longsum(x):
540 """128-bit floating point sum."""
541 return float(np.asarray(x, dtype=np.longdouble).sum())
544@contextmanager
545def workdir(path, mkdir=False):
546 """Temporarily change, and optionally create, working directory."""
547 path = Path(path)
548 if mkdir:
549 path.mkdir(parents=True, exist_ok=True)
551 olddir = os.getcwd()
552 os.chdir(path)
553 try:
554 yield # Yield the Path or dirname maybe?
555 finally:
556 os.chdir(olddir)
559class iofunction:
560 """Decorate func so it accepts either str or file.
562 (Won't work on functions that return a generator.)"""
564 def __init__(self, mode):
565 self.mode = mode
567 def __call__(self, func):
568 @functools.wraps(func)
569 def iofunc(file, *args, **kwargs):
570 openandclose = isinstance(file, (str, PurePath))
571 fd = None
572 try:
573 if openandclose:
574 fd = open(str(file), self.mode)
575 else:
576 fd = file
577 obj = func(fd, *args, **kwargs)
578 return obj
579 finally:
580 if openandclose and fd is not None:
581 # fd may be None if open() failed
582 fd.close()
583 return iofunc
586def writer(func):
587 return iofunction('w')(func)
590def reader(func):
591 return iofunction('r')(func)
594# The next two functions are for hotplugging into a JSONable class
595# using the jsonable decorator. We are supposed to have this kind of stuff
596# in ase.io.jsonio, but we'd rather import them from a 'basic' module
597# like ase/utils than one which triggers a lot of extra (cyclic) imports.
599def write_json(self, fd):
600 """Write to JSON file."""
601 from ase.io.jsonio import write_json as _write_json
602 _write_json(fd, self)
605@classmethod # type: ignore[misc]
606def read_json(cls, fd):
607 """Read new instance from JSON file."""
608 from ase.io.jsonio import read_json as _read_json
609 obj = _read_json(fd)
610 assert isinstance(obj, cls)
611 return obj
614def jsonable(name):
615 """Decorator for facilitating JSON I/O with a class.
617 Pokes JSON-based read and write functions into the class.
619 In order to write an object to JSON, it needs to be a known simple type
620 (such as ndarray, float, ...) or implement todict(). If the class
621 defines a string called ase_objtype, the decoder will want to convert
622 the object back into its original type when reading."""
623 def jsonableclass(cls):
624 cls.ase_objtype = name
625 if not hasattr(cls, 'todict'):
626 raise TypeError('Class must implement todict()')
628 # We may want the write and read to be optional.
629 # E.g. a calculator might want to be JSONable, but not
630 # that .write() produces a JSON file.
631 #
632 # This is mostly for 'lightweight' object IO.
633 cls.write = write_json
634 cls.read = read_json
635 return cls
636 return jsonableclass
639class ExperimentalFeatureWarning(Warning):
640 pass
643def experimental(func):
644 """Decorator for functions not ready for production use."""
645 @functools.wraps(func)
646 def expfunc(*args, **kwargs):
647 warnings.warn('This function may change or misbehave: {}()'
648 .format(func.__qualname__),
649 ExperimentalFeatureWarning)
650 return func(*args, **kwargs)
651 return expfunc
654def lazymethod(meth):
655 """Decorator for lazy evaluation and caching of data.
657 Example::
659 class MyClass:
661 @lazymethod
662 def thing(self):
663 return expensive_calculation()
665 The method body is only executed first time thing() is called, and
666 its return value is stored. Subsequent calls return the cached
667 value."""
668 name = meth.__name__
670 @functools.wraps(meth)
671 def getter(self):
672 try:
673 cache = self._lazy_cache
674 except AttributeError:
675 cache = self._lazy_cache = {}
677 if name not in cache:
678 cache[name] = meth(self)
679 return cache[name]
680 return getter
683def atoms_to_spglib_cell(atoms):
684 """Convert atoms into data suitable for calling spglib."""
685 return (atoms.get_cell(),
686 atoms.get_scaled_positions(),
687 atoms.get_atomic_numbers())
690def warn_legacy(feature_name):
691 warnings.warn(
692 f'The {feature_name} feature is untested and ASE developers do not '
693 'know whether it works or how to use it. Please rehabilitate it '
694 '(by writing unittests) or it may be removed.',
695 FutureWarning)
698def lazyproperty(meth):
699 """Decorator like lazymethod, but making item available as a property."""
700 return property(lazymethod(meth))
703class IOContext:
704 @lazyproperty
705 def _exitstack(self):
706 return ExitStack()
708 def __enter__(self):
709 return self
711 def __exit__(self, *args):
712 self.close()
714 def closelater(self, fd):
715 return self._exitstack.enter_context(fd)
717 def close(self):
718 self._exitstack.close()
720 def openfile(self, file, comm, mode='w'):
721 if hasattr(file, 'close'):
722 return file # File already opened, not for us to close.
724 encoding = None if mode.endswith('b') else 'utf-8'
726 if file is None or comm.rank != 0:
727 return self.closelater(open(os.devnull, mode=mode,
728 encoding=encoding))
730 if file == '-':
731 return sys.stdout
733 return self.closelater(open(file, mode=mode, encoding=encoding))
736def get_python_package_path_description(
737 package, default='module has no path') -> str:
738 """Helper to get path description of a python package/module
740 If path has multiple elements, the first one is returned.
741 If it is empty, the default is returned.
742 Exceptions are returned as strings default+(exception).
743 Always returns a string.
744 """
745 try:
746 p = list(package.__path__)
747 if p:
748 return str(p[0])
749 else:
750 return default
751 except Exception as ex:
752 return f"{default} ({ex})"