프로그래밍/Computer Science

[Data Structure] 분리 집합 (Disjoint Set)

Churnobyl 2023. 9. 21. 23:26
728x90
반응형


분리 집합 (Disjoint Set)

  분리 집합은 서로소 집합이라고도 한다. 전체 집합 U에 대해 U의 분리 집합 A, B는 다음 조건을 만족한다.

  • A, B는 U의 부분집합이다
  • A, B는 서로 공통 원소를 가지지 않는다
  • A, B의 합집합이 곧 전체집합 U이다.

 

 즉, 전체 집합 U를 겹치는 부분이 발생하지 않도록 모든 원소를 분리시킨 집합을 분리 집합이라고 한다.

 

 분리 집합은 일반적으로 두 원소가 같은 집합에 속하는지 빠르게 여부를 확인하거나, 두 집합을 합치는 등의 연산을 할 때 사용된다. 예를 들면 SNS에서 A라는 사람과 B라는 사람이 친구 관계인지 여부를 확인하거나, 친구를 맺었다면 두  집합을 합치는 연산을 할 수 있다.

 

분리 집합은 대표적으로 트리 자료 구조Union-Find 알고리즘으로 구현한다.

 


1.  Union-Find 알고리즘

class Node:
    def __init__(self, data) -> None:
        self.data = data
        self.parent = self

    def __repr__(self) -> str:
        return str(self.data)


class DisjointSet:
    def __init__(self) -> None:
        self.tree = {}

    def make_set(self, __x) -> None:
        node = Node(__x)
        self.tree[__x] = node

    def find(self, __x) -> Node:
        if self.tree[__x].parent != self.tree[__x]:
            self.tree[__x].parent = self.find(self.tree[__x].parent.data)
        return self.tree[__x].parent

    def union(self, __x, __y) -> None:
        root_x = self.find(__x)
        root_y = self.find(__y)

        if root_x != root_y:
            root_y.parent = root_x

    def is_same_set(self, __x, __y) -> bool:
        return self.find(__x) == self.find(__y)

    def __str__(self) -> str:
        return ', '.join([f"{key} -> {value.parent}" for key, value in self.tree.items()])


djs = DisjointSet()

for i in range(6):
    djs.make_set(i)

djs.union(1, 0)
djs.union(1, 2)
djs.union(4, 1)

print(djs.is_same_set(0, 4))
print(djs.is_same_set(0, 5))
print(djs)
True
False
0 -> 4, 1 -> 4, 2 -> 1, 3 -> 3, 4 -> 4, 5 -> 5

 

 다음과 같이 기본적인 트리 구조를 활용해서 DisjointSet 클래스를 만들었다. 트리 구조이므로 각 노드들은 자신의 부모를 가리키는 parent 속성을 가지고 있으며, 초기에는 자기 자신을 가리킨다. 즉 최종적으로 root노드가 되는 노드들은 계속해서 자신을 가리키고 있을 것이다.

 

 union메소드를 이용해 두번째 인자의 부모를 첫번째 인자로 설정해주면 서로 간의 관계가 설정되며 결과는 다음과 같이 나왔다. 마지막 출력문은 분리 집합의 사용에서는 의미없는 값이지만 관계를 보기 위해 출력해봤다. 실제 사용은 두 노드가 연관 있는지 없는지를 알아보기 위해 사용한다.

 

 find메소드에서는 경로 압축 기법(Path Compression)이 사용됐는데, find메소드를 실행할 때마다 해당 노드의 부모를 루트 노드와 연결시켜주는 기법이다. 이를 통해서 경로를 단축시킬 수 있다

경로 압축 기법 미사용시

 

경로 압축 기법 사용 시


2. 랭크 기법

 속도를 더 개선하기 위해 랭크 기법을 사용할 수 있다. 랭크 기법은 두 집합을 합칠 때 깊이가 더 얕은 트리를 깊은 트리에게 연결시켜 최대한 트리의 깊이를 얕게 유지하는 기법이다.

 

 각 노드에게 rank라는 속성을 추가하고 랭크가 더 작은 트리가 랭크가 큰 트리에게 연결될 수 있도록 한다. union할 때 두 노드의 parent가 같다면, 해당 parent의 랭크를 1 올려서 더 큰 트리가 되도록 유지시킨다.

 

class Node:
    def __init__(self, data) -> None:
        self.data = data
        self.parent = self
        self.rank = 0  # 랭크 초기화

    def __repr__(self) -> str:
        return str(self.data)


class DisjointSet:
    def __init__(self) -> None:
        self.tree = {}

    def make_set(self, __x) -> None:
        node = Node(__x)
        self.tree[__x] = node

    def find(self, __x) -> Node:
        if self.tree[__x].parent != self.tree[__x]:
            self.tree[__x].parent = self.find(
                self.tree[__x].parent.data)  # 경로 압축
        return self.tree[__x].parent

    def union(self, __x, __y) -> None:
        root_x = self.find(__x)
        root_y = self.find(__y)

        if root_x != root_y:
            if root_x.rank < root_y.rank:  # 랭크가 작은 트리를 큰 트리 아래에 연결
                root_x.parent = root_y
            else:
                root_y.parent = root_x
                if root_x.rank == root_y.rank:  # 랭크가 같을 경우, 루트의 랭크를 1 증가시킴
                    root_x.rank += 1

    def is_same_set(self, __x, __y) -> bool:
        return self.find(__x) == self.find(__y)

    def __str__(self) -> str:
        return ', '.join([f"{key} -> {value.parent}" for key, value in self.tree.items()])


djs = DisjointSet()

for i in range(6):
    djs.make_set(i)

djs.union(1, 0)
djs.union(1, 2)
djs.union(4, 1)

print(djs.is_same_set(0, 4))
print(djs.is_same_set(0, 5))
print(djs)
True
False
0 -> 1, 1 -> 1, 2 -> 1, 3 -> 3, 4 -> 1, 5 -> 5

 


문제

https://www.acmicpc.net/problem/1717

 

1717번: 집합의 표현

초기에 $n+1$개의 집합 $\{0\}, \{1\}, \{2\}, \dots , \{n\}$이 있다. 여기에 합집합 연산과, 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산을 수행하려고 한다. 집합을 표현하는 프로그램을 작

www.acmicpc.net

 

 분리 집합을 구현하는 간단한 문제

 

import sys


class DisjointSet:
    def __init__(self, n) -> None:
        self.parent = [i for i in range(n + 1)]
        self.rank = [0 for _ in range(n + 1)]

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            if self.rank[root_x] < self.rank[root_y]:
                root_x, root_y = root_y, root_x

            self.parent[root_y] = root_x
            if self.rank[root_x] == self.rank[root_y]:
                self.rank[root_x] += 1

    def is_same_set(self, x, y):
        return self.find(x) == self.find(y)


N, M = map(int, sys.stdin.readline().rstrip().split())

djs = DisjointSet(N)

for _ in range(M):
    ins = list(map(int, sys.stdin.readline().rstrip().split()))

    if ins[0] == 0:
        djs.union(ins[1], ins[2])
    elif ins[0] == 1:
        print("yes" if djs.is_same_set(ins[1], ins[2]) else "no")
반응형