class ResolutionProver:

    def __init__(self):
        self.kb = set()

    def add_rule(self, rule: str):
        clauses = self.to_cnf(rule)
        for clause in clauses:
            self.kb.add(frozenset(clause))
        print (f"added '{rule}' as {clauses}")

    def to_cnf(self, rule: str) -> list[str]:
        rule = rule.replace(" ", "")
        clauses = []

        # we need to handle <->, -> and simple stuff seperately
        if '<->' in rule:
            lhs, rhs = rule.split('<->')

            if '(' in rhs:
                lits = parse_disjuction(rhs)
                lits.add(negate(lhs))
                clauses.append(lits)
            else:
                clauses.append({negate(lhs), rhs})
            
            if '(' in rhs:
                rhs_lits = parse_disjuction(rhs)
                for lit in rhs_lits:
                    clauses.append({negate(lit), lhs})
            else:
                clauses.append({negate(rhs), lhs})

        elif '->' in rule:
            lhs, rhs = rule.split('->')
            lits = parse_disjuction(rhs)
            lits.add(negate(lhs))
            clauses.append(lits)

        else:
            clauses.append(parse_disjuction(rule))
    
        return clauses


    # Proving part
    def resolve(self, clause1: frozenset[str], clause2: frozenset[str]) -> list[frozenset[str]]:
        
        resolvents: list[frozenset[str]] = []

        for lit in clause1:
            compl = negate(lit)

            if compl in clause2:

                new_clause = set(clause1) | set(clause2)
                new_clause.remove(lit)
                new_clause.remove(compl)

                resolvents.append(frozenset(new_clause))
        return resolvents

    def resolution_procedure(self) -> bool:
        clauses = set(self.kb)
        found_new_clause = True

        while found_new_clause:
            found_new_clause = False
            cur_clauses = list(clauses)
            pairs = []

            for i in range(len(cur_clauses)):
                for j in range(i + 1, len(cur_clauses)):
                    pairs.append((cur_clauses[i], cur_clauses[j]))
            
            for c1, c2 in pairs:
                resolvents = self.resolve(c1, c2)

                for res in resolvents:
                    if not res: # empty set found
                        print(f"Contradiction found. Empty clause derived from {set(c1)} and {set(c2)}")
                        return True
        
                    if res not in clauses:
                        print(f"new clause inferred {set(res)}")
                        clauses.add(res)
                        found_new_clause = True
        
        print("No contradiction found")
        return False


def negate(literal: str):
    if literal.startswith('~'):
        return literal[1:]
    else:
        return f"~{literal}"

def parse_disjuction(expression: str) -> set[str]:
    cleaned = expression.replace('(', '').replace(')', '')

    chopped = set()
    for lit in cleaned.split('v'):
        chopped.add(lit.strip())

    return chopped

if __name__ == "__main__":
    prover = ResolutionProver()

    prover.add_rule("~A1")                      # R1
    prover.add_rule("~D1")                      # R2
    prover.add_rule("~B1")                      # R3
    prover.add_rule("B2")                       # R4
    prover.add_rule("B1 <-> (A2 v A3)")         # R5
    prover.add_rule("B2 <-> (A1 v A4 v A5)")    # R6
    prover.add_rule("D1 -> B1")                 # R7
    prover.add_rule("D2 -> B2")                 # R8
    prover.add_rule("A2 -> D1")                 # R9

    print("\nAdding Negated Query of '~A2'")
    prover.add_rule("A2")

    print("\nFull knowledge base (One giant CNF)")
    for c in prover.kb:
        print(set(c))

    prover.resolution_procedure()



