TransWikia.com

3-SAT Solver Python

Code Review Asked by n00bster on December 12, 2020

I’ve written a 3-SAT solver based on this prompt:

Alice recently started to work for a hardware design company and as a part of her job, she needs to identify defects in fabricated integrated circuits. An approach for identifying these defects boils down to solving a satisfiability instance. She needs your help to write a program to do this task.

Input
The first line of input contains a single integer, not more than 5, indicating the number of test cases to follow. The first line of each test case contains two integers n and m where 1 ≤ n ≤ 20 indicates the number of variables and 1 ≤ m ≤ 100 indicates the number of clauses. Then, m lines follow corresponding to each clause. Each clause is a disjunction of literals in the form Xi or ~Xi for some 1 ≤ i ≤ n, where ~Xi indicates the negation of the literal Xi. The “or” operator is denoted by a ‘v’ character and is separated from literals with a single space.

Output
For each test case, display satisfiable on a single line if there is a satisfiable assignment; otherwise display unsatisfiable.

Sample Input

2
3 3
X1 v X2
~X1
~X2 v X3
3 5
X1 v X2 v X3
X1 v ~X2
X2 v ~X3
X3 v ~X1
~X1 v ~X2 v ~X3 

Sample Output

satisfiable
unsatisfiable

This code basically maintains mySets which is a list of sets, which all represent possible combinations of literals that could make the entire statement true. Every time we parse a new clause, we check if it’s negation exists already in a set, if it does, the set is not included.

This works, but it runs a bit slow.

import sys

cases = int(sys.stdin.readline())


def GetReverse(literal):
    if literal[0] == '~':
    return literal[1:]
    else:
        return '~' + literal


for i in range(cases):
    vars, clauses = map(int, sys.stdin.readline().split())

mySets = []

firstClause = sys.stdin.readline().strip().split(" v ")

for c in firstClause:
    this = set()
    this.add(c)
    mySets.append(this)


for i in range(clauses-1):
    tempSets = []
    currentClause = sys.stdin.readline().strip().split(" v ")

    for s in mySets:
        for literal in currentClause:

            if not s.__contains__(GetReverse(literal)):


                newset = s.copy()
                newset.add(literal)

                tempSets.append(newset)
    mySets = tempSets


if mySets:
    print("satisfiable")
else:
    print("unsatisfiable")

I think the problem is here, due to the indented for-loops. 3-SAT is supposed to be exponential, but I would like to speed it up a bit (perhaps by removing a loop?)

for i in range(clauses-1):
    tempSets = []
    currentClause = sys.stdin.readline().strip().split(" v ")

    for s in mySets:
        for literal in currentClause:

            if not s.__contains__(GetReverse(literal)):


                newset = s.copy()
                newset.add(literal)

                tempSets.append(newset)
    mySets = tempSets

2 Answers

If you instrument your code with some strategically place print statements, you will see that there is some repeated computations going on. In the second test case when processing the clause X2 v ~X3, the set {'X1', 'X2'} gets added to mySets twice. When processing the clause X3 v ~X1, the set {'X3', 'X1', 'X2'} gets added to mySets three times.

For large cases, it might speed things up to change mySets to a set() instead of a list to eliminate the duplicates. Then the inner sets need to be frozensets.

mySets is a set of possible solutions that satisfy all the clauses, so I renamed it to candidates.

If you initialize candidates to contain a single empty set, then the first clause doesn't need to be handled separately.

I think you can stop anytime candidates is empty.

Also, split up the code into functions.

def is_satisfiable(n_vars, clauses):
    candidates = {frozenset()}

    for clause in clauses:
        temp = set()

        for s in candidates:
            for literal in clause:

                if GetReverse(literal) not in s:

                    temp.add(s | {literal})

        candidates = temp
        
        if len(candidates) == 0:
            return False

    return True
        
        
def load_case(f):
    n_vars, n_clauses = f.readline().split()
    clauses = [f.readline().strip().split(' v ') for _ in range(int(n_clauses))]
    return int(n_vars), clauses
    
    
def main(f=sys.stdin):
    num_cases = int(f.readline())

    for i in range(num_cases):
        n_vars, clauses = load_case(f)
        result = is_satisfiable(n_vars, clauses)
        
        print(f"{'satisfiable' if result else 'unsatisfiable'}")

Called like:

import io

data = """
2
3 3
X1 v X2
~X1
~X2 v X3
3 5
X1 v X2 v X3
X1 v ~X2
X2 v ~X3
X3 v ~X1
~X1 v ~X2 v ~X3 
""".strip()

main(io.StringIO(data))

or

import sys

main(sys.stdin)        

Correct answer by RootTwo on December 12, 2020

Here's a suggested implementation that changes basically nothing about your algorithm, but

  • has proper indentation
  • uses a little bit of type hinting
  • uses set literals and generators
  • uses _ for "unused" variables
  • adds a parse_clause() because the clause code is repeated
  • uses a StringIO, for these purposes, to effectively mock away stdin and use the example input you showed
  • uses PEP8-compliant names (with underscores)
from io import StringIO
from typing import List

stdin = StringIO('''2
3 3
X1 v X2
~X1
~X2 v X3
3 5
X1 v X2 v X3
X1 v ~X2
X2 v ~X3
X3 v ~X1
~X1 v ~X2 v ~X3
'''
)


def get_reverse(literal: str) -> str:
    if literal[0] == '~':
        return literal[1:]
    return '~' + literal


def parse_clause() -> List[str]:
    return stdin.readline().strip().split(' v ')


n_cases = int(stdin.readline())
for _ in range(n_cases):
    n_vars, n_clauses = (int(s) for s in stdin.readline().split())
    my_sets = [{c} for c in parse_clause()]

    for _ in range(n_clauses - 1):
        temp_sets = []
        current_clause = parse_clause()

        for s in my_sets:
            for literal in current_clause:
                if get_reverse(literal) not in s:
                    new_set = s.copy()
                    new_set.add(literal)
                    temp_sets.append(new_set)

        my_sets = temp_sets

    if my_sets:
        print('satisfiable')
    else:
        print('unsatisfiable')

Answered by Reinderien on December 12, 2020

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP