TransWikia.com

Generalised NxN Sudoku solver using heap

Code Review Asked by srt1104 on November 25, 2021

My implementation of a Sudoku solver. It isn’t done using the most naive way but still it does an exhaustive search with some assistance from a heap. The only constraints I have used is the basic rules of Sudoku (a number can occur only once in a row, column and it’s box). There probably are more techniques or reasonings with which it can be improved but before that I would like to get this as optimized as possible. I would appreciate any advice on how to make it faster and how my code can be made compatible with modern C++ best practices. Thank you for your time!

Edit: I forgot to mention the main idea here. The heap is used to choose the next cell having the least total of possible numbers it can be filled with. When you place one of the possible numbers in that cell say n in cell (x, y), then n is removed from the list of possibilities of all cells in row x, column y and the box which (x, y) belongs to AND these changes are reflected in the heap. To backtrack, n is added back to those lists (these changes too are reflected in the heap). When the heap becomes empty, all cells have been filled and we have found a solution.

#include <iostream>
#include <vector>
#include <unordered_map>

using namespace std;

// table to calculate no. of set bits in a number
vector<int> bitset_table(256);

// function to print the board
ostream& operator<< (ostream& out, const vector<vector<int>>& M)
{
    for (const vector<int>& V : M)
    {
        for (int e : V)
            out << e << ' ';
        out << endl;
    }
    return out;
}

// function used by heap to order it's elements based on the contents of `*ptr1` and `*ptr2`
bool isLower(const int* ptr1, const int* ptr2)
{
    int size1, size2;

    size1 = bitset_table[*ptr1 & 0xff] + bitset_table[*ptr1 >> 8 & 0xff] +
        bitset_table[*ptr1 >> 16 & 0xff] + bitset_table[*ptr1 >> 24 & 0xff];
    size2 = bitset_table[*ptr2 & 0xff] + bitset_table[*ptr2 >> 8 & 0xff] +
        bitset_table[*ptr2 >> 16 & 0xff] + bitset_table[*ptr2 >> 24 & 0xff];
    return size1 < size2;
}

class Heap
{
private:
    int heap_size;                          // no. of elements in the heap
    vector<int*> A;                         // heap container of elementes of type `int*` (for 1 by 1 mapping), note that `A.size()` can be greater than `heap_size`
    unordered_map<int*, int> mapping;       // mapping to keep track of the index of `int*` in `A`

    int parent(int i) { return (i - 1) / 2; }
    int left(int i) { return 2 * i + 1; }
    int right(int i) { return 2 * i + 2; }

    // taken from CLRS. Puts A[i] at the correct place by "heapifying" the heap (requires A[left(i)] and A[right(i)] to follow heap propertey.)
    void minHeapify(int i)
    {
        int l, r, smallest;

        l = left(i);
        r = right(i);
        smallest = i;
        if (l < heap_size && isLower(A[l], A[i]))
            smallest = l;
        if (r < heap_size && isLower(A[r], A[smallest]))
            smallest = r;

        if (smallest != i)
        {
            swap(mapping[A[i]], mapping[A[smallest]]);
            swap(A[i], A[smallest]);
            minHeapify(smallest);
        }
    }

    // updated key at A[i] is pushed towards the top of the heap if it's priority is high otherwise towards the bottom.
    void heapUpdateKey(int i)
    {
        if (i == 0 || !isLower(A[i], A[parent(i)]))
            minHeapify(i);
        else
        {
            int p = parent(i);
            while (i > 0 && isLower(A[i], A[p]))
            {
                swap(mapping[A[i]], mapping[A[p]]);
                swap(A[i], A[p]);
                i = p;
                p = parent(i);
            }
        }
    }

public:
    Heap() : heap_size(0) {}

    // `opt = 0` means delete `val` from `*ptr`, otherwise insert.
    // if it fails to detele, return false. (this fact is used in `search` method)
    bool heapUpdateKey(int *ptr, int opt, int val)
    {
        if (mapping.find(ptr) == mapping.cend() || (opt == 0 && !(*ptr & (1 << val))))
            return false;

        if (opt == 0)
            *ptr &= ~(1 << val);
        else
            *ptr |= 1 << val;
        heapUpdateKey(mapping[ptr]);
        return true;
    }

    // inserts element at the end of the heap and calls `heapUpdateKey` on it
    void insert(int *ptr)
    {
        if (heap_size < A.size())
            A[heap_size] = ptr;
        else
            A.push_back(ptr);
        mapping[ptr] = heap_size;
        heapUpdateKey(heap_size++);
    }

    // returns the element at the top of the heap and heapifies the rest of the heap.
    int* heapExtractMin()
    {
        //if (heap_size == 0)
            //return nullptr;

        int *res = A[0];
        mapping.erase(res);
        A[0] = A[--heap_size];
        mapping[A[0]] = 0;
        minHeapify(0);
        return res;
    }

    bool isEmpty()
    {
        return heap_size == 0;
    }
};

class Solve
{
private:
    int N;

    // recursive function which basically performs an exhaustive search using backtracking
    bool search(Heap& H, unordered_map<int*, unordered_map<int, vector<int*>>>& adj, vector<vector<int>>& board, unordered_map<int*, pair<int, int>>& mapping)
    {
        if (H.isEmpty())
            return true;

        int *ptr = H.heapExtractMin();
        pair<int, int>& p = mapping[ptr];
        for (int k = 1; k <= N; ++k)
            if (*ptr & (1 << k))
            {
                board[p.first][p.second] = k;

                vector<int*> deleted_from;
                for (int *ptr2 : adj[ptr][k])
                    if (H.heapUpdateKey(ptr2, 0, k))
                        deleted_from.push_back(ptr2);

                if (search(H, adj, board, mapping))
                    return true;

                for (int *ptr2 : deleted_from)
                    H.heapUpdateKey(ptr2, 1, k);
            }
        H.insert(ptr);
        return false;
    }

public:
    Solve() {}

    Solve(vector<vector<int>>& board) : N(board.size())
    {
        int n = (int)ceil(sqrt(N));

        if (n*n != N)
            exit(0);

        // look at already filled cells like number 5 at cell say (x, y).
        // set the 5th bit at rows[x], columns[y] and the 3x3 (for 9x9 Sudoku) box which (x, y) belongs to.
        vector<int> rows(N), columns(N), boxes(N);
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < N; ++j)
                if (board[i][j])
                {
                    int bit = 1 << board[i][j];
                    rows[i] |= bit;
                    columns[j] |= bit;
                    boxes[(i / n)*n + (j / n)] |= bit;
                }

        // possibilities[i][j] = list of numbers which the cell (i, j) can be filled with.
        // &possibilities[i][j] is the pointer int* used in the heap.
        vector<vector<int>> possibilities(N, vector<int>(N));
        // mapping used in `search` method to get the coordinates (i, j) which &possibilities[i][j] represents.
        unordered_map<int*, pair<int, int>> mapping;
        // look at yet to be filled cells and calculate it's possibilities[i][j]
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < N; ++j)
                if (!board[i][j])
                {
                    mapping.emplace(&possibilities[i][j], make_pair(i, j));
                    for (int k = 1; k <= N; ++k)
                    {
                        int bit = 1 << k;
                        if (!(rows[i] & bit) && !(columns[j] & bit) && !(boxes[(i / n)*n + (j / n)] & bit))
                            possibilities[i][j] |= bit;
                    }
                }

        // adjacency list used in 'search' method.
        // adj[p][k] is the list of pointers (of cells, i.e., &possibilities[i][j]) which are adjacent to cell at pointer p (same row, column and box)
        // and have their kth bit set. It seems complex and conjested but it simply creates adjencty list for adj[p][k] for all values of p and k.
        unordered_map<int*, unordered_map<int, vector<int*>>> adj;
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < N; ++j)
                if (possibilities[i][j])
                {
                    for (int k = 0; k < N; ++k)
                        if (!board[i][k] && k / n != j / n)
                            for (int l = 1; l <= N; ++l)
                                if (possibilities[i][k] & (1 << l))
                                    adj[&possibilities[i][j]][l].push_back(&possibilities[i][k]);

                    for (int k = 0; k < N; ++k)
                        if (!board[k][j] && k / n != i / n)
                            for (int l = 1; l <= N; ++l)
                                if (possibilities[k][j] & (1 << l))
                                    adj[&possibilities[i][j]][l].push_back(&possibilities[k][j]);

                    int ti, tj;
                    ti = (i / n)*n, tj = (j / n)*n;
                    for (int tti = 0; tti < n; ++tti)
                        for (int ttj = 0; ttj < n; ++ttj)
                            if (!board[ti + tti][tj + ttj] && (ti + tti != i || tj + ttj != j))
                                for (int l = 1; l <= N; ++l)
                                    if (possibilities[ti + tti][tj + ttj] & (1 << l))
                                        adj[&possibilities[i][j]][l].push_back(&possibilities[ti + tti][tj + ttj]);
                }

        // create heap and insert the address (int*) of the list of possibilities of unfilled cells.
        Heap H;
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < N; ++j)
                if (possibilities[i][j])
                    H.insert(&possibilities[i][j]);

        if (search(H, adj, board, mapping))
            cout << board << endl;
    }
};

int main()
{
    // fill the bitset_table (bitset_table[i] = no. of set bits of i)
    for (int i = 1; i < bitset_table.size(); ++i)
        bitset_table[i] = (i & 1) + bitset_table[i / 2];

    int N;
    cin >> N;
    vector<vector<int>> board(N, vector<int>(N));
    for (int i = 0; i < N; ++i)
        for (int j = 0; j < N; ++j)
            cin >> board[i][j];
    Solve obj(board);
}

Some puzzles you can try:

9
8 0 0 0 0 0 0 0 0
0 0 3 6 0 0 0 0 0
0 7 0 0 9 0 2 0 0
0 5 0 0 0 7 0 0 0
0 0 0 0 4 5 7 0 0
0 0 0 1 0 0 0 3 0
0 0 1 0 0 0 0 6 8
0 0 8 5 0 0 0 1 0
0 9 0 0 0 0 4 0 0

16
0 2 14 0 0 0 16 4 0 0 0 1 0 0 5 0
0 0 9 0 0 10 0 1 0 0 0 0 0 4 0 0
0 0 0 0 13 6 0 0 0 14 0 0 15 12 0 16
6 5 10 0 8 2 0 0 0 12 0 0 0 1 0 7
9 0 5 4 1 0 0 2 0 0 0 0 12 0 7 0
0 0 0 0 11 0 0 13 0 3 0 0 0 0 0 1
0 0 0 0 16 0 0 0 13 10 15 9 14 0 4 0
10 0 0 11 0 4 8 15 0 0 0 0 5 0 13 0
0 11 0 1 0 0 0 0 10 7 4 0 3 0 0 6
0 7 0 2 14 16 6 10 0 0 0 11 0 0 0 0
16 0 0 0 0 0 1 0 12 0 0 14 0 0 0 0
0 4 0 10 0 0 0 0 15 0 0 2 16 5 0 11
11 0 12 0 0 0 14 0 0 0 13 7 0 9 6 2
8 0 7 9 0 0 11 0 0 0 14 10 0 0 0 0
0 0 4 0 0 0 0 0 11 0 2 0 0 8 0 0
0 6 0 0 12 0 0 0 9 8 0 0 0 14 1 0

25
0 0 12 6 0 0 7 0 18 0 5 24 0 10 1 0 0 4 0 0 0 0 0 0 0 
2 0 19 0 13 0 0 0 10 0 0 0 0 0 0 0 0 18 5 0 0 0 0 0 1 
0 0 0 0 0 0 0 22 0 0 0 0 3 0 2 0 0 14 12 0 16 8 25 0 0 
0 16 0 0 0 2 23 0 0 13 12 22 0 0 0 21 15 19 3 0 0 0 0 14 0 
23 0 24 0 0 0 0 0 25 8 4 0 16 19 21 0 0 7 0 0 0 3 12 0 9 
0 4 0 2 0 0 0 0 0 0 0 10 0 24 12 17 16 0 0 0 5 0 0 0 0 
0 0 9 0 0 6 25 0 0 0 8 0 5 3 0 0 0 0 0 0 20 0 0 18 19 
15 0 10 11 0 0 0 18 12 19 0 0 0 0 0 0 0 23 0 0 7 0 0 4 0 
0 0 0 0 0 0 0 14 0 22 0 0 18 16 20 0 6 11 13 0 0 0 0 0 0 
0 22 0 25 0 0 1 17 5 4 7 0 0 14 0 8 3 21 0 0 11 0 0 0 6 
0 20 13 15 0 0 0 0 0 0 9 0 0 2 0 25 0 1 8 0 0 5 0 21 0 
0 1 0 0 0 0 16 10 0 7 0 0 4 20 0 0 9 0 0 14 0 24 0 17 0 
25 2 5 0 0 0 0 0 13 0 0 0 0 0 22 0 0 0 0 0 19 1 8 0 0 
0 0 7 21 0 0 12 0 2 17 0 0 0 18 6 16 0 0 15 0 0 13 0 10 0 
8 10 18 12 16 9 0 0 0 5 0 0 0 0 19 0 0 17 0 21 0 15 0 0 22 
0 8 0 0 15 0 3 0 6 0 21 0 0 7 0 18 14 5 0 1 0 0 0 0 0 
0 0 0 19 0 1 0 16 11 0 0 0 10 22 25 15 0 0 0 0 0 0 21 0 0 
0 3 1 0 21 0 0 4 0 0 0 0 2 0 13 0 24 25 0 0 14 0 0 6 0 
0 0 0 0 0 0 0 15 0 12 14 0 6 17 24 0 0 0 0 0 0 0 13 0 0 
0 5 23 16 4 0 13 24 7 2 0 9 0 0 15 3 0 22 0 0 0 0 0 0 8 
0 0 25 20 2 0 19 0 0 0 0 1 0 0 0 0 21 3 0 0 12 0 0 0 0 
16 12 0 5 0 11 21 0 23 0 0 15 0 0 0 0 19 9 0 0 0 0 0 25 10 
0 0 0 0 9 20 22 7 4 0 3 0 14 25 18 0 11 0 0 0 0 0 1 0 15 
24 0 6 0 22 8 0 25 14 0 10 11 0 9 0 20 1 16 0 7 0 23 0 0 13 
14 13 21 1 0 0 5 0 0 0 6 0 22 0 23 10 0 0 0 2 0 0 18 7 11

