Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Implements the Rank Determination Algorithm (RDA)
4Method is described in:
5Definition of a scoring parameter to identify low-dimensional materials
6components
7P.M. Larsen, M. Pandey, M. Strange, and K. W. Jacobsen
8Phys. Rev. Materials 3 034003, 2019
9https://doi.org/10.1103/PhysRevMaterials.3.034003
10"""
11import numpy as np
12from collections import defaultdict
13from ase.geometry.dimensionality.disjoint_set import DisjointSet
16# Numpy has a large overhead for lots of small vectors. The cross product is
17# particularly bad. Pure python is a lot faster.
19def dot_product(A, B):
20 return sum([a * b for a, b in zip(A, B)])
23def cross_product(a, b):
24 return [a[i] * b[j] - a[j] * b[i] for i, j in [(1, 2), (2, 0), (0, 1)]]
27def subtract(A, B):
28 return [a - b for a, b in zip(A, B)]
31def rank_increase(a, b):
32 if len(a) == 0:
33 return True
34 elif len(a) == 1:
35 return a[0] != b
36 elif len(a) == 4:
37 return False
39 l = a + [b]
40 w = cross_product(subtract(l[1], l[0]), subtract(l[2], l[0]))
41 if len(a) == 2:
42 return any(w)
43 elif len(a) == 3:
44 return dot_product(w, subtract(l[3], l[0])) != 0
45 else:
46 raise Exception("This shouldn't be possible.")
49def bfs(adjacency, start):
50 """Traverse the component graph using BFS.
52 The graph is traversed until the matrix rank of the subspace spanned by
53 the visited components no longer increases.
54 """
55 visited = set()
56 cvisited = defaultdict(list)
57 queue = [(start, (0, 0, 0))]
58 while queue:
59 vertex = queue.pop(0)
60 if vertex in visited:
61 continue
63 visited.add(vertex)
64 c, p = vertex
65 if not rank_increase(cvisited[c], p):
66 continue
68 cvisited[c].append(p)
70 for nc, offset in adjacency[c]:
72 nbrpos = (p[0] + offset[0], p[1] + offset[1], p[2] + offset[2])
73 nbrnode = (nc, nbrpos)
74 if nbrnode in visited:
75 continue
77 if rank_increase(cvisited[nc], nbrpos):
78 queue.append(nbrnode)
80 return visited, len(cvisited[start]) - 1
83def traverse_component_graphs(adjacency):
84 vertices = adjacency.keys()
85 all_visited = {}
86 ranks = {}
87 for v in vertices:
88 visited, rank = bfs(adjacency, v)
89 all_visited[v] = visited
90 ranks[v] = rank
92 return all_visited, ranks
95def build_adjacency_list(parents, bonds):
96 graph = np.unique(parents)
97 adjacency = {e: set() for e in graph}
98 for (i, j, offset) in bonds:
99 component_a = parents[i]
100 component_b = parents[j]
101 adjacency[component_a].add((component_b, offset))
102 return adjacency
105def get_dimensionality_histogram(ranks, roots):
106 h = [0, 0, 0, 0]
107 for e in roots:
108 h[ranks[e]] += 1
109 return tuple(h)
112def merge_mutual_visits(all_visited, ranks, graph):
113 """Find components with mutual visits and merge them."""
114 merged = False
115 common = defaultdict(list)
116 for b, visited in all_visited.items():
117 for offset in visited:
118 for a in common[offset]:
119 assert ranks[a] == ranks[b]
120 merged |= graph.union(a, b)
121 common[offset].append(b)
123 if not merged:
124 return merged, all_visited, ranks
126 merged_visits = defaultdict(set)
127 merged_ranks = {}
128 parents = graph.find_all()
129 for k, v in all_visited.items():
130 key = parents[k]
131 merged_visits[key].update(v)
132 merged_ranks[key] = ranks[key]
133 return merged, merged_visits, merged_ranks
136class RDA:
138 def __init__(self, num_atoms):
139 """
140 Initializes the RDA class.
142 A disjoint set is used to maintain the component graph.
144 Parameters:
146 num_atoms: int The number of atoms in the unit cell.
147 """
148 self.bonds = []
149 self.graph = DisjointSet(num_atoms)
150 self.adjacency = None
151 self.hcached = None
152 self.components_cached = None
153 self.cdim_cached = None
155 def insert_bond(self, i, j, offset):
156 """
157 Adds a bond to the list of graph edges.
159 Graph components are merged if the bond does not cross a cell boundary.
160 Bonds which cross cell boundaries can inappropriately connect
161 components which are not connected in the infinite crystal. This is
162 tested during graph traversal.
164 Parameters:
166 i: int The index of the first atom.
167 n: int The index of the second atom.
168 offset: tuple The cell offset of the second atom.
169 """
170 roffset = tuple(-np.array(offset))
172 if offset == (0, 0, 0): # only want bonds in aperiodic unit cell
173 self.graph.union(i, j)
174 else:
175 self.bonds += [(i, j, offset)]
176 self.bonds += [(j, i, roffset)]
178 def check(self):
179 """
180 Determines the dimensionality histogram.
182 The component graph is traversed (using BFS) until the matrix rank
183 of the subspace spanned by the visited components no longer increases.
185 Returns:
186 hist : tuple Dimensionality histogram.
187 """
188 adjacency = build_adjacency_list(self.graph.find_all(),
189 self.bonds)
190 if adjacency == self.adjacency:
191 return self.hcached
193 self.adjacency = adjacency
194 self.all_visited, self.ranks = traverse_component_graphs(adjacency)
195 res = merge_mutual_visits(self.all_visited, self.ranks, self.graph)
196 _, self.all_visited, self.ranks = res
198 self.roots = np.unique(self.graph.find_all())
199 h = get_dimensionality_histogram(self.ranks, self.roots)
200 self.hcached = h
201 return h
203 def get_components(self):
204 """
205 Determines the dimensionality and constituent atoms of each component.
207 Returns:
208 components: array The component ID of every atom
209 """
210 component_dim = {e: self.ranks[e] for e in self.roots}
211 relabelled_components = self.graph.find_all(relabel=True)
212 relabelled_dim = {}
213 for k, v in component_dim.items():
214 relabelled_dim[relabelled_components[k]] = v
215 self.cdim_cached = relabelled_dim
216 self.components_cached = relabelled_components
218 return relabelled_components, relabelled_dim