TransWikia.com

C#: A* pathfinding - performance and simplicity

Code Review Asked by Xamtos on October 27, 2021

Yet another implementation of A* pathfinding. It is focused on:

  • Performance (both speed and memory allocations).
  • Readability and simplicity.
  • Well defined objects and methods.
  • Accordance with general conventions (naming, signatures, class structure, design principles etc).

Path is calculated on 2D grid using integer vectors:

public interface IPath
{
    IReadOnlyCollection<Vector2Int> Calculate(Vector2Int start, Vector2Int target, IReadOnlyCollection<Vector2Int> obstacles);
}

First, I’ll define Vector2Int. It’s pretty straightforward:

namespace AI.A_Star
{
    public readonly struct Vector2Int : IEquatable<Vector2Int>
    {
        private static readonly float Sqr = (float) Math.Sqrt(2);

        public Vector2Int(int x, int y)
        {
            X = x;
            Y = y;
        }

        public int X { get; }
        public int Y { get; }
        
        /// <summary>
        /// Estimated path distance without obstacles.
        /// </summary>
        public float DistanceEstimate()
        {
            int linearSteps = Math.Abs(Y - X);
            int diagonalSteps = Math.Max(Math.Abs(Y), Math.Abs(X)) - linearSteps;
            return linearSteps + Sqr * diagonalSteps;
        }
        
        public static Vector2Int operator +(Vector2Int a, Vector2Int b) => new Vector2Int(a.X + b.X, a.Y + b.Y);
        public static Vector2Int operator -(Vector2Int a, Vector2Int b) => new Vector2Int(a.X - b.X, a.Y - b.Y);
        public static bool operator ==(Vector2Int a, Vector2Int b) => a.X == b.X && a.Y == b.Y;
        public static bool operator !=(Vector2Int a, Vector2Int b) => !(a == b);

        public bool Equals(Vector2Int other)
            => X == other.X && Y == other.Y;

        public override bool Equals(object obj)
        {
            if (!(obj is Vector2Int))
                return false;

            var other = (Vector2Int) obj;
            return X == other.X && Y == other.Y;
        }

        public override int GetHashCode()
            => HashCode.Combine(X, Y);

        public override string ToString()
            => $"({X}, {Y})";
    }
}

IEquatable interface is implemented for future optimizations. Sqr value is cached because there is no need to calculate it more than once.

DistanceEstimate() used for heuristic cost calculation. It is more accurate than Math.Abs(X) + Math.Abs(Y) version, which overestimates diagonal cost.


Next: PathNode which represents single location on grid:

namespace AI.A_Star
{
    internal interface IPathNode
    {
        Vector2Int Position { get; }
        [CanBeNull] IPathNode Parent { get; }
        float TraverseDistance { get; }
        float HeuristicDistance { get; }
        float EstimatedTotalCost { get; }
    }
    
    internal readonly struct PathNode : IPathNode
    {

        public PathNode(Vector2Int position, float traverseDistance, float heuristicDistance, [CanBeNull] IPathNode parent)
        {

            Position = position;
            TraverseDistance = traverseDistance;
            HeuristicDistance = heuristicDistance;
            Parent = parent;
        }

        public Vector2Int Position { get; }
        public IPathNode Parent { get; }
        public float TraverseDistance { get; }
        public float HeuristicDistance { get; }

        public float EstimatedTotalCost => TraverseDistance + HeuristicDistance;
    }
}

PathNode is defined as struct: there will be a lot of node creation. However, it has to include a reference to it’s parent, so I’m using IPathNode interface to avoid cycle inside the struct.


Next: creator of Node neighbours:

namespace AI.A_Star
{
    internal class PathNodeNeighbours
    {
        private static readonly (Vector2Int position, float cost)[] NeighboursTemplate = {
            (new Vector2Int(1, 0), 1),
            (new Vector2Int(0, 1), 1),
            (new Vector2Int(-1, 0), 1),
            (new Vector2Int(0, -1), 1),
            (new Vector2Int(1, 1), (float) Math.Sqrt(2)),
            (new Vector2Int(1, -1), (float) Math.Sqrt(2)),
            (new Vector2Int(-1, 1), (float) Math.Sqrt(2)),
            (new Vector2Int(-1, -1), (float) Math.Sqrt(2))
        };

