Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import List
2import numpy as np
3from ase import Atoms
4from .spacegroup import Spacegroup, _SPACEGROUP
6__all__ = ('get_basis', )
9def _has_spglib() -> bool:
10 """Check if spglib is available"""
11 try:
12 import spglib
13 assert spglib # silence flakes
14 except ImportError:
15 return False
16 return True
19def _get_basis_ase(atoms: Atoms,
20 spacegroup: _SPACEGROUP,
21 tol: float = 1e-5) -> np.ndarray:
22 """Recursively get a reduced basis, by removing equivalent sites.
23 Uses the first index as a basis, then removes all equivalent sites,
24 uses the next index which hasn't been placed into a basis, etc.
26 :param atoms: Atoms object to get basis from.
27 :param spacegroup: ``int``, ``str``, or
28 :class:`ase.spacegroup.Spacegroup` object.
29 :param tol: ``float``, numeric tolerance for positional comparisons
30 Default: ``1e-5``
31 """
32 scaled_positions = atoms.get_scaled_positions()
33 spacegroup = Spacegroup(spacegroup)
35 def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray):
36 """Check if a scaled position is in a site"""
37 for site in sites:
38 if np.allclose(site, scaled_pos, atol=tol):
39 return True
40 return False
42 def _get_basis(scaled_positions: np.ndarray,
43 spacegroup: Spacegroup,
44 all_basis=None) -> np.ndarray:
45 """Main recursive function to be executed"""
46 if all_basis is None:
47 # Initialization, first iteration
48 all_basis = []
49 if len(scaled_positions) == 0:
50 # End termination
51 return np.array(all_basis)
53 basis = scaled_positions[0]
54 all_basis.append(basis.tolist()) # Add the site as a basis
56 # Get equivalent sites
57 sites, _ = spacegroup.equivalent_sites(basis)
59 # Remove equivalent
60 new_scaled = np.array(
61 [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)])
62 # We should always have at least popped off the site itself
63 assert len(new_scaled) < len(scaled_positions)
65 return _get_basis(new_scaled, spacegroup, all_basis=all_basis)
67 return _get_basis(scaled_positions, spacegroup)
70def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray:
71 """Get a reduced basis using spglib. This requires having the
72 spglib package installed.
74 :param atoms: Atoms, atoms object to get basis from
75 :param tol: ``float``, numeric tolerance for positional comparisons
76 Default: ``1e-5``
77 """
78 if not _has_spglib():
79 # Give a reasonable alternative solution to this function.
80 raise ImportError(
81 ('This function requires spglib. Use "get_basis" and specify '
82 'the spacegroup instead, or install spglib.'))
84 scaled_positions = atoms.get_scaled_positions()
85 reduced_indices = _get_reduced_indices(atoms, tol=tol)
86 return scaled_positions[reduced_indices]
89def _can_use_spglib(spacegroup: _SPACEGROUP = None) -> bool:
90 """Helper dispatch function, for deciding if the spglib implementation
91 can be used"""
92 if not _has_spglib():
93 # Spglib not installed
94 return False
95 if spacegroup is not None:
96 # Currently, passing an explicit space group is not supported
97 # in spglib implementation
98 return False
99 return True
102# Dispatcher function for chosing get_basis implementation.
103def get_basis(atoms: Atoms,
104 spacegroup: _SPACEGROUP = None,
105 method: str = 'auto',
106 tol: float = 1e-5) -> np.ndarray:
107 """Function for determining a reduced basis of an atoms object.
108 Can use either an ASE native algorithm or an spglib based one.
109 The native ASE version requires specifying a space group,
110 while the (current) spglib version cannot.
111 The default behavior is to automatically determine which implementation
112 to use, based on the the ``spacegroup`` parameter,
113 and whether spglib is installed.
115 :param atoms: ase Atoms object to get basis from
116 :param spacegroup: Optional, ``int``, ``str``
117 or :class:`ase.spacegroup.Spacegroup` object.
118 If unspecified, the spacegroup can be inferred using spglib,
119 if spglib is installed, and ``method`` is set to either
120 ``'spglib'`` or ``'auto'``.
121 Inferring the spacegroup requires spglib.
122 :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``.
123 Selection of which implementation to use.
124 It is recommended to use ``'auto'``, which is also the default.
125 :param tol: ``float``, numeric tolerance for positional comparisons
126 Default: ``1e-5``
127 """
128 ALLOWED_METHODS = ('auto', 'ase', 'spglib')
130 if method not in ALLOWED_METHODS:
131 raise ValueError('Expected one of {} methods, got {}'.format(
132 ALLOWED_METHODS, method))
134 if method == 'auto':
135 # Figure out which implementation we want to use automatically
136 # Essentially figure out if we can use the spglib version or not
137 use_spglib = _can_use_spglib(spacegroup=spacegroup)
138 else:
139 # User told us which implementation they wanted
140 use_spglib = method == 'spglib'
142 if use_spglib:
143 # Use the spglib implementation
144 # Note, we do not pass the spacegroup, as the function cannot handle
145 # an explicit space group right now. This may change in the future.
146 return _get_basis_spglib(atoms, tol=tol)
147 else:
148 # Use the ASE native non-spglib version, since a specific
149 # space group is requested
150 if spacegroup is None:
151 # We have reached this point either because spglib is not installed,
152 # or ASE was explicitly required
153 raise ValueError(
154 ('A space group must be specified for the native ASE '
155 'implementation. Try using the spglib version instead, '
156 'or explicitly specifying a space group.'))
157 return _get_basis_ase(atoms, spacegroup, tol=tol)
160def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> List[int]:
161 """Get a list of the reduced atomic indices using spglib.
162 Note: Does no checks to see if spglib is installed.
164 :param atoms: ase Atoms object to reduce
165 :param tol: ``float``, numeric tolerance for positional comparisons
166 """
167 import spglib
169 # Create input for spglib
170 spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(),
171 atoms.numbers)
172 symmetry_data = spglib.get_symmetry_dataset(spglib_cell, symprec=tol)
173 return list(set(symmetry_data['equivalent_atoms']))