The 9×9 is supposedly the "hardest 9×9 Sudoku puzzle". Takes no time. The 16×16 is another hard one and takes about 20 minutes on my machine lol.

One Answer

Freebies

Looking at the performance profile for the 16x16 puzzle (there is a profiler built into Visual Studio 2017, which you said you are using, and I used that, so you can reproduce this), I see that deleted_from.push_back(ptr2); is hotter than it deserves. That indicates the vector is growing too often.

So change this:

vector<int*> deleted_from;

To this:

vector<int*> deleted_from(8);

Before: 6 seconds. After: 5.5 seconds. That's significant, but a trivial change to the code.

Reading between the lines of the profile, it turns out that isLower is taking a substantial amount of time. It is not directly implicated by the profile, but the places where it is called are redder than they ought to be. It really should be trivial, but it's not.

Here is an other way to write it:

#include <intrin.h>

...

// function used by heap to order it's elements based on the contents of `*ptr1` and `*ptr2`
bool isLower(const int* ptr1, const int* ptr2)
{
    return _mm_popcnt_u32(*ptr1) < _mm_popcnt_u32(*ptr2);
}

Before: 5.5 seconds. After: 5.0 seconds. That's nice, and it even made the code simpler.

The Heap

It should be no surprise that a lot of time is spent on modifying the heap. So let's tinker with it.

This logic:

   if (l < heap_size && isLower(A[l], A[i]))
       smallest = l;
   if (r < heap_size && isLower(A[r], A[smallest]))
       smallest = r;

Can be rewritten to:

if (r < heap_size)
{
    smallest = isLower(A[l], A[r]) ? l : r;
    smallest = isLower(A[i], A[smallest]) ? i : smallest;
}
else if (l < heap_size)
    smallest = isLower(A[l], A[i]) ? l : i;

It looks like it should be about the same, but it's not.

Before: 5.0 seconds. After: 2.0 seconds.

What?! The biggest difference I saw in the disassembly of the function was that cmovl was used this way, but not before. Conditional-move is better than a badly-predicted branch, but worse than a well-predicted branch - it makes sense that these branches would be badly predicted, after all they depend on which path the data item takes "down the heap", which is some semi-randomly zig-zagging path.

This on the other hand does not help:

smallest = (l < heap_size && isLower(A[l], A[i])) ? l : i;
smallest = (r < heap_size && isLower(A[r], A[smallest])) ? r : smallest;

When MSVC chooses to use a cmov or not is a mystery. Clearly it has a large impact, but there seems to be no reliable way to ask for a cmov.

An extra trick is using that what this "minHeapify" is doing is moving items up the heap along a path, and dropping the item which it was originally called on into the open spot at the end. That isn't how it's doing it though: it's doing a lot of swaps. In total it's doing twice as many assignments as are necessary. That could be changed such as this:

void minHeapify(int i)
{
    int l, r, smallest;
    int* item = A[i];
    do {
        l = left(i);
        r = right(i);
        smallest = i;

        if (r < heap_size)
        {
            smallest = isLower(A[l], A[r]) ? l : r;
            smallest = isLower(item, A[smallest]) ? i : smallest;
        }
        else if (l < heap_size)
            smallest = isLower(A[l], item) ? l : i;

        if (smallest == i)
            break;

        A[i] = A[smallest];
        mapping[A[i]] = i;
        i = smallest;
    } while (1);

    A[i] = item;
    mapping[item] = i;
}

Before: 2.0 seconds. After: 1.85 seconds.

unordered_map

Often some other hash map can do better than the default unordered_map. For example you could try Boost's version of unordered_map, or Abseil's flat_hash_map, or various others. There are too many to list.

In any case, with Skarupke's flat_hash_map, the time went from 1.85 seconds to 1.8 seconds. Not amazing, but it's as simple as including a header and changing unordered_map to ska::flat_hash_map.

By the way, for MSVC specifically, unordered_map is a common reason for poor performance of the Debug build. It's not nearly as bad for the Release build.

Answered by harold on November 25, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP