added test code for incremental cluster pattern matching

This commit is contained in:
kwikrick 2012-11-02 08:59:52 +00:00
parent fc9f9f350e
commit d0381e7e9b
2 changed files with 451 additions and 1 deletions

View File

@ -13,7 +13,7 @@ from cluster import *
from configuration import Configuration from configuration import Configuration
from gmatch import gmatch from gmatch import gmatch
from method import OrMethod,SetMethod from method import OrMethod,SetMethod
from incremental import MutableSet,Union,Filter from incremental import IncrementalSet,MutableSet,Union,Filter
# -------------------------------------------------- # --------------------------------------------------
# ---------- ClusterSolver main class -------------- # ---------- ClusterSolver main class --------------
@ -976,4 +976,249 @@ def rootname(cluster):
"""returns the name of the root variable associated with the name of a cluster variable""" """returns the name of the root variable associated with the name of a cluster variable"""
return "root#"+str(id(cluster)) return "root#"+str(id(cluster))
class KindFilter(Filter):
def __init__(self, kind, incrset):
kind2class["rigid"]=Rigid
kind2class["hog"]=Hedgehog
kind2class["balloon"]=Balloon
kind2class["distance"]=Rigid
kind2class["point"]=Rigid
kind2maxpoints["rigid"]=0
kind2maxpoints["hog"]=0
kind2maxpoints["balloon"]=0
kind2maxpoints["distance"]=2
kind2maxpoints["point"]=1
self._kind = kind
self._classobj = kind2class(kind)
self._maxpoints = kind2maxpoints(kind)
IncrementalSet.__init__(self,[incrset])
def _receive_add(self, source, object):
if isinstance(object,self._classobj) and len(object.vars) == self.maxpoints:
self._add(object)
def _receive_remove(self, source, object):
self._remove(object)
def __eq__(self, other):
if isinstance(other, KindFilter):
return (self._incrset, self._kind)==(other._incrset,other._kind)
else:
return False
def __hash__(self):
return hash((self._incrset, self._kind))
def __repr__(self):
return "KindFilter(%s,%s)"%(str(self._kind),str(self._incrset))
# ---------------------------------------------------------
# ---------- incremental cluster pattern matching ---------
# ---------------------------------------------------------
class KindFilter(Filter):
def __init__(self, kind, minpoints, incrset):
kind2class = {}
kind2class["rigid"]=Rigid
kind2class["hog"]=Hedgehog
kind2class["balloon"]=Balloon
kind2class["distance"]=Rigid
kind2class["point"]=Rigid
kind2maxpoints = {}
kind2maxpoints["rigid"]=0
kind2maxpoints["hog"]=0
kind2maxpoints["balloon"]=0
kind2maxpoints["distance"]=2
kind2maxpoints["point"]=1
self._kind = kind
self._classobj = kind2class[kind]
self._maxpoints = kind2maxpoints[kind]
self._input = incrset
self._minpoints = minpoints
IncrementalSet.__init__(self,[incrset])
def _receive_add(self, source, object):
if (isinstance(object,self._classobj) and
(len(object.vars) <= self._maxpoints or self._maxpoints == 0) and
len(object.vars) >= self._minpoints):
self._add(object)
def _receive_remove(self, source, object):
self._remove(object)
def __eq__(self, other):
if isinstance(other, KindFilter):
return (self._input, self._kind, self._minpoints)==(other._input,other._kind, self._minpoints)
else:
return False
def __hash__(self):
return hash((self._input, self._kind, self._minpoints))
def __repr__(self):
return "KindFilter(%s,%s,%s)"%(str(self._kind),str(self._minpoints),str(self._input))
class NConnectedPairs(IncrementalSet):
"""Incremental set of all unordered pairs of N-connected clusters in 1 incremental sets."""
def __init__(self, solver, n, incrset1, incrset2):
"""Creates an incremental set of all pairs frozetset([c1, c2]) from incrset1 and incrset2 respectively,
that are connected with N variables, according to solver"""
# defining variables
self._solver = solver
self._incrset1 = incrset1
self._incrset2 = incrset2
self._n = n
# map from objects to sets of pairs
self._map = {}
# super init
IncrementalSet.__init__(self, [incrset1, incrset2])
def _receive_add(self,source, obj):
# determine 1-connected objects
connected = set()
for var in obj.vars:
dependend = self._solver.find_dependend(var)
# check that connected objects in both sets
if source == self._incrset1:
dependend = filter(lambda x: x in self._incrset2, dependend)
elif source == self._incrset2:
dependend = filter(lambda x: x in self._incrset1, dependend)
connected.update(dependend)
# dont pair (obj,obj)
if obj in connected:
connected.remove(obj)
# for each connected object, check that #shared vars == n
for obj2 in connected:
shared = obj.vars.intersection(obj2.vars)
if len(shared) == self._n:
# add new pair
pair = frozenset([obj, obj2])
self._add(pair)
# add to mapping
if obj not in self._map:
self._map[obj] = set()
if obj2 not in self._map:
self._map[obj2] = set()
self._map[obj].add(pair)
self._map[obj2].add(pair)
def _receive_remove(self,source, obj):
# remove all pairs that contain obj
for pair in self._map[obj]:
(obj1,obj2) = pair
self._remove(pair)
# remove pair from mapping
self._map[obj1].remove(par)
self._map[obj2].remove(pair)
# clean up mapping
if len(self._map[obj1]) == 0:
del self._map[obj1]
if len(self._map[obj2]) == 0:
del self._map[obj2]
def __eq__(self, other):
if isinstance(other, NConnectedPairs):
return (self._solver == other._solver and
self._incrset1 == other._incrset1 and
self._incrset2 == other._incrset2 and
self._n == other._n)
else:
return False
def __hash__(self):
return hash((self._solver, self._incrset1, self._incrset2, self._n))
class PatternMatches(IncrementalSet):
"""Incrementally matches patterns of clusters"""
def __init__(self, pattern, solver):
self._solver = solver
# convert pattern to a set of tuples
listoftuples = []
for clusterpattern in pattern:
(kind, clustername, pointnames) = clusterpattern
listoftuples.append(tuple([kind, clustername, tuple(pointnames)]))
setoftuples = frozenset(listoftuples)
self._pattern = setoftuples
# create sub-sets for correct type of input clusters
self._subs = []
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern = {}
for clusterpattern in self._pattern:
(kind, clustername, pointnames) = clusterpattern
sub = KindFilter(kind, len(pointnames),self._solver.top_level())
self._subs.append(sub)
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern[sub]=[clusterpattern]
print "creating",sub,"for",self._source2pattern[sub]
#rof
# create sub-sets for all pairs of clusters
listoftuples = list(setoftuples)
l = len(listoftuples)
for i in range(l):
for j in range(i+1,l):
cp1 = listoftuples[i]
cp2 = listoftuples[j]
(kind1, clustername1, pointnames1) = cp1
(kind2, clustername2, pointnames2) = cp2
shared = set(pointnames1).intersection(set(pointnames2))
n = len(shared)
if n > 0:
kindfilter1 = KindFilter(kind1, len(pointnames1),self._solver.top_level())
kindfilter2 = KindFilter(kind2, len(pointnames2),self._solver.top_level())
sub = NConnectedPairs(solver, n, kindfilter1,kindfilter2)
self._subs.append(sub)
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern[sub]=[cp1,cp2]
print "creating",sub,"for",self._source2pattern[sub]
IncrementalSet.__init__(self, self._subs)
def _receive_add(self, source, object):
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
print "receive",object
pattern = self._source2pattern[source]
print "matching sub-pattern", pattern
#raise NotImplementedError
def _receive_remove(self, source, object):
raise NotImplementedError
def __eq__(self, other):
if isinstance(other, PatternMatches):
return (self._solver, self._pattern)==(other._solver,other._pattern)
else:
return False
def __hash__(self):
return hash((self._solver, self._pattern))
def __repr__(self):
return "PatternMatches(%s,%s)"%(str(self._pattern),str(self._solver))
# -------------------------------------------------------
# -------------------------- test code ------------------
# -------------------------------------------------------
def test_icpm():
solver = ClusterSolver([])
solver.add(Rigid(["p5"]))
solver.add(Rigid(["p1","p2"]))
solver.add(Rigid(["p2","p3"]))
solver.add(Rigid(["p1","p3","p4","p5"]))
pattern = [["rigid","$d_ab",["$a", "$b"]],
["rigid", "$d_ac",["$a", "$c"]],
["rigid", "$d_bc",["$b","$c"]]]
matches = PatternMatches(pattern, solver)
print list(matches)
if __name__ == "__main__":
test_icpm()

