Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import numpy as np 

2 

3 

4class DisjointSet: 

5 

6 def __init__(self, n): 

7 self.sizes = np.ones(n, dtype=int) 

8 self.parents = np.arange(n) 

9 self.nc = n 

10 

11 def _compress(self): 

12 a = self.parents 

13 b = a[a] 

14 while (a != b).any(): 

15 a = b 

16 b = a[a] 

17 self.parents = a 

18 

19 def union(self, a, b): 

20 a = self.find(a) 

21 b = self.find(b) 

22 if a == b: 

23 return False 

24 

25 sizes = self.sizes 

26 parents = self.parents 

27 if sizes[a] < sizes[b]: 

28 parents[a] = b 

29 sizes[b] += sizes[a] 

30 else: 

31 parents[b] = a 

32 sizes[a] += sizes[b] 

33 

34 self.nc -= 1 

35 return True 

36 

37 def find(self, index): 

38 parents = self.parents 

39 parent = parents[index] 

40 while parent != parents[parent]: 

41 parent = parents[parent] 

42 parents[index] = parent 

43 return parent 

44 

45 def find_all(self, relabel=False): 

46 self._compress() 

47 if not relabel: 

48 return self.parents 

49 

50 # order elements by frequency 

51 unique, inverse, counts = np.unique(self.parents, 

52 return_inverse=True, 

53 return_counts=True) 

54 indices = np.argsort(counts, kind='merge')[::-1] 

55 return np.argsort(indices)[inverse]