# Simple logical inference system: resolution and model checking for first-order logic. # @author Percy Liang import collections # Recursively apply str inside map def rstr(x): if isinstance(x, tuple): return str(tuple(map(rstr, x))) if isinstance(x, list): return str(list(map(rstr, x))) if isinstance(x, set): return str(set(map(rstr, x))) if isinstance(x, dict): newx = {} for k, v in list(x.items()): newx[rstr(k)] = rstr(v) return str(newx) return str(x) class Expression: # Helper functions used by subclasses. def ensureType(self, arg, wantedType): if not isinstance(arg, wantedType): raise Exception('%s: wanted %s, but got %s' % (self.__class__.__name__, wantedType, arg)) return arg def ensureFormula(self, arg): return self.ensureType(arg, Formula) def ensureFormulas(self, args): for arg in args: self.ensureFormula(arg) return args def isa(self, wantedType): return isinstance(self, wantedType) def join(self, args): return ','.join(str(arg) for arg in args) def __eq__(self, other): return str(self) == str(other) def __hash__(self): return hash(str(self)) # Cache the string to be more efficient def __repr__(self): if not self.strRepn: self.strRepn = self.computeStrRepn() return self.strRepn # A Formula represents a truth value. class Formula(Expression): pass # A Term coresponds to an object. class Term(Expression): pass # Variable symbol (must start with '$') # Example: $x class Variable(Term): def __init__(self, name): if not name.startswith('$'): raise Exception('Variable must start with "$", but got %s' % name) self.name = name self.strRepn = None def computeStrRepn(self): return self.name # Constant symbol (must be uncapitalized) # Example: john class Constant(Term): def __init__(self, name): if not name[0].islower(): raise Exception('Constants must start with a lowercase letter, but got %s' % name) self.name = name self.strRepn = None def computeStrRepn(self): return self.name # Predicate symbol (must be capitalized) applied to arguments. # Example: LivesIn(john, palo_alto) class Atom(Formula): def __init__(self, name, *args): if not name[0].isupper(): raise Exception('Predicates must start with a uppercase letter, but got %s' % name) self.name = name self.args = list(map(toExpr, args)) self.strRepn = None def computeStrRepn(self): if len(self.args) == 0: return self.name return self.name + '(' + self.join(self.args) + ')' def toExpr(x): if isinstance(x, str): if x.startswith('$'): return Variable(x) return Constant(x) return x AtomFalse = False AtomTrue = True # Example: Not(Rain) class Not(Formula): def __init__(self, arg): self.arg = self.ensureFormula(arg) self.strRepn = None def computeStrRepn(self): return 'Not(' + str(self.arg) + ')' # Example: And(Rain,Snow) class And(Formula): def __init__(self, arg1, arg2): self.arg1 = self.ensureFormula(arg1) self.arg2 = self.ensureFormula(arg2) self.strRepn = None def computeStrRepn(self): return 'And(' + str(self.arg1) + ',' + str(self.arg2) + ')' # Example: Or(Rain,Snow) class Or(Formula): def __init__(self, arg1, arg2): self.arg1 = self.ensureFormula(arg1) self.arg2 = self.ensureFormula(arg2) self.strRepn = None def computeStrRepn(self): return 'Or(' + str(self.arg1) + ',' + str(self.arg2) + ')' # Example: Implies(Rain,Wet) class Implies(Formula): def __init__(self, arg1, arg2): self.arg1 = self.ensureFormula(arg1) self.arg2 = self.ensureFormula(arg2) self.strRepn = None def computeStrRepn(self): return 'Implies(' + str(self.arg1) + ',' + str(self.arg2) + ')' # Example: Exists($x,Lives(john, $x)) class Exists(Formula): def __init__(self, var, body): self.var = self.ensureType(toExpr(var), Variable) self.body = self.ensureFormula(body) self.strRepn = None def computeStrRepn(self): return 'Exists(' + str(self.var) + ',' + str(self.body) + ')' # Example: Forall($x,Implies(Human($x),Alive($x))) class Forall(Formula): def __init__(self, var, body): self.var = self.ensureType(toExpr(var), Variable) self.body = self.ensureFormula(body) self.strRepn = None def computeStrRepn(self): return 'Forall(' + str(self.var) + ',' + str(self.body) + ')' # Take a list of conjuncts / disjuncts and return a formula def AndList(forms): result = AtomTrue for form in forms: result = And(result, form) if result != AtomTrue else form return result def OrList(forms): result = AtomFalse for form in forms: result = Or(result, form) if result != AtomFalse else form return result # Return list of conjuncts of |form|. # Example: And(And(A, Or(B, C)), Not(D)) => [A, Or(B, C), Not(D)] def flattenAnd(form): if form.isa(And): return flattenAnd(form.arg1) + flattenAnd(form.arg2) else: return [form] # Return list of disjuncts of |form|. # Example: Or(Or(A, And(B, C)), D) => [A, And(B, C), Not(D)] def flattenOr(form): if form.isa(Or): return flattenOr(form.arg1) + flattenOr(form.arg2) else: return [form] # Syntactic sugar def Equiv(a, b): return And(Implies(a, b), Implies(b, a)) def Xor(a, b): return And(Or(a, b), Not(And(a, b))) # Special predicate which is used internally (e.g., in propositionalization). def Equals(x, y): return Atom('Equals', x, y) # Given a predicate name (e.g., Parent), return a formula that asserts that # that predicate is anti-reflexive # (e.g., Not(Parent(x,x))). def AntiReflexive(predicate): # return Forall('$x', Not(Atom(predicate, '$x', '$x'))) # Force Equals() to be used and show up in the models. return Forall('$x', Forall('$y', Implies(Atom(predicate, '$x', '$y'), Not(Equals('$x', '$y'))))) ############################################################ # Simple inference rules # A Rule takes a sequence of argument Formulas and produces a set of result # Formulas (possibly [] if the rule doesn't apply). class Rule: pass class UnaryRule(Rule): def applyRule(self, form): raise Exception('Override me') class BinaryRule(Rule): def applyRule(self, form1, form2): raise Exception('Override me') # Override if rule is symmetric to save a factor of 2. def symmetric(self): return False ############################################################ # Unification # Mutate |subst| with variable => bindings # Return whether unification was successful # Assume forms are in CNF. # Note: we don't do occurs check because we don't have function symbols. def unify(form1, form2, subst): if form1.isa(Variable): return unifyTerms(form1, form2, subst) if form1.isa(Constant): return unifyTerms(form1, form2, subst) if form1.isa(Atom): return form2.isa(Atom) and form1.name == form2.name and len(form1.args) == len(form2.args) and \ all(unify(form1.args[i], form2.args[i], subst) for i in range(len(form1.args))) if form1.isa(Not): return form2.isa(Not) and unify(form1.arg, form2.arg, subst) if form1.isa(And): return form2.isa(And) and unify(form1.arg1, form2.arg1, subst) and unify(form1.arg2, form2.arg2, subst) if form1.isa(Or): return form2.isa(Or) and unify(form1.arg1, form2.arg1, subst) and unify(form1.arg2, form2.arg2, subst) raise Exception('Unhandled: %s' % form1) # Follow multiple links to get to x def getSubst(subst, x): while True: y = subst.get(x) if y == None: return x x = y def unifyTerms(a, b, subst): # print 'unifyTerms', a, b, rstr(subst) a = getSubst(subst, a) b = getSubst(subst, b) if a == b: return True if a.isa(Variable): subst[a] = b elif b.isa(Variable): subst[b] = a else: return False return True # Assume form in CNF. def applySubst(form, subst): if len(subst) == 0: return form if form.isa(Variable): # print 'applySubst', rstr(form), rstr(subst), rstr(subst.get(form, form)) # return subst.get(form, form) return getSubst(subst, form) if form.isa(Constant): return form if form.isa(Atom): return Atom(*[form.name] + [applySubst(arg, subst) for arg in form.args]) if form.isa(Not): return Not(applySubst(form.arg, subst)) if form.isa(And): return And(applySubst(form.arg1, subst), applySubst(form.arg2, subst)) if form.isa(Or): return Or(applySubst(form.arg1, subst), applySubst(form.arg2, subst)) raise Exception('Unhandled: %s' % form) ############################################################ # Convert to CNF, Resolution rules def withoutElementAt(items, i): return items[0:i] + items[i + 1:] def negateFormula(item): return item.arg if item.isa(Not) else Not(item) # Given a list of Formulas, return a new list with: # - If A and Not(A) exists, return [AtomFalse] for conjunction, [AtomTrue] for disjunction # - Remove duplicates # - Sort the list def reduceFormulas(items, mode): for i in range(len(items)): for j in range(i + 1, len(items)): if negateFormula(items[i]) == items[j]: if mode == And: return [AtomFalse] elif mode == Or: return [AtomTrue] else: raise Exception("Invalid mode: %s" % mode) items = sorted(set(items), key=str) return items # Generate a list of all subexpressions of a formula (including terms). # Example: # - Input: And(Atom('A', Constant('a')), Atom('B')) # - Output: [And(Atom('A', Constant('a')), Atom('B')), Atom('A', Constant('a')), Constant('a'), Atom('B')] def allSubexpressions(form): subforms = [] def recurse(form): subforms.append(form) if form.isa(Variable): pass elif form.isa(Constant): pass elif form.isa(Atom): for arg in form.args: recurse(arg) elif form.isa(Not): recurse(form.arg) elif form.isa(And): recurse(form.arg1); recurse(form.arg2) elif form.isa(Or): recurse(form.arg1); recurse(form.arg2) elif form.isa(Implies): recurse(form.arg1); recurse(form.arg2) elif form.isa(Exists): recurse(form.body) elif form.isa(Forall): recurse(form.body) else: raise Exception("Unhandled: %s" % form) recurse(form) return subforms # Return a list of the free variables in |form|. def allFreeVars(form): variables = [] def recurse(form, boundVars): if form.isa(Variable): if form not in boundVars: variables.append(form) elif form.isa(Constant): pass elif form.isa(Atom): for arg in form.args: recurse(arg, boundVars) elif form.isa(Not): recurse(form.arg, boundVars) elif form.isa(And): recurse(form.arg1, boundVars); recurse(form.arg2, boundVars) elif form.isa(Or): recurse(form.arg1, boundVars); recurse(form.arg2, boundVars) elif form.isa(Implies): recurse(form.arg1, boundVars); recurse(form.arg2, boundVars) elif form.isa(Exists): recurse(form.body, boundVars + [form.var]) elif form.isa(Forall): recurse(form.body, boundVars + [form.var]) else: raise Exception("Unhandled: %s" % form) recurse(form, []) return variables # Return |form| with all free occurrences of |var| replaced with |obj|. def substituteFreeVars(form, var, obj): def recurse(form, boundVars): if form.isa(Variable): if form == var: return obj return form elif form.isa(Constant): return form elif form.isa(Atom): return Atom(*[form.name] + [recurse(arg, boundVars) for arg in form.args]) elif form.isa(Not): return Not(recurse(form.arg, boundVars)) elif form.isa(And): return And(recurse(form.arg1, boundVars), recurse(form.arg2, boundVars)) elif form.isa(Or): return Or(recurse(form.arg1, boundVars), recurse(form.arg2, boundVars)) elif form.isa(Implies): return Implies(recurse(form.arg1, boundVars), recurse(form.arg2, boundVars)) elif form.isa(Exists): if form.var == var: return form # Don't substitute inside return Exists(form.var, recurse(form.body, boundVars + [form.var])) elif form.isa(Forall): if form.var == var: return form # Don't substitute inside return Forall(form.var, recurse(form.body, boundVars + [form.var])) else: raise Exception("Unhandled: %s" % form) return recurse(form, []) def allConstants(form): return [x for x in allSubexpressions(form) if x.isa(Constant)] class ToCNFRule(UnaryRule): def __init__(self): # For standardizing variables. # For each existing variable name, the number of times it has occurred self.varCounts = collections.Counter() def applyRule(self, form): newForm = form # Step 1: remove implications def removeImplications(form): if form.isa(Atom): return form if form.isa(Not): return Not(removeImplications(form.arg)) if form.isa(And): return And(removeImplications(form.arg1), removeImplications(form.arg2)) if form.isa(Or): return Or(removeImplications(form.arg1), removeImplications(form.arg2)) if form.isa(Implies): return Or(Not(removeImplications(form.arg1)), removeImplications(form.arg2)) if form.isa(Exists): return Exists(form.var, removeImplications(form.body)) if form.isa(Forall): return Forall(form.var, removeImplications(form.body)) raise Exception("Unhandled: %s" % form) newForm = removeImplications(newForm) # Step 2: push negation inwards (de Morgan) def pushNegationInwards(form): if form.isa(Atom): return form if form.isa(Not): if form.arg.isa(Not): # Double negation return pushNegationInwards(form.arg.arg) if form.arg.isa(And): # De Morgan return Or(pushNegationInwards(Not(form.arg.arg1)), pushNegationInwards(Not(form.arg.arg2))) if form.arg.isa(Or): # De Morgan return And(pushNegationInwards(Not(form.arg.arg1)), pushNegationInwards(Not(form.arg.arg2))) if form.arg.isa(Exists): return Forall(form.arg.var, pushNegationInwards(Not(form.arg.body))) if form.arg.isa(Forall): return Exists(form.arg.var, pushNegationInwards(Not(form.arg.body))) return form if form.isa(And): return And(pushNegationInwards(form.arg1), pushNegationInwards(form.arg2)) if form.isa(Or): return Or(pushNegationInwards(form.arg1), pushNegationInwards(form.arg2)) if form.isa(Implies): return Or(Not(pushNegationInwards(form.arg1)), pushNegationInwards(form.arg2)) if form.isa(Exists): return Exists(form.var, pushNegationInwards(form.body)) if form.isa(Forall): return Forall(form.var, pushNegationInwards(form.body)) raise Exception("Unhandled: %s" % form) newForm = pushNegationInwards(newForm) # Step 3: standardize variables: make sure all variables are different # Don't modify subst; return a new version where var is mapped onto # something that hasn't been seen before. def updateSubst(subst, var): self.varCounts[var.name] += 1 newVar = Variable(var.name + str(self.varCounts[var.name])) return dict(list(subst.items()) + [(var, newVar)]) def standardizeVariables(form, subst): if form.isa(Variable): if form not in subst: raise Exception("Free variable found: %s" % form) return subst[form] if form.isa(Constant): return form if form.isa(Atom): return Atom(*([form.name] + [standardizeVariables(arg, subst) for arg in form.args])) if form.isa(Not): return Not(standardizeVariables(form.arg, subst)) if form.isa(And): return And(standardizeVariables(form.arg1, subst), standardizeVariables(form.arg2, subst)) if form.isa(Or): return Or(standardizeVariables(form.arg1, subst), standardizeVariables(form.arg2, subst)) if form.isa(Exists): newSubst = updateSubst(subst, form.var) return Exists(newSubst[form.var], standardizeVariables(form.body, newSubst)) if form.isa(Forall): newSubst = updateSubst(subst, form.var) return Forall(newSubst[form.var], standardizeVariables(form.body, newSubst)) raise Exception("Unhandled: %s" % form) newForm = standardizeVariables(newForm, {}) # Step 4: replace existentially quantified variables with Skolem functions def skolemize(form, subst, scope): if form.isa(Variable): return subst.get(form, form) if form.isa(Constant): return form if form.isa(Atom): return Atom(*[form.name] + [skolemize(arg, subst, scope) for arg in form.args]) if form.isa(Not): return Not(skolemize(form.arg, subst, scope)) if form.isa(And): return And(skolemize(form.arg1, subst, scope), skolemize(form.arg2, subst, scope)) if form.isa(Or): return Or(skolemize(form.arg1, subst, scope), skolemize(form.arg2, subst, scope)) if form.isa(Exists): # Create a Skolem function that depends on the variables in the scope (list of variables) # Example: # - Suppose scope = [$x, $y] and form = Exists($z,F($z)). # - Normally, we would return F(Z($x,$y)), where Z is a brand new Skolem function. # - But since we don't have function symbols, we replace with a Skolem predicate: # Forall($z,Implies(Z($z,$x,$y),F($z))) # Important: when doing resolution, need to catch Not(Z($z,*,*)) as a contradiction. if len(scope) == 0: subst[form.var] = Constant('skolem' + form.var.name) return skolemize(form.body, subst, scope) else: skolem = Atom(*['Skolem' + form.var.name, form.var] + scope) return Forall(form.var, Or(Not(skolem), skolemize(form.body, subst, scope))) if form.isa(Forall): return Forall(form.var, skolemize(form.body, subst, scope + [form.var])) raise Exception("Unhandled: %s" % form) newForm = skolemize(newForm, {}, []) # Step 5: remove universal quantifiers [note: need to do this before distribute, unlike Russell/Norvig book] def removeUniversalQuantifiers(form): if form.isa(Atom): return form if form.isa(Not): return Not(removeUniversalQuantifiers(form.arg)) if form.isa(And): return And(removeUniversalQuantifiers(form.arg1), removeUniversalQuantifiers(form.arg2)) if form.isa(Or): return Or(removeUniversalQuantifiers(form.arg1), removeUniversalQuantifiers(form.arg2)) if form.isa(Forall): return removeUniversalQuantifiers(form.body) raise Exception("Unhandled: %s" % form) newForm = removeUniversalQuantifiers(newForm) # Step 6: distribute Or over And (want And on the outside): Or(And(A,B),C) becomes And(Or(A,C),Or(B,C)) def distribute(form): if form.isa(Atom): return form if form.isa(Not): return Not(distribute(form.arg)) if form.isa(And): return And(distribute(form.arg1), distribute(form.arg2)) if form.isa(Or): # First need to distribute as much as possible f1 = distribute(form.arg1) f2 = distribute(form.arg2) if f1.isa(And): return And(distribute(Or(f1.arg1, f2)), distribute(Or(f1.arg2, f2))) if f2.isa(And): return And(distribute(Or(f1, f2.arg1)), distribute(Or(f1, f2.arg2))) return Or(f1, f2) if form.isa(Exists): return Exists(form.var, distribute(form.body)) if form.isa(Forall): return Forall(form.var, distribute(form.body)) raise Exception("Unhandled: %s" % form) newForm = distribute(newForm) # Post-processing: break up conjuncts into conjuncts and sort the disjuncts in each conjunct # Remove instances of A and Not(A) conjuncts = [OrList(reduceFormulas(flattenOr(f), Or)) for f in flattenAnd(newForm)] # print rstr(form), rstr(conjuncts) assert len(conjuncts) > 0 if any(x == AtomFalse for x in conjuncts): return [AtomFalse] if all(x == AtomTrue for x in conjuncts): return [AtomTrue] conjuncts = [x for x in conjuncts if x != AtomTrue] results = reduceFormulas(conjuncts, And) if len(results) == 0: results = [AtomFalse] # print 'CNF', form, rstr(results) return results class ResolutionRule(BinaryRule): # Assume formulas are in CNF # Assume A and Not(A) don't both exist in a form (taken care of by CNF conversion) def applyRule(self, form1, form2): items1 = flattenOr(form1) items2 = flattenOr(form2) results = [] # print 'RESOLVE', form1, form2 for i, item1 in enumerate(items1): for j, item2 in enumerate(items2): subst = {} if unify(negateFormula(item1), item2, subst): newItems1 = withoutElementAt(items1, i) newItems2 = withoutElementAt(items2, j) newItems = [applySubst(item, subst) for item in newItems1 + newItems2] if len(newItems) == 0: # Contradiction: False results = [AtomFalse] break # print 'STEP: %s %s => %s %s' % (form1, form2, rstr(newItems), rstr(subst)) result = OrList(reduceFormulas(newItems, Or)) # Not(Skolem$x($x,...)) is a contradiction if isinstance(result, Not) and result.arg.name.startswith('Skolem'): results = [AtomFalse] break # Don't add redundant stuff if result == AtomTrue: continue if result in results: continue results.append(result) if results == [AtomFalse]: break # print 'RESOLUTION: %s %s => %s' % (form1, form2, rstr(results)) return results def symmetric(self): return True ############################################################ # Model checking # Return the set of models def performModelChecking(allForms, findAll, objects=None, verbose=0): if verbose >= 3: print(('performModelChecking', rstr(allForms))) # Propositionalize, convert to CNF, dedup allForms = propositionalize(allForms, objects) # Convert to CNF: actually makes things slower # allForms = [f for form in allForms for f in ToCNFRule().applyRule(form)] # if any(x == AtomFalse for x in allForms): return [] # if all(x == AtomTrue for x in allForms): return [set()] # allForms = [x for x in allForms if x != AtomTrue] # allForms = reduceFormulas(allForms, And) allForms = [universalInterpret(form) for form in allForms] allForms = list(set(allForms) - set([AtomTrue, AtomFalse])) if verbose >= 3: print(('All Forms:', rstr(allForms))) if allForms == []: return [set()] # One model if allForms == [AtomFalse]: return [] # No models # Atoms are the variables atoms = set() for form in allForms: for f in allSubexpressions(form): if f.isa(Atom): atoms.add(f) atoms = list(atoms) if verbose >= 3: print(('Atoms:', rstr(atoms))) print(('Constraints:', rstr(allForms))) # For each atom, list the set of formulas # atom index => list of formulas atomForms = [ (atom, [form for form in allForms if atom in allSubexpressions(form)]) \ for atom in atoms \ ] # Degree heuristic atomForms.sort(key=lambda x: -len(x[1])) atoms = [atom for atom, form in atomForms] # Keep only the forms for an atom if it only uses atoms up until that point. atomPrefixForms = [] for i, (atom, forms) in enumerate(atomForms): prefixForms = [] for form in forms: useAtoms = set(x for x in allSubexpressions(form) if x.isa(Atom)) if useAtoms <= set(atoms[0:i + 1]): prefixForms.append(form) atomPrefixForms.append((atom, prefixForms)) if verbose >= 3: print('Plan:') for atom, forms in atomForms: print((" %s: %s" % (rstr(atom), rstr(forms)))) assert sum(len(forms) for atom, forms in atomPrefixForms) == len(allForms) # Build up an interpretation N = len(atoms) models = [] # List of models which are true model = set() # Set of true atoms, mutated over time def recurse(i): # i: atom index if not findAll and len(models) > 0: return if i == N: # Found a model on which the formulas are true models.append(set(model)) return atom, forms = atomPrefixForms[i] result = universalInterpretAtom(atom) if result == None or result == False: if interpretForms(forms, model): recurse(i + 1) if result == None or result == True: model.add(atom) if interpretForms(forms, model): recurse(i + 1) model.remove(atom) recurse(0) if verbose >= 5: print('Models:') for model in models: print((" %s" % rstr(model))) return models # A model is a set of atoms. def printModel(model): for x in sorted(map(str, model)): print(('*', x, '=', 'True')) print(('*', '(other atoms if any)', '=', 'False')) # Convert a first-order logic formula into a propositional formula, assuming # database semantics. # Example: Forall becomes And over all objects # - Input: form = Forall('$x', Atom('Alive', '$x')), objects = ['alice', 'bob'] # - Output: And(Atom('Alive', 'alice'), Atom('Alive', 'bob')) # Example: Exists becomes Or over all objects # - Input: form = Exists('$x', Atom('Alive', '$x')), objects = ['alice', 'bob'] # - Output: Or(Atom('Alive', 'alice'), Atom('Alive', 'bob')) def propositionalize(forms, objects=None): # If not specified, set objects to all constants mentioned in in |form|. if objects == None: objects = set() for form in forms: objects |= set(allConstants(form)) objects = list(objects) else: # Make sure objects are expressions: Convert ['a', 'b'] to [Constant('a'), Constant('b')] objects = [toExpr(obj) for obj in objects] # Recursively convert |form|, which could contain Exists and Forall, to forms that don't contain these quantifiers. # |subst| is a map from variables to constants. def convert(form, subst): if form.isa(Variable): if form not in subst: raise Exception("Free variable found: %s" % form) return subst[form] if form.isa(Constant): return form if form.isa(Atom): return Atom(*[form.name] + [convert(arg, subst) for arg in form.args]) if form.isa(Not): return Not(convert(form.arg, subst)) if form.isa(And): return And(convert(form.arg1, subst), convert(form.arg2, subst)) if form.isa(Or): return Or(convert(form.arg1, subst), convert(form.arg2, subst)) if form.isa(Implies): return Implies(convert(form.arg1, subst), convert(form.arg2, subst)) if form.isa(Exists): return OrList([convert(form.body, dict(list(subst.items()) + [(form.var, obj)])) for obj in objects]) if form.isa(Forall): return AndList([convert(form.body, dict(list(subst.items()) + [(form.var, obj)])) for obj in objects]) raise Exception("Unhandled: %s" % form) # Think of newForms as conjoined newForms = [] # Convert all the forms for form in forms: newForm = convert(form, {}) if newForm == AtomFalse: return [AtomFalse] if newForm == AtomTrue: continue newForms.extend(flattenAnd(newForm)) return newForms # Some atoms have a fixed value, so we should just evaluate them. # Assumption: atom is propositional logic. def universalInterpretAtom(atom): if atom.name == 'Equals': return AtomTrue if atom.args[0] == atom.args[1] else AtomFalse return None # Reduce the expression (e.g., Equals(a,a) => True) # Assumption: atom is propositional logic. def universalInterpret(form): if form.isa(Variable): return form if form.isa(Constant): return form if form.isa(Atom): result = universalInterpretAtom(form) if result != None: return result return Atom(*[form.name] + [universalInterpret(arg) for arg in form.args]) if form.isa(Not): arg = universalInterpret(form.arg) if arg == AtomTrue: return AtomFalse if arg == AtomFalse: return AtomTrue return Not(arg) if form.isa(And): arg1 = universalInterpret(form.arg1) arg2 = universalInterpret(form.arg2) if arg1 == AtomFalse: return AtomFalse if arg2 == AtomFalse: return AtomFalse if arg1 == AtomTrue: return arg2 if arg2 == AtomTrue: return arg1 return And(arg1, arg2) if form.isa(Or): arg1 = universalInterpret(form.arg1) arg2 = universalInterpret(form.arg2) if arg1 == AtomTrue: return AtomTrue if arg2 == AtomTrue: return AtomTrue if arg1 == AtomFalse: return arg2 if arg2 == AtomFalse: return arg1 return Or(arg1, arg2) if form.isa(Implies): arg1 = universalInterpret(form.arg1) arg2 = universalInterpret(form.arg2) if arg1 == AtomFalse: return AtomTrue if arg2 == AtomTrue: return AtomTrue if arg1 == AtomTrue: return arg2 if arg2 == AtomFalse: return Not(arg1) return Implies(arg1, arg2) raise Exception("Unhandled: %s" % form) def interpretForm(form, model): if form.isa(Atom): return form in model if form.isa(Not): return not interpretForm(form.arg, model) if form.isa(And): return interpretForm(form.arg1, model) and interpretForm(form.arg2, model) if form.isa(Or): return interpretForm(form.arg1, model) or interpretForm(form.arg2, model) if form.isa(Implies): return not interpretForm(form.arg1, model) or interpretForm(form.arg2, model) raise Exception("Unhandled: %s" % form) # Conjunction def interpretForms(forms, model): return all(interpretForm(form, model) for form in forms) ############################################################ # A Derivation is a tree where each node corresponds to the application of a rule. # For any Formula, we can extract a set of categories. # Rule arguments are labeled with category. class Derivation: def __init__(self, form, children, cost, derived): self.form = form self.children = children self.cost = cost self.permanent = False # Marker for being able to extract. self.derived = derived # Whether this was derived (as opposed to added by the user). def __repr__(self): return 'Derivation(%s, cost=%s, permanent=%s, derived=%s)' % ( self.form, self.cost, self.permanent, self.derived) # Possible responses to queries to the knowledge base ENTAILMENT = "ENTAILMENT" CONTINGENT = "CONTINGENT" CONTRADICTION = "CONTRADICTION" # A response to a KB query class KBResponse: # query: what the query is (just a string description for printing) # modify: whether we modified the knowledge base # status: one of the ENTAILMENT, CONTINGENT, CONTRADICTION # trueModel: if available, a model consistent with the KB for which the the query is true # falseModel: if available, a model consistent with the KB for which the the query is false def __init__(self, query, modify, status, trueModel, falseModel): self.query = query self.modify = modify self.status = status self.trueModel = trueModel self.falseModel = falseModel def show(self, verbose=1): padding = '>>>>>' print(padding + ' ' + self.responseStr()) if verbose >= 1: print(('Query: %s[%s]' % ('TELL' if self.modify else 'ASK', self.query))) if self.trueModel: print('An example of a model where query is TRUE:') printModel(self.trueModel) if self.falseModel: print('An example of a model where query is FALSE:') printModel(self.falseModel) def responseStr(self): if self.status == ENTAILMENT: if self.modify: return 'I already knew that.' else: return 'Yes.' elif self.status == CONTINGENT: if self.modify: return 'I learned something.' else: return 'I don\'t know.' elif self.status == CONTRADICTION: if self.modify: return 'I don\'t buy that.' else: return 'No.' else: raise Exception("Invalid status: %s" % self.status) def __repr__(self): return self.responseStr() def showKBResponse(response, verbose=1): if isinstance(response, KBResponse): response.show(verbose) else: items = [(obj, r.status) for ((var, obj), r) in list(response.items())] print(('Yes: %s' % rstr([obj for obj, status in items if status == ENTAILMENT]))) print(('Maybe: %s' % rstr([obj for obj, status in items if status == CONTINGENT]))) print(('No: %s' % rstr([obj for obj, status in items if status == CONTRADICTION]))) # A KnowledgeBase is a set collection of Formulas. # Interact with it using # - addRule: add inference rules # - tell: modify the KB with a new formula. # - ask: query the KB about class KnowledgeBase: def __init__(self, standardizationRule, rules, modelChecking, verbose=0): # Rule to apply to each formula that's added to the KB (None is possible). self.standardizationRule = standardizationRule # Set of inference rules self.rules = rules # Use model checking as opposed to applying rules. self.modelChecking = modelChecking # For debugging self.verbose = verbose # Formulas that we believe are true (used when not doing model checking). self.derivations = {} # Map from Derivation key (logical form) to Derivation # Add a formula |form| to the KB if it doesn't contradict. Returns a KBResponse. def tell(self, form): return self.query(form, modify=True) # Ask whether the logical formula |form| is True, False, or unknown based # on the KB. Returns a KBResponse. def ask(self, form): return self.query(form, modify=False) def dump(self): print(('==== Knowledge base [%d derivations] ===' % len(self.derivations))) for deriv in list(self.derivations.values()): print((('-' if deriv.derived else '*'), deriv if self.verbose >= 2 else deriv.form)) ####### Internal functions # Returns a KBResponse or if there are free variables, a mapping from (var, obj) => query without that variable. def query(self, form, modify): # print 'QUERY', form # Handle wh-queries: try all possible values of the free variable, and recurse on query(). freeVars = allFreeVars(form) if len(freeVars) > 0: if modify: raise Exception("Can't modify database with a query with free variables: %s" % form) var = freeVars[0] allForms = AndList([deriv.form for deriv in list(self.derivations.values())]) if allForms == AtomTrue: return {} # Weird corner case objects = allConstants(allForms) # Try binding |var| to |obj| response = {} for obj in objects: response[(var, obj)] = self.query(substituteFreeVars(form, var, obj), modify) return response # Assume no free variables from here on... formStr = '%s, standardized: %s' % (form, rstr(self.standardize(form))) # Models to serve as supporting evidence falseModel = None # Makes the query false trueModel = None # Makes the query true # Add Not(form) if not self.addAxiom(Not(form)): self.removeTemporary() status = ENTAILMENT else: # Inconclusive... falseModel = self.consistentModel self.removeTemporary() # Add form if self.addAxiom(form): if modify: self.makeTemporaryPermanent() else: self.removeTemporary() trueModel = self.consistentModel status = CONTINGENT else: self.removeTemporary() status = CONTRADICTION return KBResponse(query=formStr, modify=modify, status=status, trueModel=trueModel, falseModel=falseModel) # Apply the standardization rule to |form|. def standardize(self, form): if self.standardizationRule: return self.standardizationRule.applyRule(form) return [form] # Return whether adding |form| is consistent with the current knowledge base. # Add |form| to the knowledge base if we can. Note: this is done temporarily! # Just calls addDerivation. def addAxiom(self, form): self.consistentModel = None for f in self.standardize(form): if f == AtomFalse: return False if f == AtomTrue: continue deriv = Derivation(f, children=[], cost=0, derived=False) if not self.addDerivation(deriv): return False return True # Return whether the Derivation is consistent with the KB. def addDerivation(self, deriv): # Derived a contradiction if deriv.form == AtomFalse: return False key = deriv.form oldDeriv = self.derivations.get(key) maxCost = 100 if oldDeriv == None and deriv.cost <= maxCost: # if oldDeriv == None or (deriv.cost < oldDeriv.cost and (deriv.permanent >= oldDeriv.permanent)): # print 'UPDATE %s %s' % (deriv, oldDeriv) # self.dump() # Something worth updating self.derivations[key] = deriv if self.verbose >= 3: print(('add %s [%s derivations]' % (deriv, len(self.derivations)))) if self.modelChecking: allForms = [deriv.form for deriv in list(self.derivations.values())] models = performModelChecking(allForms, findAll=False, verbose=self.verbose) if len(models) == 0: return False else: self.consistentModel = models[0] # Apply rules forward if not self.applyUnaryRules(deriv): return False for key2, deriv2 in list(self.derivations.items()): if not self.applyBinaryRules(deriv, deriv2): return False if not self.applyBinaryRules(deriv2, deriv): return False return True # Raise an exception if |formulas| is not a list of Formulas. def ensureFormulas(self, rule, formulas): if isinstance(formulas, list) and all(formula == False or isinstance(formula, Formula) for formula in formulas): return formulas raise Exception('Expected list of Formulas, but %s returned %s' % (rule, formulas)) # Return whether everything is okay (no contradiction). def applyUnaryRules(self, deriv): for rule in self.rules: if not isinstance(rule, UnaryRule): continue for newForm in self.ensureFormulas(rule, rule.applyRule(deriv.form)): if not self.addDerivation(Derivation(newForm, children=[deriv], cost=deriv.cost + 1, derived=True)): return False return True # Return whether everything is okay (no contradiction). def applyBinaryRules(self, deriv1, deriv2): for rule in self.rules: if not isinstance(rule, BinaryRule): continue if rule.symmetric() and str(deriv1.form) >= str(deriv2.form): continue # Optimization for newForm in self.ensureFormulas(rule, rule.applyRule(deriv1.form, deriv2.form)): if not self.addDerivation( Derivation(newForm, children=[deriv1, deriv2], cost=deriv1.cost + deriv2.cost + 1, derived=True)): return False return True # Remove all the temporary derivations from the KB. def removeTemporary(self): for key, value in list(self.derivations.items()): if not value.permanent: del self.derivations[key] # Mark all the derivations marked temporary to permanent. def makeTemporaryPermanent(self): for deriv in list(self.derivations.values()): deriv.permanent = True # Create an empty knowledge base equipped with the usual inference rules. def createResolutionKB(): return KnowledgeBase(standardizationRule=ToCNFRule(), rules=[ResolutionRule()], modelChecking=False) def createModelCheckingKB(): return KnowledgeBase(standardizationRule=None, rules=[], modelChecking=True)