Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Implements the Rank Determination Algorithm (RDA) 

3 

4Method is described in: 

5Definition of a scoring parameter to identify low-dimensional materials 

6components 

7P.M. Larsen, M. Pandey, M. Strange, and K. W. Jacobsen 

8Phys. Rev. Materials 3 034003, 2019 

9https://doi.org/10.1103/PhysRevMaterials.3.034003 

10""" 

11import numpy as np 

12from collections import defaultdict 

13from ase.geometry.dimensionality.disjoint_set import DisjointSet 

14 

15 

16# Numpy has a large overhead for lots of small vectors. The cross product is 

17# particularly bad. Pure python is a lot faster. 

18 

19def dot_product(A, B): 

20 return sum([a * b for a, b in zip(A, B)]) 

21 

22 

23def cross_product(a, b): 

24 return [a[i] * b[j] - a[j] * b[i] for i, j in [(1, 2), (2, 0), (0, 1)]] 

25 

26 

27def subtract(A, B): 

28 return [a - b for a, b in zip(A, B)] 

29 

30 

31def rank_increase(a, b): 

32 if len(a) == 0: 

33 return True 

34 elif len(a) == 1: 

35 return a[0] != b 

36 elif len(a) == 4: 

37 return False 

38 

39 l = a + [b] 

40 w = cross_product(subtract(l[1], l[0]), subtract(l[2], l[0])) 

41 if len(a) == 2: 

42 return any(w) 

43 elif len(a) == 3: 

44 return dot_product(w, subtract(l[3], l[0])) != 0 

45 else: 

46 raise Exception("This shouldn't be possible.") 

47 

48 

49def bfs(adjacency, start): 

50 """Traverse the component graph using BFS. 

51 

52 The graph is traversed until the matrix rank of the subspace spanned by 

53 the visited components no longer increases. 

54 """ 

55 visited = set() 

56 cvisited = defaultdict(list) 

57 queue = [(start, (0, 0, 0))] 

58 while queue: 

59 vertex = queue.pop(0) 

60 if vertex in visited: 

61 continue 

62 

63 visited.add(vertex) 

64 c, p = vertex 

65 if not rank_increase(cvisited[c], p): 

66 continue 

67 

68 cvisited[c].append(p) 

69 

70 for nc, offset in adjacency[c]: 

71 

72 nbrpos = (p[0] + offset[0], p[1] + offset[1], p[2] + offset[2]) 

73 nbrnode = (nc, nbrpos) 

74 if nbrnode in visited: 

75 continue 

76 

77 if rank_increase(cvisited[nc], nbrpos): 

78 queue.append(nbrnode) 

79 

80 return visited, len(cvisited[start]) - 1 

81 

82 

83def traverse_component_graphs(adjacency): 

84 vertices = adjacency.keys() 

85 all_visited = {} 

86 ranks = {} 

87 for v in vertices: 

88 visited, rank = bfs(adjacency, v) 

89 all_visited[v] = visited 

90 ranks[v] = rank 

91 

92 return all_visited, ranks 

93 

94 

95def build_adjacency_list(parents, bonds): 

96 graph = np.unique(parents) 

97 adjacency = {e: set() for e in graph} 

98 for (i, j, offset) in bonds: 

99 component_a = parents[i] 

100 component_b = parents[j] 

101 adjacency[component_a].add((component_b, offset)) 

102 return adjacency 

103 

104 

105def get_dimensionality_histogram(ranks, roots): 

106 h = [0, 0, 0, 0] 

107 for e in roots: 

108 h[ranks[e]] += 1 

109 return tuple(h) 

110 

111 

112def merge_mutual_visits(all_visited, ranks, graph): 

113 """Find components with mutual visits and merge them.""" 

114 merged = False 

115 common = defaultdict(list) 

116 for b, visited in all_visited.items(): 

117 for offset in visited: 

118 for a in common[offset]: 

119 assert ranks[a] == ranks[b] 

120 merged |= graph.union(a, b) 

121 common[offset].append(b) 

122 

123 if not merged: 

124 return merged, all_visited, ranks 

125 

126 merged_visits = defaultdict(set) 

127 merged_ranks = {} 

128 parents = graph.find_all() 

129 for k, v in all_visited.items(): 

130 key = parents[k] 

131 merged_visits[key].update(v) 

132 merged_ranks[key] = ranks[key] 

133 return merged, merged_visits, merged_ranks 

134 

135 

136class RDA: 

137 

138 def __init__(self, num_atoms): 

139 """ 

140 Initializes the RDA class. 

141 

142 A disjoint set is used to maintain the component graph. 

143 

144 Parameters: 

145 

146 num_atoms: int The number of atoms in the unit cell. 

147 """ 

148 self.bonds = [] 

149 self.graph = DisjointSet(num_atoms) 

150 self.adjacency = None 

151 self.hcached = None 

152 self.components_cached = None 

153 self.cdim_cached = None 

154 

155 def insert_bond(self, i, j, offset): 

156 """ 

157 Adds a bond to the list of graph edges. 

158 

159 Graph components are merged if the bond does not cross a cell boundary. 

160 Bonds which cross cell boundaries can inappropriately connect 

161 components which are not connected in the infinite crystal. This is 

162 tested during graph traversal. 

163 

164 Parameters: 

165 

166 i: int The index of the first atom. 

167 n: int The index of the second atom. 

168 offset: tuple The cell offset of the second atom. 

169 """ 

170 roffset = tuple(-np.array(offset)) 

171 

172 if offset == (0, 0, 0): # only want bonds in aperiodic unit cell 

173 self.graph.union(i, j) 

174 else: 

175 self.bonds += [(i, j, offset)] 

176 self.bonds += [(j, i, roffset)] 

177 

178 def check(self): 

179 """ 

180 Determines the dimensionality histogram. 

181 

182 The component graph is traversed (using BFS) until the matrix rank 

183 of the subspace spanned by the visited components no longer increases. 

184 

185 Returns: 

186 hist : tuple Dimensionality histogram. 

187 """ 

188 adjacency = build_adjacency_list(self.graph.find_all(), 

189 self.bonds) 

190 if adjacency == self.adjacency: 

191 return self.hcached 

192 

193 self.adjacency = adjacency 

194 self.all_visited, self.ranks = traverse_component_graphs(adjacency) 

195 res = merge_mutual_visits(self.all_visited, self.ranks, self.graph) 

196 _, self.all_visited, self.ranks = res 

197 

198 self.roots = np.unique(self.graph.find_all()) 

199 h = get_dimensionality_histogram(self.ranks, self.roots) 

200 self.hcached = h 

201 return h 

202 

203 def get_components(self): 

204 """ 

205 Determines the dimensionality and constituent atoms of each component. 

206 

207 Returns: 

208 components: array The component ID of every atom 

209 """ 

210 component_dim = {e: self.ranks[e] for e in self.roots} 

211 relabelled_components = self.graph.find_all(relabel=True) 

212 relabelled_dim = {} 

213 for k, v in component_dim.items(): 

214 relabelled_dim[relabelled_components[k]] = v 

215 self.cdim_cached = relabelled_dim 

216 self.components_cached = relabelled_components 

217 

218 return relabelled_components, relabelled_dim