205
geosolver/icpm.py Normal file
View File

@ -0,0 +1,205 @@
"""Incremental Cluster Pattern Matching"""
from clsolver import ClusterSolver
from diagnostic import diag_print
from cluster import Rigid, Hedgehog, Balloon
from incremental import IncrementalSet,MutableSet,Filter
class KindFilter(Filter):
def __init__(self, kind, minpoints, incrset):
kind2class = {}
kind2class["rigid"]=Rigid
kind2class["hog"]=Hedgehog
kind2class["balloon"]=Balloon
kind2class["distance"]=Rigid
kind2class["point"]=Rigid
kind2maxpoints = {}
kind2maxpoints["rigid"]=0
kind2maxpoints["hog"]=0
kind2maxpoints["balloon"]=0
kind2maxpoints["distance"]=2
kind2maxpoints["point"]=1
self._kind = kind
self._classobj = kind2class[kind]
self._maxpoints = kind2maxpoints[kind]
self._input = incrset
self._minpoints = minpoints
IncrementalSet.__init__(self,[incrset])
def _receive_add(self, source, object):
if (isinstance(object,self._classobj) and
(len(object.vars) <= self._maxpoints or self._maxpoints == 0) and
len(object.vars) >= self._minpoints):
self._add(object)
def _receive_remove(self, source, object):
self._remove(object)
def __eq__(self, other):
if isinstance(other, KindFilter):
return (self._input, self._kind, self._minpoints)==(other._input,other._kind, self._minpoints)
else:
return False
def __hash__(self):
return hash((self._input, self._kind, self._minpoints))
def __repr__(self):
return "KindFilter(%s,%s,%s)"%(str(self._kind),str(self._minpoints),str(self._input))
class NConnectedPairs(IncrementalSet):
"""Incremental set of all unordered pairs of N-connected clusters in 1 incremental sets."""
def __init__(self, solver, n, incrset1, incrset2):
"""Creates an incremental set of all pairs frozetset([c1, c2]) from incrset1 and incrset2 respectively,
that are connected with N variables, according to solver"""
# defining variables
self._solver = solver
self._incrset1 = incrset1
self._incrset2 = incrset2
self._n = n
# map from objects to sets of pairs
self._map = {}
# super init
IncrementalSet.__init__(self, [incrset1, incrset2])
def _receive_add(self,source, obj):
# determine 1-connected objects
connected = set()
for var in obj.vars:
dependend = self._solver.find_dependend(var)
# check that connected objects in both sets
if source == self._incrset1:
dependend = filter(lambda x: x in self._incrset2, dependend)
elif source == self._incrset2:
dependend = filter(lambda x: x in self._incrset1, dependend)
connected.update(dependend)
# dont pair (obj,obj)
if obj in connected:
connected.remove(obj)
# for each connected object, check that #shared vars == n
for obj2 in connected:
shared = obj.vars.intersection(obj2.vars)
if len(shared) == self._n:
# add new pair
pair = frozenset([obj, obj2])
self._add(pair)
# add to mapping
if obj not in self._map:
self._map[obj] = set()
if obj2 not in self._map:
self._map[obj2] = set()
self._map[obj].add(pair)
self._map[obj2].add(pair)
def _receive_remove(self,source, obj):
# remove all pairs that contain obj
for pair in self._map[obj]:
(obj1,obj2) = pair
self._remove(pair)
# remove pair from mapping
self._map[obj1].remove(par)
self._map[obj2].remove(pair)
# clean up mapping
if len(self._map[obj1]) == 0:
del self._map[obj1]
if len(self._map[obj2]) == 0:
del self._map[obj2]
def __eq__(self, other):
if isinstance(other, NConnectedPairs):
return (self._solver == other._solver and
self._incrset1 == other._incrset1 and
self._incrset2 == other._incrset2 and
self._n == other._n)
else:
return False
def __hash__(self):
return hash((self._solver, self._incrset1, self._incrset2, self._n))
class PatternMatches(IncrementalSet):
"""Incrementally matches patterns of clusters"""
def __init__(self, pattern, solver):
self._solver = solver
# convert pattern to a set of tuples
listoftuples = []
for clusterpattern in pattern:
(kind, clustername, pointnames) = clusterpattern
listoftuples.append(tuple([kind, clustername, tuple(pointnames)]))
setoftuples = frozenset(listoftuples)
self._pattern = setoftuples
# create sub-sets for correct type of input clusters
self._subs = []
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern = {}
for clusterpattern in self._pattern:
(kind, clustername, pointnames) = clusterpattern
sub = KindFilter(kind, len(pointnames),self._solver.top_level())
self._subs.append(sub)
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern[sub]=[clusterpattern]
print "creating",sub,"for",self._source2pattern[sub]
#rof
# create sub-sets for all pairs of clusters
listoftuples = list(setoftuples)
l = len(listoftuples)
for i in range(l):
for j in range(i+1,l):
cp1 = listoftuples[i]
cp2 = listoftuples[j]
(kind1, clustername1, pointnames1) = cp1
(kind2, clustername2, pointnames2) = cp2
shared = set(pointnames1).intersection(set(pointnames2))
n = len(shared)
if n > 0:
kindfilter1 = KindFilter(kind1, len(pointnames1),self._solver.top_level())
kindfilter2 = KindFilter(kind2, len(pointnames2),self._solver.top_level())
sub = NConnectedPairs(solver, n, kindfilter1,kindfilter2)
self._subs.append(sub)
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern[sub]=[cp1,cp2]
print "creating",sub,"for",self._source2pattern[sub]
IncrementalSet.__init__(self, self._subs)
def _receive_add(self, source, object):
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
print "receive",object
pattern = self._source2pattern[source]
print "matching sub-pattern", pattern
#raise NotImplementedError
def _receive_remove(self, source, object):
raise NotImplementedError
def __eq__(self, other):
if isinstance(other, PatternMatches):
return (self._solver, self._pattern)==(other._solver,other._pattern)
else:
return False
def __hash__(self):
return hash((self._solver, self._pattern))
def __repr__(self):
return "PatternMatches(%s,%s)"%(str(self._pattern),str(self._solver))
def test_icpm():
solver = ClusterSolver([])
solver.add(Rigid(["p5"]))
solver.add(Rigid(["p1","p2"]))
solver.add(Rigid(["p2","p3"]))
solver.add(Rigid(["p1","p3","p4","p5"]))
pattern = [["rigid","$d_ab",["$a", "$b"]],
["rigid", "$d_ac",["$a", "$c"]],
["rigid", "$d_bc",["$b","$c"]]]
matches = PatternMatches(pattern, solver)
print list(matches)
if __name__ == "__main__":
test_icpm()