최소 스패닝 트리
Minimum Spanning Tree(MST)
간선의 가중치 합이 최소인 Spanning Tree를 의미함
Spanning Tree : 그래프 내에 있는 모든 정점을 연결하고, 사이클이 없는 그래프(Spanning Tree의 간선의 수 < 전체 간선의 수)
그래프 내의 모든 정점을 가장 적은 수, 적은 비용의 간선으로 연결한 그래프
MST 특징
간선의 가중치의 합이 최소여야 함
n개의 정점을 가진 그래프에 대해 n-1개의 간선만 사용해야 함
사이클이 포함되서는 안됨
MST 알고리즘
크루스칼 알고리즘
프림 알고리즘
크루스칼(Kruskal) 알고리즘
탐욕적인 방법을 사용하여 네트워크 내의 모든 정점을 최소 비용으로 연결하여 최적의 답을 구하는 알고리즘
최소 비용의 간선으로 사이클이 형성되지 않는 방식으로 간선을 선택함 → 간선 선택 기반의 알고리즘
이전 단계에서 만들어진 Spanning Tree를 고려하지 않고 무조건 최소 비용의 간선만 선택 → 탐욕적
과정
그래프의 간선들을 가중치를 기준으로 오름차순 정렬
정렬된 간선들을 순서대로 사이클이 형성되지 않는 간선을 선택
사이클 형성 여부 판단 →
union-find 알고리즘
선택된 간선을 MST 집합에 추가
시간복잡도 → O(ElogE + E)
간선 오름차순 정렬 →
O(ElogE)
union-find →
O(1) * E
코드
import sys v, e = map(int, input().split()) # 부모 테이블 초기화 parent = [0] * (v+1) for i in range(1, v+1): parent[i] = i # find 연산 def find_parent(parent, x): if parent[x] != x: parent[x] = find_parent(parent, parent[x]) return parent[x] # union 연산 def union_parent(parent, a, b): a = find_parent(parent, a) b = find_parent(parent, b) if a < b: parent[b] = a else: parent[a] = b # 간선 정보 담을 리스트와 최소 신장 트리 계산 변수 정의 edges = [] total_cost = 0 # 간선 정보 주어지고 비용을 기준으로 정렬 for _ in range(e): a, b, cost = map(int, input().split()) edges.append((cost, a, b)) # 간선 정보 비용 기준으로 오름차순 정렬 edges.sort() # 간선 정보 하나씩 확인하면서 크루스칼 알고리즘 수행 for i in range(e): cost, a, b = edges[i] # find 연산 후, 부모노드 다르면 사이클 발생 X으므로 union 연산 수행 -> 최소 신장 트리에 포함! if find_parent(parent, a) != find_parent(parent, b): union_parent(parent, a, b) total_cost += cost print(total_cost)
프림(Prim) 알고리즘
특정 시작 정점에서 출발하여 MST 집합을 확장해나가는 알고리즘
정점 선택 기반 알고리즘
과정
최초 MST 집합에는 시작 정점만 포함되어 있음
이전 단계에서 만들어진 MST 집합에 인접한 정점들 중 최소 간선으로 연결된 정점을 MST 집합에 포함
MST 집합에 인접 정점들 중 최소 간선을 가진 정점 찾기 →
우선순위 큐
활용우선순위 큐를 heap이나 array로 구현할 경우
사이클 방지하기 위해 방문 처리가 필요함
MST 집합에 포함된 간선의 수가 n-1개가 될 때까지 1~2 번을 반복
시간복잡도 → O(ElogV)
모든 노드에 대해 탐색
O(V)
* 우선순위 큐로 매 노드마다 최소 간선을 찾음O(logV)
⇒O(VlogV)
각 노드의 인접 간선 탐색
O(E)
* 최소 비용의 간선을 큐에 삽입O(logV)
⇒O(ElogV)
O(VlogV+ElogV)
로O(ElogV)
가 된다 (일반적으로 E가 V보다 큼)코드
import heapq import collections import sys sys.setrecursionlimit(10**6) input = sys.stdin.readline n, m = map(int,input().split()) # 노드 수, 간선 수 graph = collections.defaultdict(list) # 빈 그래프 생성 visited = [0] * (n+1) # 노드의 방문 정보 초기화 # 무방향 그래프 생성 for i in range(m): # 간성 정보 입력 받기 u, v, weight = map(int,input().split()) graph[u].append([weight, u, v]) graph[v].append([weight, v, u]) # 프림 알고리즘 def prim(graph, start_node): visited[start_node] = 1 # 방문 갱신 candidate = graph[start_node] # 인접 간선 추출 heapq.heapify(candidate) # 우선순위 큐 생성 mst = [] # mst total_weight = 0 # 전체 가중치 while candidate: weight, u, v = heapq.heappop(candidate) # 가중치가 가장 적은 간선 추출 if visited[v] == 0: # 방문하지 않았다면 visited[v] = 1 # 방문 갱신 mst.append((u,v)) # mst 삽입 total_weight += weight # 전체 가중치 갱신 for edge in graph[v]: # 다음 인접 간선 탐색 if visited[edge[2]] == 0: # 방문한 노드가 아니라면, (순환 방지) heapq.heappush(candidate, edge) # 우선순위 큐에 edge 삽입 return total_weight print(prim(graph,1))
Last updated
Was this helpful?