最小生成树MST 之 Prim及其实现

昨天突然想不起来并查集的名字,忽觉忘记的有点多,于是决定复习一下各种算法。

问题引入

要在n个城市之间铺设光缆,主要目标是要使这 n 个城市的任意两个之间都可以通信,但铺设光缆的费用很高,且各个城市之间铺设光缆的费用不同,因此另一个目标是要使铺设光缆的总费用最低。这就需要找到带权的最小生成树。

将城市抽象为点,光缆抽象为边,费用抽象为权值,就可以绘制一幅连通加权无向图

类似的还有修路,建桥等问题,抽象出来都属于最小生成树。

最小生成树简介

最小生成树是一副连通加权无向图中一棵权值最小的生成树。

在一给定的无向图 G = (V, E) 中,(u, v) 代表连接顶点 u 与顶点 v 的边(即 {\displaystyle (u,v)\in E}),而 w(u, v) 代表此边的权重,若存在 T 为 E 的子集(即 {\displaystyle T\subseteq E})且 (V, T) 为树,使得

的 w(T) 最小,则此 T 为 G 的最小生成树

最小生成树其实是最小权重生成树的简称。

一个连通图可能有多个生成树。当图中的边具有权值时,总会有一个生成树的边的权值之和小于或者等于其它生成树的边的权值之和。广义上而言,对于非连通无向图来说,它的每一连通分量同样有最小生成树,它们的并被称为最小生成森林

以有线电视电缆的架设为例,若只能沿着街道布线,则以街道为边,而路口为顶点,其中必然有一最小生成树能使布线成本最低

来自 最小生成树 – 维基百科

Prim算法

与之齐名的还有Kruskal算法,相比之下,Prim是从图中的顶点出发,而Kruskal更注重图中的边。

不管是Prim还是Kruskal,其核心都是贪心——寻找当前状态下最小权值的边。

算法步骤:

  1. 从一个顶点出发。
  2. 查找与已连接顶点相连的边权值最小的边。
  3. 通过此边连接对应顶点。
  4. 重复2、3步骤,直到连接所有顶点。

Prim算法实现(Python)

  • 使用random随机生成图(邻接矩阵)
  • 权值最大为100
  • 输出顶点连接顺序以及最小权值和
import random
MAX_DISTANCE = 100  # 边的最大权值


def create_map():  # 生成图(邻接矩阵)
    map = []
    size = random.randint(5, 8)
    print("图大小:", size)
    
    for i in range(size):
        row = []
        for j in range(size):
            row.append(random.randint(1, MAX_DISTANCE))
        map.append(row)
    
    print("生成图:")
    for i in range(size):
        print(map[i])
        
    return map


def prim(map):
    n = len(map)  # 图的大小
    start_node = 0  # 开始顶点
    tol_distance = 0  # 总开销
    visit_order = []  # 记录访问顺序
    visited = [False] * n  # 初始化访问标记
    visited_distance = [0] * n  # 初始化当前最小权值集合
    
    for i in range(n):  # 载入开始顶点到各顶点的权值
        visited_distance[i] = map[start_node][i]
    visited[start_node] = True
    visit_order.append([start_node, 0])
    
    for i in range(n-1):  # 开始算法,连接N个点需要N-1条边
        visiting_node = 0
        min_distance = MAX_DISTANCE + 1
        for j in range(n):  # 查找当前连通边中权值最小的边,以及其连接的顶点
            if not visited[j] and visited_distance[j] < min_distance:
                min_distance = visited_distance[j]
                visiting_node = j
        
        if visiting_node == 0:
            return
        visited[visiting_node] = True
        visit_order.append([visiting_node, min_distance])
        tol_distance += min_distance
        
        for j in range(n):  # 维护最小权值的边集合
            if not visited[j] and visited_distance[j] > map[visiting_node][j]:
                visited_distance[j] = map[visiting_node][j]
    
    return visit_order, tol_distance


def main():
    map = create_map()
    result, tol_distant = prim(map)
    
    print("生成顺序(节点,权值):")
    for i in range(len(result)):
        print(i, ": ", result[i])
    
    print("最小开销:", tol_distant)


if __name__ == "__main__":
    main()

输出: