csc-665-homework/hw2/ints/logic.py
2024-03-20 16:21:15 -07:00

1085 lines
41 KiB
Python

# Simple logical inference system: resolution and model checking for first-order logic.
# @author Percy Liang
import collections
# Recursively apply str inside map
def rstr(x):
if isinstance(x, tuple): return str(tuple(map(rstr, x)))
if isinstance(x, list): return str(list(map(rstr, x)))
if isinstance(x, set): return str(set(map(rstr, x)))
if isinstance(x, dict):
newx = {}
for k, v in list(x.items()):
newx[rstr(k)] = rstr(v)
return str(newx)
return str(x)
class Expression:
# Helper functions used by subclasses.
def ensureType(self, arg, wantedType):
if not isinstance(arg, wantedType):
raise Exception('%s: wanted %s, but got %s' % (self.__class__.__name__, wantedType, arg))
return arg
def ensureFormula(self, arg):
return self.ensureType(arg, Formula)
def ensureFormulas(self, args):
for arg in args: self.ensureFormula(arg)
return args
def isa(self, wantedType):
return isinstance(self, wantedType)
def join(self, args):
return ','.join(str(arg) for arg in args)
def __eq__(self, other):
return str(self) == str(other)
def __hash__(self):
return hash(str(self))
# Cache the string to be more efficient
def __repr__(self):
if not self.strRepn: self.strRepn = self.computeStrRepn()
return self.strRepn
# A Formula represents a truth value.
class Formula(Expression): pass
# A Term coresponds to an object.
class Term(Expression): pass
# Variable symbol (must start with '$')
# Example: $x
class Variable(Term):
def __init__(self, name):
if not name.startswith('$'): raise Exception('Variable must start with "$", but got %s' % name)
self.name = name
self.strRepn = None
def computeStrRepn(self): return self.name
# Constant symbol (must be uncapitalized)
# Example: john
class Constant(Term):
def __init__(self, name):
if not name[0].islower(): raise Exception('Constants must start with a lowercase letter, but got %s' % name)
self.name = name
self.strRepn = None
def computeStrRepn(self): return self.name
# Predicate symbol (must be capitalized) applied to arguments.
# Example: LivesIn(john, palo_alto)
class Atom(Formula):
def __init__(self, name, *args):
if not name[0].isupper(): raise Exception('Predicates must start with a uppercase letter, but got %s' % name)
self.name = name
self.args = list(map(toExpr, args))
self.strRepn = None
def computeStrRepn(self):
if len(self.args) == 0: return self.name
return self.name + '(' + self.join(self.args) + ')'
def toExpr(x):
if isinstance(x, str):
if x.startswith('$'): return Variable(x)
return Constant(x)
return x
AtomFalse = False
AtomTrue = True
# Example: Not(Rain)
class Not(Formula):
def __init__(self, arg):
self.arg = self.ensureFormula(arg)
self.strRepn = None
def computeStrRepn(self): return 'Not(' + str(self.arg) + ')'
# Example: And(Rain,Snow)
class And(Formula):
def __init__(self, arg1, arg2):
self.arg1 = self.ensureFormula(arg1)
self.arg2 = self.ensureFormula(arg2)
self.strRepn = None
def computeStrRepn(self): return 'And(' + str(self.arg1) + ',' + str(self.arg2) + ')'
# Example: Or(Rain,Snow)
class Or(Formula):
def __init__(self, arg1, arg2):
self.arg1 = self.ensureFormula(arg1)
self.arg2 = self.ensureFormula(arg2)
self.strRepn = None
def computeStrRepn(self): return 'Or(' + str(self.arg1) + ',' + str(self.arg2) + ')'
# Example: Implies(Rain,Wet)
class Implies(Formula):
def __init__(self, arg1, arg2):
self.arg1 = self.ensureFormula(arg1)
self.arg2 = self.ensureFormula(arg2)
self.strRepn = None
def computeStrRepn(self): return 'Implies(' + str(self.arg1) + ',' + str(self.arg2) + ')'
# Example: Exists($x,Lives(john, $x))
class Exists(Formula):
def __init__(self, var, body):
self.var = self.ensureType(toExpr(var), Variable)
self.body = self.ensureFormula(body)
self.strRepn = None
def computeStrRepn(self): return 'Exists(' + str(self.var) + ',' + str(self.body) + ')'
# Example: Forall($x,Implies(Human($x),Alive($x)))
class Forall(Formula):
def __init__(self, var, body):
self.var = self.ensureType(toExpr(var), Variable)
self.body = self.ensureFormula(body)
self.strRepn = None
def computeStrRepn(self): return 'Forall(' + str(self.var) + ',' + str(self.body) + ')'
# Take a list of conjuncts / disjuncts and return a formula
def AndList(forms):
result = AtomTrue
for form in forms:
result = And(result, form) if result != AtomTrue else form
return result
def OrList(forms):
result = AtomFalse
for form in forms:
result = Or(result, form) if result != AtomFalse else form
return result
# Return list of conjuncts of |form|.
# Example: And(And(A, Or(B, C)), Not(D)) => [A, Or(B, C), Not(D)]
def flattenAnd(form):
if form.isa(And):
return flattenAnd(form.arg1) + flattenAnd(form.arg2)
else:
return [form]
# Return list of disjuncts of |form|.
# Example: Or(Or(A, And(B, C)), D) => [A, And(B, C), Not(D)]
def flattenOr(form):
if form.isa(Or):
return flattenOr(form.arg1) + flattenOr(form.arg2)
else:
return [form]
# Syntactic sugar
def Equiv(a, b): return And(Implies(a, b), Implies(b, a))
def Xor(a, b): return And(Or(a, b), Not(And(a, b)))
# Special predicate which is used internally (e.g., in propositionalization).
def Equals(x, y): return Atom('Equals', x, y)
# Given a predicate name (e.g., Parent), return a formula that asserts that
# that predicate is anti-reflexive
# (e.g., Not(Parent(x,x))).
def AntiReflexive(predicate):
# return Forall('$x', Not(Atom(predicate, '$x', '$x')))
# Force Equals() to be used and show up in the models.
return Forall('$x', Forall('$y', Implies(Atom(predicate, '$x', '$y'),
Not(Equals('$x', '$y')))))
############################################################
# Simple inference rules
# A Rule takes a sequence of argument Formulas and produces a set of result
# Formulas (possibly [] if the rule doesn't apply).
class Rule:
pass
class UnaryRule(Rule):
def applyRule(self, form): raise Exception('Override me')
class BinaryRule(Rule):
def applyRule(self, form1, form2): raise Exception('Override me')
# Override if rule is symmetric to save a factor of 2.
def symmetric(self): return False
############################################################
# Unification
# Mutate |subst| with variable => bindings
# Return whether unification was successful
# Assume forms are in CNF.
# Note: we don't do occurs check because we don't have function symbols.
def unify(form1, form2, subst):
if form1.isa(Variable): return unifyTerms(form1, form2, subst)
if form1.isa(Constant): return unifyTerms(form1, form2, subst)
if form1.isa(Atom):
return form2.isa(Atom) and form1.name == form2.name and len(form1.args) == len(form2.args) and \
all(unify(form1.args[i], form2.args[i], subst) for i in range(len(form1.args)))
if form1.isa(Not):
return form2.isa(Not) and unify(form1.arg, form2.arg, subst)
if form1.isa(And):
return form2.isa(And) and unify(form1.arg1, form2.arg1, subst) and unify(form1.arg2, form2.arg2, subst)
if form1.isa(Or):
return form2.isa(Or) and unify(form1.arg1, form2.arg1, subst) and unify(form1.arg2, form2.arg2, subst)
raise Exception('Unhandled: %s' % form1)
# Follow multiple links to get to x
def getSubst(subst, x):
while True:
y = subst.get(x)
if y == None: return x
x = y
def unifyTerms(a, b, subst):
# print 'unifyTerms', a, b, rstr(subst)
a = getSubst(subst, a)
b = getSubst(subst, b)
if a == b: return True
if a.isa(Variable):
subst[a] = b
elif b.isa(Variable):
subst[b] = a
else:
return False
return True
# Assume form in CNF.
def applySubst(form, subst):
if len(subst) == 0: return form
if form.isa(Variable):
# print 'applySubst', rstr(form), rstr(subst), rstr(subst.get(form, form))
# return subst.get(form, form)
return getSubst(subst, form)
if form.isa(Constant): return form
if form.isa(Atom): return Atom(*[form.name] + [applySubst(arg, subst) for arg in form.args])
if form.isa(Not): return Not(applySubst(form.arg, subst))
if form.isa(And): return And(applySubst(form.arg1, subst), applySubst(form.arg2, subst))
if form.isa(Or): return Or(applySubst(form.arg1, subst), applySubst(form.arg2, subst))
raise Exception('Unhandled: %s' % form)
############################################################
# Convert to CNF, Resolution rules
def withoutElementAt(items, i): return items[0:i] + items[i + 1:]
def negateFormula(item):
return item.arg if item.isa(Not) else Not(item)
# Given a list of Formulas, return a new list with:
# - If A and Not(A) exists, return [AtomFalse] for conjunction, [AtomTrue] for disjunction
# - Remove duplicates
# - Sort the list
def reduceFormulas(items, mode):
for i in range(len(items)):
for j in range(i + 1, len(items)):
if negateFormula(items[i]) == items[j]:
if mode == And:
return [AtomFalse]
elif mode == Or:
return [AtomTrue]
else:
raise Exception("Invalid mode: %s" % mode)
items = sorted(set(items), key=str)
return items
# Generate a list of all subexpressions of a formula (including terms).
# Example:
# - Input: And(Atom('A', Constant('a')), Atom('B'))
# - Output: [And(Atom('A', Constant('a')), Atom('B')), Atom('A', Constant('a')), Constant('a'), Atom('B')]
def allSubexpressions(form):
subforms = []
def recurse(form):
subforms.append(form)
if form.isa(Variable):
pass
elif form.isa(Constant):
pass
elif form.isa(Atom):
for arg in form.args: recurse(arg)
elif form.isa(Not):
recurse(form.arg)
elif form.isa(And):
recurse(form.arg1); recurse(form.arg2)
elif form.isa(Or):
recurse(form.arg1); recurse(form.arg2)
elif form.isa(Implies):
recurse(form.arg1); recurse(form.arg2)
elif form.isa(Exists):
recurse(form.body)
elif form.isa(Forall):
recurse(form.body)
else:
raise Exception("Unhandled: %s" % form)
recurse(form)
return subforms
# Return a list of the free variables in |form|.
def allFreeVars(form):
variables = []
def recurse(form, boundVars):
if form.isa(Variable):
if form not in boundVars: variables.append(form)
elif form.isa(Constant):
pass
elif form.isa(Atom):
for arg in form.args: recurse(arg, boundVars)
elif form.isa(Not):
recurse(form.arg, boundVars)
elif form.isa(And):
recurse(form.arg1, boundVars); recurse(form.arg2, boundVars)
elif form.isa(Or):
recurse(form.arg1, boundVars); recurse(form.arg2, boundVars)
elif form.isa(Implies):
recurse(form.arg1, boundVars); recurse(form.arg2, boundVars)
elif form.isa(Exists):
recurse(form.body, boundVars + [form.var])
elif form.isa(Forall):
recurse(form.body, boundVars + [form.var])
else:
raise Exception("Unhandled: %s" % form)
recurse(form, [])
return variables
# Return |form| with all free occurrences of |var| replaced with |obj|.
def substituteFreeVars(form, var, obj):
def recurse(form, boundVars):
if form.isa(Variable):
if form == var: return obj
return form
elif form.isa(Constant):
return form
elif form.isa(Atom):
return Atom(*[form.name] + [recurse(arg, boundVars) for arg in form.args])
elif form.isa(Not):
return Not(recurse(form.arg, boundVars))
elif form.isa(And):
return And(recurse(form.arg1, boundVars), recurse(form.arg2, boundVars))
elif form.isa(Or):
return Or(recurse(form.arg1, boundVars), recurse(form.arg2, boundVars))
elif form.isa(Implies):
return Implies(recurse(form.arg1, boundVars), recurse(form.arg2, boundVars))
elif form.isa(Exists):
if form.var == var: return form # Don't substitute inside
return Exists(form.var, recurse(form.body, boundVars + [form.var]))
elif form.isa(Forall):
if form.var == var: return form # Don't substitute inside
return Forall(form.var, recurse(form.body, boundVars + [form.var]))
else:
raise Exception("Unhandled: %s" % form)
return recurse(form, [])
def allConstants(form):
return [x for x in allSubexpressions(form) if x.isa(Constant)]
class ToCNFRule(UnaryRule):
def __init__(self):
# For standardizing variables.
# For each existing variable name, the number of times it has occurred
self.varCounts = collections.Counter()
def applyRule(self, form):
newForm = form
# Step 1: remove implications
def removeImplications(form):
if form.isa(Atom): return form
if form.isa(Not): return Not(removeImplications(form.arg))
if form.isa(And): return And(removeImplications(form.arg1), removeImplications(form.arg2))
if form.isa(Or): return Or(removeImplications(form.arg1), removeImplications(form.arg2))
if form.isa(Implies): return Or(Not(removeImplications(form.arg1)), removeImplications(form.arg2))
if form.isa(Exists): return Exists(form.var, removeImplications(form.body))
if form.isa(Forall): return Forall(form.var, removeImplications(form.body))
raise Exception("Unhandled: %s" % form)
newForm = removeImplications(newForm)
# Step 2: push negation inwards (de Morgan)
def pushNegationInwards(form):
if form.isa(Atom): return form
if form.isa(Not):
if form.arg.isa(Not): # Double negation
return pushNegationInwards(form.arg.arg)
if form.arg.isa(And): # De Morgan
return Or(pushNegationInwards(Not(form.arg.arg1)), pushNegationInwards(Not(form.arg.arg2)))
if form.arg.isa(Or): # De Morgan
return And(pushNegationInwards(Not(form.arg.arg1)), pushNegationInwards(Not(form.arg.arg2)))
if form.arg.isa(Exists):
return Forall(form.arg.var, pushNegationInwards(Not(form.arg.body)))
if form.arg.isa(Forall):
return Exists(form.arg.var, pushNegationInwards(Not(form.arg.body)))
return form
if form.isa(And): return And(pushNegationInwards(form.arg1), pushNegationInwards(form.arg2))
if form.isa(Or): return Or(pushNegationInwards(form.arg1), pushNegationInwards(form.arg2))
if form.isa(Implies): return Or(Not(pushNegationInwards(form.arg1)), pushNegationInwards(form.arg2))
if form.isa(Exists): return Exists(form.var, pushNegationInwards(form.body))
if form.isa(Forall): return Forall(form.var, pushNegationInwards(form.body))
raise Exception("Unhandled: %s" % form)
newForm = pushNegationInwards(newForm)
# Step 3: standardize variables: make sure all variables are different
# Don't modify subst; return a new version where var is mapped onto
# something that hasn't been seen before.
def updateSubst(subst, var):
self.varCounts[var.name] += 1
newVar = Variable(var.name + str(self.varCounts[var.name]))
return dict(list(subst.items()) + [(var, newVar)])
def standardizeVariables(form, subst):
if form.isa(Variable):
if form not in subst: raise Exception("Free variable found: %s" % form)
return subst[form]
if form.isa(Constant): return form
if form.isa(Atom): return Atom(*([form.name] + [standardizeVariables(arg, subst) for arg in form.args]))
if form.isa(Not): return Not(standardizeVariables(form.arg, subst))
if form.isa(And): return And(standardizeVariables(form.arg1, subst), standardizeVariables(form.arg2, subst))
if form.isa(Or): return Or(standardizeVariables(form.arg1, subst), standardizeVariables(form.arg2, subst))
if form.isa(Exists):
newSubst = updateSubst(subst, form.var)
return Exists(newSubst[form.var], standardizeVariables(form.body, newSubst))
if form.isa(Forall):
newSubst = updateSubst(subst, form.var)
return Forall(newSubst[form.var], standardizeVariables(form.body, newSubst))
raise Exception("Unhandled: %s" % form)
newForm = standardizeVariables(newForm, {})
# Step 4: replace existentially quantified variables with Skolem functions
def skolemize(form, subst, scope):
if form.isa(Variable): return subst.get(form, form)
if form.isa(Constant): return form
if form.isa(Atom): return Atom(*[form.name] + [skolemize(arg, subst, scope) for arg in form.args])
if form.isa(Not): return Not(skolemize(form.arg, subst, scope))
if form.isa(And): return And(skolemize(form.arg1, subst, scope), skolemize(form.arg2, subst, scope))
if form.isa(Or): return Or(skolemize(form.arg1, subst, scope), skolemize(form.arg2, subst, scope))
if form.isa(Exists):
# Create a Skolem function that depends on the variables in the scope (list of variables)
# Example:
# - Suppose scope = [$x, $y] and form = Exists($z,F($z)).
# - Normally, we would return F(Z($x,$y)), where Z is a brand new Skolem function.
# - But since we don't have function symbols, we replace with a Skolem predicate:
# Forall($z,Implies(Z($z,$x,$y),F($z)))
# Important: when doing resolution, need to catch Not(Z($z,*,*)) as a contradiction.
if len(scope) == 0:
subst[form.var] = Constant('skolem' + form.var.name)
return skolemize(form.body, subst, scope)
else:
skolem = Atom(*['Skolem' + form.var.name, form.var] + scope)
return Forall(form.var, Or(Not(skolem), skolemize(form.body, subst, scope)))
if form.isa(Forall):
return Forall(form.var, skolemize(form.body, subst, scope + [form.var]))
raise Exception("Unhandled: %s" % form)
newForm = skolemize(newForm, {}, [])
# Step 5: remove universal quantifiers [note: need to do this before distribute, unlike Russell/Norvig book]
def removeUniversalQuantifiers(form):
if form.isa(Atom): return form
if form.isa(Not): return Not(removeUniversalQuantifiers(form.arg))
if form.isa(And): return And(removeUniversalQuantifiers(form.arg1), removeUniversalQuantifiers(form.arg2))
if form.isa(Or): return Or(removeUniversalQuantifiers(form.arg1), removeUniversalQuantifiers(form.arg2))
if form.isa(Forall): return removeUniversalQuantifiers(form.body)
raise Exception("Unhandled: %s" % form)
newForm = removeUniversalQuantifiers(newForm)
# Step 6: distribute Or over And (want And on the outside): Or(And(A,B),C) becomes And(Or(A,C),Or(B,C))
def distribute(form):
if form.isa(Atom): return form
if form.isa(Not): return Not(distribute(form.arg))
if form.isa(And): return And(distribute(form.arg1), distribute(form.arg2))
if form.isa(Or):
# First need to distribute as much as possible
f1 = distribute(form.arg1)
f2 = distribute(form.arg2)
if f1.isa(And):
return And(distribute(Or(f1.arg1, f2)), distribute(Or(f1.arg2, f2)))
if f2.isa(And):
return And(distribute(Or(f1, f2.arg1)), distribute(Or(f1, f2.arg2)))
return Or(f1, f2)
if form.isa(Exists): return Exists(form.var, distribute(form.body))
if form.isa(Forall): return Forall(form.var, distribute(form.body))
raise Exception("Unhandled: %s" % form)
newForm = distribute(newForm)
# Post-processing: break up conjuncts into conjuncts and sort the disjuncts in each conjunct
# Remove instances of A and Not(A)
conjuncts = [OrList(reduceFormulas(flattenOr(f), Or)) for f in flattenAnd(newForm)]
# print rstr(form), rstr(conjuncts)
assert len(conjuncts) > 0
if any(x == AtomFalse for x in conjuncts): return [AtomFalse]
if all(x == AtomTrue for x in conjuncts): return [AtomTrue]
conjuncts = [x for x in conjuncts if x != AtomTrue]
results = reduceFormulas(conjuncts, And)
if len(results) == 0: results = [AtomFalse]
# print 'CNF', form, rstr(results)
return results
class ResolutionRule(BinaryRule):
# Assume formulas are in CNF
# Assume A and Not(A) don't both exist in a form (taken care of by CNF conversion)
def applyRule(self, form1, form2):
items1 = flattenOr(form1)
items2 = flattenOr(form2)
results = []
# print 'RESOLVE', form1, form2
for i, item1 in enumerate(items1):
for j, item2 in enumerate(items2):
subst = {}
if unify(negateFormula(item1), item2, subst):
newItems1 = withoutElementAt(items1, i)
newItems2 = withoutElementAt(items2, j)
newItems = [applySubst(item, subst) for item in newItems1 + newItems2]
if len(newItems) == 0: # Contradiction: False
results = [AtomFalse]
break
# print 'STEP: %s %s => %s %s' % (form1, form2, rstr(newItems), rstr(subst))
result = OrList(reduceFormulas(newItems, Or))
# Not(Skolem$x($x,...)) is a contradiction
if isinstance(result, Not) and result.arg.name.startswith('Skolem'):
results = [AtomFalse]
break
# Don't add redundant stuff
if result == AtomTrue: continue
if result in results: continue
results.append(result)
if results == [AtomFalse]: break
# print 'RESOLUTION: %s %s => %s' % (form1, form2, rstr(results))
return results
def symmetric(self):
return True
############################################################
# Model checking
# Return the set of models
def performModelChecking(allForms, findAll, objects=None, verbose=0):
if verbose >= 3:
print(('performModelChecking', rstr(allForms)))
# Propositionalize, convert to CNF, dedup
allForms = propositionalize(allForms, objects)
# Convert to CNF: actually makes things slower
# allForms = [f for form in allForms for f in ToCNFRule().applyRule(form)]
# if any(x == AtomFalse for x in allForms): return []
# if all(x == AtomTrue for x in allForms): return [set()]
# allForms = [x for x in allForms if x != AtomTrue]
# allForms = reduceFormulas(allForms, And)
allForms = [universalInterpret(form) for form in allForms]
allForms = list(set(allForms) - set([AtomTrue, AtomFalse]))
if verbose >= 3:
print(('All Forms:', rstr(allForms)))
if allForms == []: return [set()] # One model
if allForms == [AtomFalse]: return [] # No models
# Atoms are the variables
atoms = set()
for form in allForms:
for f in allSubexpressions(form):
if f.isa(Atom): atoms.add(f)
atoms = list(atoms)
if verbose >= 3:
print(('Atoms:', rstr(atoms)))
print(('Constraints:', rstr(allForms)))
# For each atom, list the set of formulas
# atom index => list of formulas
atomForms = [
(atom, [form for form in allForms if atom in allSubexpressions(form)]) \
for atom in atoms \
]
# Degree heuristic
atomForms.sort(key=lambda x: -len(x[1]))
atoms = [atom for atom, form in atomForms]
# Keep only the forms for an atom if it only uses atoms up until that point.
atomPrefixForms = []
for i, (atom, forms) in enumerate(atomForms):
prefixForms = []
for form in forms:
useAtoms = set(x for x in allSubexpressions(form) if x.isa(Atom))
if useAtoms <= set(atoms[0:i + 1]):
prefixForms.append(form)
atomPrefixForms.append((atom, prefixForms))
if verbose >= 3:
print('Plan:')
for atom, forms in atomForms:
print((" %s: %s" % (rstr(atom), rstr(forms))))
assert sum(len(forms) for atom, forms in atomPrefixForms) == len(allForms)
# Build up an interpretation
N = len(atoms)
models = [] # List of models which are true
model = set() # Set of true atoms, mutated over time
def recurse(i): # i: atom index
if not findAll and len(models) > 0: return
if i == N: # Found a model on which the formulas are true
models.append(set(model))
return
atom, forms = atomPrefixForms[i]
result = universalInterpretAtom(atom)
if result == None or result == False:
if interpretForms(forms, model): recurse(i + 1)
if result == None or result == True:
model.add(atom)
if interpretForms(forms, model): recurse(i + 1)
model.remove(atom)
recurse(0)
if verbose >= 5:
print('Models:')
for model in models:
print((" %s" % rstr(model)))
return models
# A model is a set of atoms.
def printModel(model):
for x in sorted(map(str, model)):
print(('*', x, '=', 'True'))
print(('*', '(other atoms if any)', '=', 'False'))
# Convert a first-order logic formula into a propositional formula, assuming
# database semantics.
# Example: Forall becomes And over all objects
# - Input: form = Forall('$x', Atom('Alive', '$x')), objects = ['alice', 'bob']
# - Output: And(Atom('Alive', 'alice'), Atom('Alive', 'bob'))
# Example: Exists becomes Or over all objects
# - Input: form = Exists('$x', Atom('Alive', '$x')), objects = ['alice', 'bob']
# - Output: Or(Atom('Alive', 'alice'), Atom('Alive', 'bob'))
def propositionalize(forms, objects=None):
# If not specified, set objects to all constants mentioned in in |form|.
if objects == None:
objects = set()
for form in forms:
objects |= set(allConstants(form))
objects = list(objects)
else:
# Make sure objects are expressions: Convert ['a', 'b'] to [Constant('a'), Constant('b')]
objects = [toExpr(obj) for obj in objects]
# Recursively convert |form|, which could contain Exists and Forall, to forms that don't contain these quantifiers.
# |subst| is a map from variables to constants.
def convert(form, subst):
if form.isa(Variable):
if form not in subst: raise Exception("Free variable found: %s" % form)
return subst[form]
if form.isa(Constant): return form
if form.isa(Atom):
return Atom(*[form.name] + [convert(arg, subst) for arg in form.args])
if form.isa(Not): return Not(convert(form.arg, subst))
if form.isa(And): return And(convert(form.arg1, subst), convert(form.arg2, subst))
if form.isa(Or): return Or(convert(form.arg1, subst), convert(form.arg2, subst))
if form.isa(Implies): return Implies(convert(form.arg1, subst), convert(form.arg2, subst))
if form.isa(Exists):
return OrList([convert(form.body, dict(list(subst.items()) + [(form.var, obj)])) for obj in objects])
if form.isa(Forall):
return AndList([convert(form.body, dict(list(subst.items()) + [(form.var, obj)])) for obj in objects])
raise Exception("Unhandled: %s" % form)
# Think of newForms as conjoined
newForms = []
# Convert all the forms
for form in forms:
newForm = convert(form, {})
if newForm == AtomFalse: return [AtomFalse]
if newForm == AtomTrue: continue
newForms.extend(flattenAnd(newForm))
return newForms
# Some atoms have a fixed value, so we should just evaluate them.
# Assumption: atom is propositional logic.
def universalInterpretAtom(atom):
if atom.name == 'Equals':
return AtomTrue if atom.args[0] == atom.args[1] else AtomFalse
return None
# Reduce the expression (e.g., Equals(a,a) => True)
# Assumption: atom is propositional logic.
def universalInterpret(form):
if form.isa(Variable): return form
if form.isa(Constant): return form
if form.isa(Atom):
result = universalInterpretAtom(form)
if result != None: return result
return Atom(*[form.name] + [universalInterpret(arg) for arg in form.args])
if form.isa(Not):
arg = universalInterpret(form.arg)
if arg == AtomTrue: return AtomFalse
if arg == AtomFalse: return AtomTrue
return Not(arg)
if form.isa(And):
arg1 = universalInterpret(form.arg1)
arg2 = universalInterpret(form.arg2)
if arg1 == AtomFalse: return AtomFalse
if arg2 == AtomFalse: return AtomFalse
if arg1 == AtomTrue: return arg2
if arg2 == AtomTrue: return arg1
return And(arg1, arg2)
if form.isa(Or):
arg1 = universalInterpret(form.arg1)
arg2 = universalInterpret(form.arg2)
if arg1 == AtomTrue: return AtomTrue
if arg2 == AtomTrue: return AtomTrue
if arg1 == AtomFalse: return arg2
if arg2 == AtomFalse: return arg1
return Or(arg1, arg2)
if form.isa(Implies):
arg1 = universalInterpret(form.arg1)
arg2 = universalInterpret(form.arg2)
if arg1 == AtomFalse: return AtomTrue
if arg2 == AtomTrue: return AtomTrue
if arg1 == AtomTrue: return arg2
if arg2 == AtomFalse: return Not(arg1)
return Implies(arg1, arg2)
raise Exception("Unhandled: %s" % form)
def interpretForm(form, model):
if form.isa(Atom): return form in model
if form.isa(Not): return not interpretForm(form.arg, model)
if form.isa(And): return interpretForm(form.arg1, model) and interpretForm(form.arg2, model)
if form.isa(Or): return interpretForm(form.arg1, model) or interpretForm(form.arg2, model)
if form.isa(Implies): return not interpretForm(form.arg1, model) or interpretForm(form.arg2, model)
raise Exception("Unhandled: %s" % form)
# Conjunction
def interpretForms(forms, model):
return all(interpretForm(form, model) for form in forms)
############################################################
# A Derivation is a tree where each node corresponds to the application of a rule.
# For any Formula, we can extract a set of categories.
# Rule arguments are labeled with category.
class Derivation:
def __init__(self, form, children, cost, derived):
self.form = form
self.children = children
self.cost = cost
self.permanent = False # Marker for being able to extract.
self.derived = derived # Whether this was derived (as opposed to added by the user).
def __repr__(self): return 'Derivation(%s, cost=%s, permanent=%s, derived=%s)' % (
self.form, self.cost, self.permanent, self.derived)
# Possible responses to queries to the knowledge base
ENTAILMENT = "ENTAILMENT"
CONTINGENT = "CONTINGENT"
CONTRADICTION = "CONTRADICTION"
# A response to a KB query
class KBResponse:
# query: what the query is (just a string description for printing)
# modify: whether we modified the knowledge base
# status: one of the ENTAILMENT, CONTINGENT, CONTRADICTION
# trueModel: if available, a model consistent with the KB for which the the query is true
# falseModel: if available, a model consistent with the KB for which the the query is false
def __init__(self, query, modify, status, trueModel, falseModel):
self.query = query
self.modify = modify
self.status = status
self.trueModel = trueModel
self.falseModel = falseModel
def show(self, verbose=1):
padding = '>>>>>'
print(padding + ' ' + self.responseStr())
if verbose >= 1:
print(('Query: %s[%s]' % ('TELL' if self.modify else 'ASK', self.query)))
if self.trueModel:
print('An example of a model where query is TRUE:')
printModel(self.trueModel)
if self.falseModel:
print('An example of a model where query is FALSE:')
printModel(self.falseModel)
def responseStr(self):
if self.status == ENTAILMENT:
if self.modify:
return 'I already knew that.'
else:
return 'Yes.'
elif self.status == CONTINGENT:
if self.modify:
return 'I learned something.'
else:
return 'I don\'t know.'
elif self.status == CONTRADICTION:
if self.modify:
return 'I don\'t buy that.'
else:
return 'No.'
else:
raise Exception("Invalid status: %s" % self.status)
def __repr__(self):
return self.responseStr()
def showKBResponse(response, verbose=1):
if isinstance(response, KBResponse):
response.show(verbose)
else:
items = [(obj, r.status) for ((var, obj), r) in list(response.items())]
print(('Yes: %s' % rstr([obj for obj, status in items if status == ENTAILMENT])))
print(('Maybe: %s' % rstr([obj for obj, status in items if status == CONTINGENT])))
print(('No: %s' % rstr([obj for obj, status in items if status == CONTRADICTION])))
# A KnowledgeBase is a set collection of Formulas.
# Interact with it using
# - addRule: add inference rules
# - tell: modify the KB with a new formula.
# - ask: query the KB about
class KnowledgeBase:
def __init__(self, standardizationRule, rules, modelChecking, verbose=0):
# Rule to apply to each formula that's added to the KB (None is possible).
self.standardizationRule = standardizationRule
# Set of inference rules
self.rules = rules
# Use model checking as opposed to applying rules.
self.modelChecking = modelChecking
# For debugging
self.verbose = verbose
# Formulas that we believe are true (used when not doing model checking).
self.derivations = {} # Map from Derivation key (logical form) to Derivation
# Add a formula |form| to the KB if it doesn't contradict. Returns a KBResponse.
def tell(self, form):
return self.query(form, modify=True)
# Ask whether the logical formula |form| is True, False, or unknown based
# on the KB. Returns a KBResponse.
def ask(self, form):
return self.query(form, modify=False)
def dump(self):
print(('==== Knowledge base [%d derivations] ===' % len(self.derivations)))
for deriv in list(self.derivations.values()):
print((('-' if deriv.derived else '*'), deriv if self.verbose >= 2 else deriv.form))
####### Internal functions
# Returns a KBResponse or if there are free variables, a mapping from (var, obj) => query without that variable.
def query(self, form, modify):
# print 'QUERY', form
# Handle wh-queries: try all possible values of the free variable, and recurse on query().
freeVars = allFreeVars(form)
if len(freeVars) > 0:
if modify:
raise Exception("Can't modify database with a query with free variables: %s" % form)
var = freeVars[0]
allForms = AndList([deriv.form for deriv in list(self.derivations.values())])
if allForms == AtomTrue: return {} # Weird corner case
objects = allConstants(allForms)
# Try binding |var| to |obj|
response = {}
for obj in objects:
response[(var, obj)] = self.query(substituteFreeVars(form, var, obj), modify)
return response
# Assume no free variables from here on...
formStr = '%s, standardized: %s' % (form, rstr(self.standardize(form)))
# Models to serve as supporting evidence
falseModel = None # Makes the query false
trueModel = None # Makes the query true
# Add Not(form)
if not self.addAxiom(Not(form)):
self.removeTemporary()
status = ENTAILMENT
else:
# Inconclusive...
falseModel = self.consistentModel
self.removeTemporary()
# Add form
if self.addAxiom(form):
if modify:
self.makeTemporaryPermanent()
else:
self.removeTemporary()
trueModel = self.consistentModel
status = CONTINGENT
else:
self.removeTemporary()
status = CONTRADICTION
return KBResponse(query=formStr, modify=modify, status=status, trueModel=trueModel, falseModel=falseModel)
# Apply the standardization rule to |form|.
def standardize(self, form):
if self.standardizationRule:
return self.standardizationRule.applyRule(form)
return [form]
# Return whether adding |form| is consistent with the current knowledge base.
# Add |form| to the knowledge base if we can. Note: this is done temporarily!
# Just calls addDerivation.
def addAxiom(self, form):
self.consistentModel = None
for f in self.standardize(form):
if f == AtomFalse: return False
if f == AtomTrue: continue
deriv = Derivation(f, children=[], cost=0, derived=False)
if not self.addDerivation(deriv): return False
return True
# Return whether the Derivation is consistent with the KB.
def addDerivation(self, deriv):
# Derived a contradiction
if deriv.form == AtomFalse: return False
key = deriv.form
oldDeriv = self.derivations.get(key)
maxCost = 100
if oldDeriv == None and deriv.cost <= maxCost:
# if oldDeriv == None or (deriv.cost < oldDeriv.cost and (deriv.permanent >= oldDeriv.permanent)):
# print 'UPDATE %s %s' % (deriv, oldDeriv)
# self.dump()
# Something worth updating
self.derivations[key] = deriv
if self.verbose >= 3: print(('add %s [%s derivations]' % (deriv, len(self.derivations))))
if self.modelChecking:
allForms = [deriv.form for deriv in list(self.derivations.values())]
models = performModelChecking(allForms, findAll=False, verbose=self.verbose)
if len(models) == 0:
return False
else:
self.consistentModel = models[0]
# Apply rules forward
if not self.applyUnaryRules(deriv): return False
for key2, deriv2 in list(self.derivations.items()):
if not self.applyBinaryRules(deriv, deriv2): return False
if not self.applyBinaryRules(deriv2, deriv): return False
return True
# Raise an exception if |formulas| is not a list of Formulas.
def ensureFormulas(self, rule, formulas):
if isinstance(formulas, list) and all(formula == False or isinstance(formula, Formula) for formula in formulas):
return formulas
raise Exception('Expected list of Formulas, but %s returned %s' % (rule, formulas))
# Return whether everything is okay (no contradiction).
def applyUnaryRules(self, deriv):
for rule in self.rules:
if not isinstance(rule, UnaryRule): continue
for newForm in self.ensureFormulas(rule, rule.applyRule(deriv.form)):
if not self.addDerivation(Derivation(newForm, children=[deriv], cost=deriv.cost + 1, derived=True)):
return False
return True
# Return whether everything is okay (no contradiction).
def applyBinaryRules(self, deriv1, deriv2):
for rule in self.rules:
if not isinstance(rule, BinaryRule): continue
if rule.symmetric() and str(deriv1.form) >= str(deriv2.form): continue # Optimization
for newForm in self.ensureFormulas(rule, rule.applyRule(deriv1.form, deriv2.form)):
if not self.addDerivation(
Derivation(newForm, children=[deriv1, deriv2], cost=deriv1.cost + deriv2.cost + 1,
derived=True)):
return False
return True
# Remove all the temporary derivations from the KB.
def removeTemporary(self):
for key, value in list(self.derivations.items()):
if not value.permanent:
del self.derivations[key]
# Mark all the derivations marked temporary to permanent.
def makeTemporaryPermanent(self):
for deriv in list(self.derivations.values()):
deriv.permanent = True
# Create an empty knowledge base equipped with the usual inference rules.
def createResolutionKB():
return KnowledgeBase(standardizationRule=ToCNFRule(), rules=[ResolutionRule()], modelChecking=False)
def createModelCheckingKB():
return KnowledgeBase(standardizationRule=None, rules=[], modelChecking=True)