Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import numpy as np
4class DisjointSet:
6 def __init__(self, n):
7 self.sizes = np.ones(n, dtype=int)
8 self.parents = np.arange(n)
9 self.nc = n
11 def _compress(self):
12 a = self.parents
13 b = a[a]
14 while (a != b).any():
15 a = b
16 b = a[a]
17 self.parents = a
19 def union(self, a, b):
20 a = self.find(a)
21 b = self.find(b)
22 if a == b:
23 return False
25 sizes = self.sizes
26 parents = self.parents
27 if sizes[a] < sizes[b]:
28 parents[a] = b
29 sizes[b] += sizes[a]
30 else:
31 parents[b] = a
32 sizes[a] += sizes[b]
34 self.nc -= 1
35 return True
37 def find(self, index):
38 parents = self.parents
39 parent = parents[index]
40 while parent != parents[parent]:
41 parent = parents[parent]
42 parents[index] = parent
43 return parent
45 def find_all(self, relabel=False):
46 self._compress()
47 if not relabel:
48 return self.parents
50 # order elements by frequency
51 unique, inverse, counts = np.unique(self.parents,
52 return_inverse=True,
53 return_counts=True)
54 indices = np.argsort(counts, kind='merge')[::-1]
55 return np.argsort(indices)[inverse]