import cv2
import numpy as np
from heapq import (
heappop,
heappush,
)
# 定义8连通领域
directions = [(0, 1), (1, 1), (1, 0), (1, -1), (0, -1), (-1, -1), (-1, 0), (-1, 1)]
# 定义一个函数来获取两个点之间连通的所有像素坐标和路径长度
def get_connected_coord_list(mask, point_1, point_2):
# print(f"==>> point_1: {point_1}")
# print(f"==>> point_2: {point_2}")
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
rows, cols = mask.shape[:2]
visited = np.zeros_like(mask, dtype=np.uint8)
queue = []
# 优先队列,优先级是距离
# 堆是一种数据结构,它总是保持最小元素在顶部。
# 将元组(距离, 点)放入堆中,距离较小的点会优先出队
heappush(queue, (0, point_1))
# visited[point_1[1], point_1[0]] = 1
visited[point_1[0], point_1[1]] = 1
# 记录每个点的父节点,用于回溯路径
parents = {}
connected_pixels = []
while queue:
# 从对中取出距离最小对应的元组
current_distance, current_point = heappop(queue)
if current_point == point_2:
# print(f"==>> current_point: {current_point}")
#
# # 回溯路径并计算长度
# path_length = 0
while current_point in parents:
# path_length += 1
current_point = parents[current_point]
# print(f"==>> current_point: {current_point}")
connected_pixels.append(current_point)
# return path_length, connected_pixels
connected_pixels.append(point_2)
return connected_pixels
for direction in directions:
next_point = (
current_point[0] + direction[0],
current_point[1] + direction[1],
)
if 0 <= next_point[0] < rows and 0 <= next_point[1] < cols:
if (
visited[next_point[0], next_point[1]] == 0
and mask[next_point[0], next_point[1]] > 0
):
visited[next_point[0], next_point[1]] = 1
# 优先队列,优先级是距离
heappush(queue, (current_distance + 1, next_point))
parents[next_point] = current_point
# return -1, []
return []
if __name__ == "__main__":
mask_path = "./test.png"
mask = cv2.imread(mask_path, 0)
# (h, w)
end_point_crop_list = [(1, 0), (8, 30), (16, 42), (46, 43)]
intersection_point_crop_list = [(10, 32), (11, 15), (18, 36), (44, 45)]
# 调用函数并打印结果
connected_coord_list = get_connected_coord_list(
mask, (1, 0), (10, 32)
)
print(f"==>> connected_coord_list: {connected_coord_list}")
mask_visual = np.zeros_like(mask)
for coord in connected_coord_list:
mask_visual[coord[0], coord[1]] = 255
print(len(connected_coord_list))
# 36
mask_visual_path = "./test_visual_point_mask.png"
cv2.imwrite(mask_visual_path, mask_visual)
因篇幅问题不能全部显示,请点此查看更多更全内容