堆优化的A*算法-Python实现

堆优化的A*算法-Python实现

原理参考博客地址
代码借鉴地址
A*算法解决二维网格地图中的寻路问题

  • 输入:图片(白色区域代表可行,深色区域代表不行可行)
  • 输出:路径(在图中绘制)
""" 方格地图中的A*算法 (openList进行了堆优化)
A* 算法:  F = G+H
F: 总移动代价
G: 起点到当前点的移动代价  直:1, 斜:1.4
H: 当前点到终点的预估代价  曼哈顿距离
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1.把起点加入 openList中
2.While True:
    a.遍历openList,查找F值最小的节点,作为current
    b.current是终点:
        ========结束========
    c.从openList中弹出,放入closeList中
    d.对八个方位的格点:
        if 越界 or 是障碍物 or 在closeList中:
            continue
        if 不在openList中:
            设置父节点,F,G,H
            加入openList中
        else:
            if 这条路径更好:
                设置父节点,F,G
                更新openList中的对应节点
3.生成路径path
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
堆优化:
    openList:作为最小堆,按F值排序存储坐标 (不更新只增加)
    openDict:坐标:点详细信息 (既更新又增加)
    get_minfNode() 从openList中弹出坐标,去openDict中取点 (但由于不更新只增加,坐标可能冗余)
    in_openList() 判断坐标是否在openDict中即可 

"""
import math
from PIL import Image,ImageDraw 
import numpy as np
import heapq # 堆

STAT_OBSTACLE='#'
STAT_NORMAL='.'

class Node():
    """
    节点元素,parent用来在成功的时候回溯路径
    """
    def __init__(self, x, y,parent=None, g=0, h=0):
        self.parent = parent
        self.x = x
        self.y = y
        self.g = g
        self.h = h
        self.update()
    
    def update(self):
        self.f = self.g+self.h

class RoadMap():
    """ 读进一张图片,二值化成为有障碍物的二维网格化地图,并提供相关操作
    """
    def __init__(self,img_file):
        """图片变二维数组"""
        test_map = []
        img = Image.open(img_file)
#        img = img.resize((100,100))  ### resize图片尺寸
        img_gray = img.convert('L')  # 地图灰度化
        img_arr = np.array(img_gray)
        img_binary = np.where(img_arr<127,0,255)
        for x in range(img_binary.shape[0]):
            temp_row = []
            for y in range(img_binary.shape[1]):
                status = STAT_OBSTACLE if img_binary[x,y]==0 else STAT_NORMAL 
                temp_row.append(status)
            test_map.append(temp_row)
            
        self.map = test_map
        self.cols = len(self.map[0])
        self.rows = len(self.map)
        
    def is_valid_xy(self, x,y):
        if x < 0 or x >= self.rows or y < 0 or y >= self.cols:
            return False
        return True

    def not_obstacle(self,x,y):
        return self.map[x][y] != STAT_OBSTACLE
    
    def EuclidenDistance(self, xy1, xy2):
        """两个像素点之间的欧几里得距离"""
        dis = 0
        for (x1, x2) in zip(xy1, xy2):
            dis += (x1 - x2)**2
        return dis**0.5

    def ManhattanDistance(self,xy1,xy2):
        """两个像素点之间的曼哈顿距离"""
        dis = 0
        for x1,x2 in zip(xy1,xy2):
            dis+=abs(x1-x2)
        return dis

    def check_path(self, xy1, xy2):
        """碰撞检测 两点之间的连线是否经过障碍物"""
        steps = max(abs(xy1[0]-xy2[0]), abs(xy1[1]-xy2[1])) # 取横向、纵向较大值,确保经过的每个像素都被检测到
        xs = np.linspace(xy1[0],xy2[0],steps+1)
        ys = np.linspace(xy1[1],xy2[1],steps+1)
        for i in range(1, steps): # 第一个节点和最后一个节点是 xy1,xy2,无需检查
            if not self.not_obstacle(math.ceil(xs[i]), math.ceil(ys[i])):
                return False
        return True

    def plot(self,path):
        """绘制地图及路径"""
        out = []
        for x in range(self.rows):
            temp = []
            for y in range(self.cols):
                if self.map[x][y]==STAT_OBSTACLE:
                    temp.append(0)
                elif self.map[x][y]==STAT_NORMAL:
                    temp.append(255)
                elif self.map[x][y]=='*':
                    temp.append(127)
                else:
                    temp.append(255)
            out.append(temp)
        for x,y in path:
            out[x][y] = 127
        out = np.array(out)
        img = Image.fromarray(out)
        img.show()

