# Sample code from https://www.redblobgames.com/pathfinding/a-star/ # Copyright 2014 Red Blob Games # # Feel free to use this code in your own projects, including commercial projects # License: Apache v2.0 from __future__ import annotations # some of these types are deprecated: https://www.python.org/dev/peps/pep-0585/ from typing import Protocol, Iterator, Tuple, TypeVar, Optional T = TypeVar('T') Location = TypeVar('Location') class Graph(Protocol): def neighbors(self, id: Location) -> list[Location]: pass class SimpleGraph: def __init__(self): self.edges: dict[Location, list[Location]] = {} def neighbors(self, id: Location) -> list[Location]: return self.edges[id] example_graph = SimpleGraph() example_graph.edges = { 'A': ['B'], 'B': ['C'], 'C': ['B', 'D', 'F'], 'D': ['C', 'E'], 'E': ['F'], 'F': [], } import collections class Queue: def __init__(self): self.elements = collections.deque() def empty(self) -> bool: return not self.elements def put(self, x: T): self.elements.append(x) def get(self) -> T: return self.elements.popleft() # utility functions for dealing with square grids def from_id_width(id, width): return (id % width, id // width) def draw_tile(graph, id, style): r = " . " if 'number' in style and id in style['number']: r = " %-2d" % style['number'][id] if 'point_to' in style and style['point_to'].get(id, None) is not None: (x1, y1) = id (x2, y2) = style['point_to'][id] if x2 == x1 + 1: r = " > " if x2 == x1 - 1: r = " < " if y2 == y1 + 1: r = " v " if y2 == y1 - 1: r = " ^ " if 'path' in style and id in style['path']: r = " @ " if 'start' in style and id == style['start']: r = " A " if 'goal' in style and id == style['goal']: r = " Z " if id in graph.walls: r = "###" return r def draw_grid(graph, **style): print("___" * graph.width) for y in range(graph.height): for x in range(graph.width): print("%s" % draw_tile(graph, (x, y), style), end="") print() print("~~~" * graph.width) # data from main article DIAGRAM1_WALLS = [from_id_width(id, width=30) for id in [21,22,51,52,81,82,93,94,111,112,123,124,133,134,141,142,153,154,163,164,171,172,173,174,175,183,184,193,194,201,202,203,204,205,213,214,223,224,243,244,253,254,273,274,283,284,303,304,313,314,333,334,343,344,373,374,403,404,433,434]] GridLocation = Tuple[int, int] class SquareGrid: def __init__(self, width: int, height: int): self.width = width self.height = height self.walls: list[GridLocation] = [] def in_bounds(self, id: GridLocation) -> bool: (x, y) = id return 0 <= x < self.width and 0 <= y < self.height def passable(self, id: GridLocation) -> bool: return id not in self.walls def neighbors(self, id: GridLocation) -> Iterator[GridLocation]: (x, y) = id neighbors = [(x+1, y), (x-1, y), (x, y-1), (x, y+1)] # E W N S # see "Ugly paths" section for an explanation: if (x + y) % 2 == 0: neighbors.reverse() # S N W E results = filter(self.in_bounds, neighbors) results = filter(self.passable, results) return results class WeightedGraph(Graph): def cost(self, from_id: Location, to_id: Location) -> float: pass class GridWithWeights(SquareGrid): def __init__(self, width: int, height: int): super().__init__(width, height) self.weights: dict[GridLocation, float] = {} def cost(self, from_node: GridLocation, to_node: GridLocation) -> float: return self.weights.get(to_node, 1) diagram4 = GridWithWeights(10, 10) diagram4.walls = [(1, 7), (1, 8), (2, 7), (2, 8), (3, 7), (3, 8)] diagram4.weights = {loc: 5 for loc in [(3, 4), (3, 5), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), (6, 2), (6, 3), (6, 4), (6, 5), (6, 6), (6, 7), (7, 3), (7, 4), (7, 5)]} import heapq class PriorityQueue: def __init__(self): self.elements: list[tuple[float, T]] = [] def empty(self) -> bool: return not self.elements def put(self, item: T, priority: float): heapq.heappush(self.elements, (priority, item)) def get(self) -> T: return heapq.heappop(self.elements)[1] def dijkstra_search(graph: WeightedGraph, start: Location, goal: Location): frontier = PriorityQueue() frontier.put(start, 0) came_from: dict[Location, Optional[Location]] = {} cost_so_far: dict[Location, float] = {} came_from[start] = None cost_so_far[start] = 0 while not frontier.empty(): current: Location = frontier.get() if current == goal: break for next in graph.neighbors(current): new_cost = cost_so_far[current] + graph.cost(current, next) if next not in cost_so_far or new_cost < cost_so_far[next]: cost_so_far[next] = new_cost priority = new_cost frontier.put(next, priority) came_from[next] = current return came_from, cost_so_far # thanks to @m1sp for this simpler version of # reconstruct_path that doesn't have duplicate entries def reconstruct_path(came_from: dict[Location, Location], start: Location, goal: Location) -> list[Location]: current: Location = goal path: list[Location] = [] if goal not in came_from: # no path was found return [] while current != start: path.append(current) current = came_from[current] path.append(start) # optional path.reverse() # optional return path diagram_nopath = GridWithWeights(10, 10) diagram_nopath.walls = [(5, row) for row in range(10)] def heuristic(a: GridLocation, b: GridLocation) -> float: (x1, y1) = a (x2, y2) = b return abs(x1 - x2) + abs(y1 - y2) def a_star_search(graph: WeightedGraph, start: Location, goal: Location): frontier = PriorityQueue() frontier.put(start, 0) came_from: dict[Location, Optional[Location]] = {} cost_so_far: dict[Location, float] = {} came_from[start] = None cost_so_far[start] = 0 while not frontier.empty(): current: Location = frontier.get() if current == goal: break for next in graph.neighbors(current): new_cost = cost_so_far[current] + graph.cost(current, next) if next not in cost_so_far or new_cost < cost_so_far[next]: cost_so_far[next] = new_cost priority = new_cost + heuristic(next, goal) frontier.put(next, priority) came_from[next] = current return came_from, cost_so_far def breadth_first_search(graph: Graph, start: Location, goal: Location): frontier = Queue() frontier.put(start) came_from: dict[Location, Optional[Location]] = {} came_from[start] = None while not frontier.empty(): current: Location = frontier.get() if current == goal: break for next in graph.neighbors(current): if next not in came_from: frontier.put(next) came_from[next] = current return came_from class SquareGridNeighborOrder(SquareGrid): def neighbors(self, id): (x, y) = id neighbors = [(x + dx, y + dy) for (dx, dy) in self.NEIGHBOR_ORDER] results = filter(self.in_bounds, neighbors) results = filter(self.passable, results) return list(results) def test_with_custom_order(neighbor_order): if neighbor_order: g = SquareGridNeighborOrder(30, 15) g.NEIGHBOR_ORDER = neighbor_order else: g = SquareGrid(30, 15) g.walls = DIAGRAM1_WALLS start, goal = (8, 7), (27, 2) came_from = breadth_first_search(g, start, goal) draw_grid(g, path=reconstruct_path(came_from, start=start, goal=goal), point_to=came_from, start=start, goal=goal) class GridWithAdjustedWeights(GridWithWeights): def cost(self, from_node, to_node): prev_cost = super().cost(from_node, to_node) nudge = 0 (x1, y1) = from_node (x2, y2) = to_node if (x1 + y1) % 2 == 0 and x2 != x1: nudge = 1 if (x1 + y1) % 2 == 1 and y2 != y1: nudge = 1 return prev_cost + 0.001 * nudge