        private readonly PathNode[] buffer = new PathNode[NeighboursTemplate.Length];

        public PathNode[] FillAdjacentNodesNonAlloc(IPathNode parent, Vector2Int target)
        {
            var i = 0;
            foreach ((Vector2Int position, float cost) in NeighboursTemplate)
            {
                Vector2Int nodePosition = position + parent.Position;
                float traverseDistance = parent.TraverseDistance + cost;
                float heuristicDistance = (nodePosition - target).DistanceEstimate();
                buffer[i++] = new PathNode(nodePosition, traverseDistance, heuristicDistance, parent);
            }

            return buffer;
        }
    }
}

Another straightforward class, which simply creates neighboring Nodes around the parent on the grid (including diagonal ones). It uses array buffer, avoiding creation of unnecessary collections.

Code didn’t seem quite right inside PathNode struct or inside Path class. It felt like minor SRP violation – so I moved it to separate class.


Now, the interesting one:

namespace AI.A_Star
{
    public class Path : IPath
    {
        private readonly PathNodeNeighbours neighbours = new PathNodeNeighbours();
        private readonly int maxSteps;
        
        private readonly SortedSet<PathNode> frontier = new SortedSet<PathNode>(Comparer<PathNode>.Create((a, b) => a.EstimatedTotalCost.CompareTo(b.EstimatedTotalCost)));
        private readonly HashSet<Vector2Int> ignoredPositions = new HashSet<Vector2Int>();
        private readonly List<Vector2Int> output = new List<Vector2Int>();

        public Path(int maxSteps)
        {
            this.maxSteps = maxSteps;
        }

        public IReadOnlyCollection<Vector2Int> Calculate(Vector2Int start, Vector2Int target, IReadOnlyCollection<Vector2Int> obstacles)
        {
            if (!TryGetPathNodes(start, target, obstacles, out IPathNode node))
                return Array.Empty<Vector2Int>();

            output.Clear();
            while (node != null)
            {
                output.Add(node.Position);
                node = node.Parent;
            }

            return output.AsReadOnly();
        }
        
        private bool TryGetPathNodes(Vector2Int start, Vector2Int target, IReadOnlyCollection<Vector2Int> obstacles, out IPathNode node)
        {
            frontier.Clear();
            ignoredPositions.Clear();

            frontier.Add(new PathNode(start, 0, 0, null));
            ignoredPositions.UnionWith(obstacles);
            var step = 0;
            
            while (frontier.Count > 0 && ++step <= maxSteps)
            {
                PathNode current = frontier.Min;
                if (current.Position.Equals(target))
                {
                    node = current;
                    return true;
                }

                ignoredPositions.Add(current.Position);
                frontier.Remove(current);
                GenerateFrontierNodes(current, target);
            }

            // All nodes analyzed - no path detected.
            node = default;
            return false;
        }

        private void GenerateFrontierNodes(PathNode parent, Vector2Int target)
        {
            // Get adjacent positions and remove already checked.
            var nodes = neighbours.FillAdjacentNodesNonAlloc(parent, target);
                
            foreach(PathNode newNode in nodes)
            {
                // Position is already checked or occupied by an obstacle.
                if (ignoredPositions.Contains(newNode.Position)) 
                    continue;
                    
                // Node is not present in queue.
                if (!frontier.TryGetValue(newNode, out PathNode existingNode))
                    frontier.Add(newNode);

                // Node is present in queue and new optimal path is detected.
                else if (newNode.TraverseDistance < existingNode.TraverseDistance)
                {
                    frontier.Remove(existingNode);
                    frontier.Add(newNode);
                }
            }
        }
    }
}

Collections are defined inside class body, not inside methods: this way in subsequent calculations there will be no need in collection creation and resizing (assuming calculated paths are always have somewhat same length).

SortedSet and HashSet allows calculation to complete 150-200 times faster; List usage is miserably slow.

TryGetPathNodes() returns child node as out parameter; Calculate() iterates through all node’s parents and returns collection of their positions.