class A_Star(RoadMap):
    """ x是行索引,y是列索引
    """
    def __init__(self, img_file, start=None, end=None):
        """地图文件,起点,终点"""
        RoadMap.__init__(self,img_file)
        self.startXY = tuple(start) if start else (0,0)
        self.endXY = tuple(end) if end else (self.rows-1, self.cols-1)
        self.closeList = set()
        self.path = []
        self.openList = []  # 堆,只添加,和弹出最小值点,
        self.openDict = dict() # openList中的 坐标:详细信息 -->不冗余的
        
    def find_path(self):
        """A*算法寻路主程序"""
        p = Node(self.startXY[0], self.startXY[1], 
                 h=self.ManhattanDistance(self.startXY, self.endXY)) # 构建开始节点
        heapq.heappush(self.openList, (p.f,(p.x,p.y)))
        
        self.openDict[(p.x,p.y)] = p  # 加进dict目录
        while True:
            current = self.get_minfNode()
            if (current.x,current.y)==self.endXY:
                print('found path successfully..')
                self.make_path(current)
                return 
            
            self.closeList.add((current.x,current.y))  ## 加入closeList
            del self.openDict[(current.x,current.y)]
            self.extend_surrounds(current) # 会更新close list

    def make_path(self,p):
        """从结束点回溯到开始点,开始点的parent==None"""
        while p:
            self.path.append((p.x, p.y))
            p = p.parent
    
    def extend_surrounds(self, node):
        """ 将当前点周围可走的点加到openList中,
            其中 不在openList中的点 设置parent、F,G,H 加进去,
                 在openList中的点  更新parent、F,G,H
            (斜向时,正向存在障碍物时不允许穿越)
        """
        motion_direction = [[1, 0], [0,  1], [-1, 0], [0,  -1], 
                            [1, 1], [1, -1], [-1, 1], [-1, -1]]  
        for dx, dy in motion_direction:
            x,y = node.x+dx, node.y+dy
            new_node = Node(x,y)
            # 位置无效,或者是障碍物, 或者已经在closeList中 
            if not self.is_valid_xy(x,y) or not self.not_obstacle(x,y) or self.in_closeList(new_node): 
                continue
            if abs(dx)+abs(dy)==2:  ## 斜向 需检查正向有无障碍物
                h_x,h_y = node.x+dx,node.y # 水平向
                v_x,v_y = node.x,node.y+dy # 垂直向
                if not self.is_valid_xy(h_x,h_y) or not self.not_obstacle(h_x,h_y) or self.in_closeList(Node(h_x,h_y)): 
                    continue
                if not self.is_valid_xy(v_x,v_y) or not self.not_obstacle(v_x,v_y) or self.in_closeList(Node(v_x,v_y)): 
                    continue
            #============ ** 关键 **             ========================
            #============ 不在openList中,加进去; ========================
            #============ 在openList中,更新      ========================
            #============对于openList和openDict来说,操作都一样 ===========
            new_g = node.g + self.cal_deltaG(node.x,node.y, x,y)
            sign=False # 是否执行操作的标志 
            if not self.in_openList(new_node): # 不在openList中
                # 加进来,设置 父节点, F, G, H
                new_node.h = self.cal_H(new_node)
                sign=True
            elif self.openDict[(new_node.x,new_node.y)].g > new_g: # 已在openList中,但现在的路径更好
                sign=True
            if sign:
                new_node.parent = node
                new_node.g = new_g
                new_node.f = self.cal_F(new_node)
                self.openDict[(new_node.x,new_node.y)]=new_node # 更新dict目录
                heapq.heappush(self.openList, (new_node.f,(new_node.x,new_node.y)))
        
    def get_minfNode(self):
        """从openList中取F=G+H值最小的 (堆-O(1))"""
        while True:
            f, best_xy=heapq.heappop(self.openList)
            if best_xy in self.openDict:
                return self.openDict[best_xy]

    def in_closeList(self, node):
        """判断是否在closeList中 (集合-O(1)) """
        return True if (node.x,node.y) in self.closeList else False
     
    def in_openList(self, node):
        """判断是否在openList中 (字典-O(1))"""
        if not (node.x,node.y) in self.openDict:
            return False
        else:
            return True

    def cal_deltaG(self,x1,y1,x2,y2):
        """ 计算两点之间行走的代价
            (为简化计算)上下左右直走,代价为1.0,斜走,代价为1.4  G值
        """
        if x1 == x2 or y1 == y2:
            return 1.0
        return 1.4
    
    def cal_H(self, node):
        """ 曼哈顿距离 估计距离目标点的距离"""
        return abs(node.x-self.endXY[0])+abs(node.y-self.endXY[1]) # 剩余路径的估计长度
    
    def cal_F(self, node):
        """ 计算F值 F = G+H 
            A*算法的精髓:已经消耗的代价G,和预估将要消耗的代价H
        """
        return node.g + node.h


def path_length(path):
    """计算路径长度"""
    l = 0
    for i in range(len(path)-1):
        x1,y1 = path[i]
        x2,y2 = path[i+1]
        if x1 == x2 or y1 == y2:
            l+=1.0
        else:
            l+=1.4
    return l


# ===== test case ===============
a = A_Star('map_1.bmp')
a.find_path()
a.plot(a.path)
print('path length:',path_length(a.path))

测试用例及结果

map1
map_3
迷宫

存在的问题

不确定是否是最优路径
原文描述:
“ If we overestimate this distance, however, it is not guaranteed to give us the shortest path. In such cases, we have what is called an “inadmissible heuristic.”.

Technically, in this example, the Manhattan method is inadmissible because it slightly overestimates the remaining distance.”
即如果我们高估了H,则不能保证最短路径。而曼哈顿距离略微高估了。

另外,笔者不确定程序是不是正确,以及是不是真正的A*算法,请大神们指正。

猜你喜欢

转载自blog.csdn.net/FengKuangXiaoZuo/article/details/105135005