I’m really uncertain about following things:

  1. PathNode struct contains IPathNode reference. It doesn’t seem normal at all.

  2. The rule of thumb, never return reference to mutable collection. However, PathNodeNeighbours class returns original array buffer itself instead of it’s copy. Is that tolerable behavior for internal classes (which are expected to be used in one single place)? Or it is always preferable to provide external buffer and fill it via CopyTo()?
    I’d prefer to keep classes as clean as possible, without multiple ‘temporary’ arrays.

  3. 85% of memory allocations are happening inside GenerateFrontierNodes() method. Half of that caused by SortedSet.Add() method. Nothing I can do there?

  4. Boxing from value PathNode to reference IPathNode causes another half of allocations.
    But making PathNode a class instead of struct makes things worse! There are thousands of PathNode‘s! And I have to provide a reference to a parent to each node: otherwise there will be no way to track final path through nodes.


Are there any poor solutions used in my pathfinding algorithm? Are there potential improvements in performance to achieve? How can I further improve readability?

2 Answers

(For anyone who stumbles across this question and decides to use the sample code).

Actually, the following collection does not work as intended:

        private readonly SortedSet<PathNode> frontier = new SortedSet<PathNode>(Comparer<PathNode>.Create((a, b) => a.EstimatedTotalCost.CompareTo(b.EstimatedTotalCost)));

It disallows duplicate nodes with the same estimated cost although their positions are different. It increases pathfinding speed dramatically (there are a lot of nodes with the same cost), but may lead to inaccurate paths or false-negative results.

I didn't find any built-in collection with keys sorting and duplicate keys and fast lookup and low allocations overhead. There is non-generic binary heap implementation instead of SortedSet, as @harold suggested:

internal interface IBinaryHeap<in TKey, T> where TKey : IEquatable<TKey>
{
    void Enqueue(T item);
    T Dequeue();
    void Clear();
    bool TryGet(TKey key, out T value);
    void Modify(T value);
    int Count { get; }
}

internal class BinaryHeap : IBinaryHeap<Vector2Int, PathNode> 
{
    private readonly IDictionary<Vector2Int, int> map;
    private readonly IList<PathNode> collection;
    private readonly IComparer<PathNode> comparer;
    
    public BinaryHeap(IComparer<PathNode> comparer)
    {
        this.comparer = comparer;
        collection = new List<PathNode>();
        map = new Dictionary<Vector2Int, int>();
    }

    public int Count => collection.Count;

    public void Enqueue(PathNode item)
    {
        collection.Add(item);
        int i = collection.Count - 1;
        map[item.Position] = i;
        while(i > 0)
        {
            int j = (i - 1) / 2;
            
            if (comparer.Compare(collection[i], collection[j]) <= 0)
                break;

            Swap(i, j);
            i = j;
        }
    }

    public PathNode Dequeue()
    {
        if (collection.Count == 0) return default;
        
        var result = collection.First();
        RemoveRoot();
        map.Remove(result.Position);
        return result;
    }
    
    public bool TryGet(Vector2Int key, out PathNode value)
    {
        if (!map.TryGetValue(key, out int index))
        {
            value = default;
            return false;
        }
        
        value = collection[index];
        return true;
    }

    public void Modify(PathNode value)
    {
        if (!map.TryGetValue(value.Position, out int index))
            throw new KeyNotFoundException(nameof(value));

        collection.RemoveAt(index);
        Enqueue(value);
    }

    public void Clear()
    {
        collection.Clear();
        map.Clear();
    }

    private void RemoveRoot()
    {
        collection[0] = collection.Last();
        map[collection[0].Position] = 0;
        collection.RemoveAt(collection.Count - 1);

        int i = 0;
        while(true)
        {
            int largest = LargestIndex(i);
            if (largest == i)
                return;

            Swap(i, largest);
            i = largest;
        }
    }

    private void Swap(int i, int j)
    {
        PathNode temp = collection[i];
        collection[i] = collection[j];
        collection[j] = temp;
        map[collection[i].Position] = i;
        map[collection[j].Position] = j;
    }

    private int LargestIndex(int i)
    {
        int leftInd = 2 * i + 1;
        int rightInd = 2 * i + 2;
        int largest = i;

        if (leftInd < collection.Count && comparer.Compare(collection[leftInd], collection[largest]) > 0) largest = leftInd;

        if (rightInd < collection.Count && comparer.Compare(collection[rightInd], collection[largest]) > 0) largest = rightInd;
        
        return largest;
    }
}

Generic version:

internal class BinaryHeap<TKey, T> : IBinaryHeap<TKey, T> where TKey : IEquatable<TKey>
{
    private readonly IDictionary<TKey, int> map;
    private readonly IList<T> collection;
    private readonly IComparer<T> comparer;
    private readonly Func<T, TKey> lookupFunc;
    
    public BinaryHeap(IComparer<T> comparer, Func<T, TKey> lookupFunc)
    {
        this.comparer = comparer;
        this.lookupFunc = lookupFunc;
        collection = new List<T>();
        map = new Dictionary<TKey, int>();
    }

    public int Count => collection.Count;

    public void Enqueue(T item)
    {
        collection.Add(item);
        int i = collection.Count - 1;
        map[lookupFunc(item)] = i;
        while(i > 0)
        {
            int j = (i - 1) / 2;
            
            if (comparer.Compare(collection[i], collection[j]) <= 0)
                break;

            Swap(i, j);
            i = j;
        }
    }

    public T Dequeue()
    {
        if (collection.Count == 0) return default;
        
        var result = collection.First();
        RemoveRoot();
        map.Remove(lookupFunc(result));
        return result;
    }

    public void Clear()
    {
        collection.Clear();
        map.Clear();
    }

    public bool TryGet(TKey key, out T value)
    {
        if (!map.TryGetValue(key, out int index))
        {
            value = default;
            return false;
        }
        
        value = collection[index];
        return true;
    }

    public void Modify(T value)
    {
        if (!map.TryGetValue(lookupFunc(value), out int index))
            throw new KeyNotFoundException(nameof(value));
        
        collection[index] = value;
    }
    
    private void RemoveRoot()
    {
        collection[0] = collection.Last();
        map[lookupFunc(collection[0])] = 0;
        collection.RemoveAt(collection.Count - 1);

        int i = 0;
        while(true)
        {
            int largest = LargestIndex(i);
            if (largest == i)
                return;

            Swap(i, largest);
            i = largest;
        }
    }

    private void Swap(int i, int j)
    {
        T temp = collection[i];
        collection[i] = collection[j];
        collection[j] = temp;
        map[lookupFunc(collection[i])] = i;
        map[lookupFunc(collection[j])] = j;
    }

    private int LargestIndex(int i)
    {
        int leftInd = 2 * i + 1;
        int rightInd = 2 * i + 2;
        int largest = i;

        if (leftInd < collection.Count && comparer.Compare(collection[leftInd], collection[largest]) > 0) largest = leftInd;

        if (rightInd < collection.Count && comparer.Compare(collection[rightInd], collection[largest]) > 0) largest = rightInd;
        
        return largest;
    }
}

Answered by Xamtos on October 27, 2021

Boxing from value PathNode to reference IPathNode causes another half of allocations. But making PathNode a class instead of struct makes things worse! There are thousands of PathNode's! And I have to provide a reference to a parent to each node: otherwise there will be no way to track final path through nodes.

It's normally good software engineering practice to have the interface, probably, but for this situation I recommend removing it. Boxing should be avoided, not by switching to classes, but by removing the boxing. So let's work around needing a reference to a node.

There are other ways to remember the "parent" information, that do not involve a reference to a node. For example, a Dictionary<Vector2Int, Vector2Int>, or Vector2Int[,], or Direction[,], there are many variants. When at the end of A* the path is reconstructed, the nodes are mostly irrelevant: only the positions matter, so only the positions need to be accessible, and they still are with these solutions.

85% of memory allocations are happening inside GenerateFrontierNodes() method. Half of that caused by SortedSet.Add() method. Nothing I can do there?

There is something that can be done: use a binary heap. Actually SortedSet is not that good to begin with, it has decent asymptotic behaviour, but its contant factor is poor. A binary heap is great for this use. It's simple to implement, low-overhead, low-allocation. It doesn't keep the collection completely sorted but A* does not require that.

Then "the update problem" needs to be solved. Currently, it is solved by frontier.Remove and frontier.Add to re-add the node with the new weight. A binary heap is not searchable (not properly), but a Dictionary<Vector2Int, int> can be maintained on the side to record the index in the heap of a node with a given location. Maintaining that dictionary is not a great burden for the heap, and allows an O(log n) "change weight" operation.

Answered by harold on October 27, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP