moved icpm code (back) to seperate file

This commit is contained in:
kwikrick 2012-11-08 17:22:52 +00:00
parent d0381e7e9b
commit 4ca358c1ef

View File

@ -976,249 +976,3 @@ def rootname(cluster):
"""returns the name of the root variable associated with the name of a cluster variable"""
return "root#"+str(id(cluster))
class KindFilter(Filter):
def __init__(self, kind, incrset):
kind2class["rigid"]=Rigid
kind2class["hog"]=Hedgehog
kind2class["balloon"]=Balloon
kind2class["distance"]=Rigid
kind2class["point"]=Rigid
kind2maxpoints["rigid"]=0
kind2maxpoints["hog"]=0
kind2maxpoints["balloon"]=0
kind2maxpoints["distance"]=2
kind2maxpoints["point"]=1
self._kind = kind
self._classobj = kind2class(kind)
self._maxpoints = kind2maxpoints(kind)
IncrementalSet.__init__(self,[incrset])
def _receive_add(self, source, object):
if isinstance(object,self._classobj) and len(object.vars) == self.maxpoints:
self._add(object)
def _receive_remove(self, source, object):
self._remove(object)
def __eq__(self, other):
if isinstance(other, KindFilter):
return (self._incrset, self._kind)==(other._incrset,other._kind)
else:
return False
def __hash__(self):
return hash((self._incrset, self._kind))
def __repr__(self):
return "KindFilter(%s,%s)"%(str(self._kind),str(self._incrset))
# ---------------------------------------------------------
# ---------- incremental cluster pattern matching ---------
# ---------------------------------------------------------
class KindFilter(Filter):
def __init__(self, kind, minpoints, incrset):
kind2class = {}
kind2class["rigid"]=Rigid
kind2class["hog"]=Hedgehog
kind2class["balloon"]=Balloon
kind2class["distance"]=Rigid
kind2class["point"]=Rigid
kind2maxpoints = {}
kind2maxpoints["rigid"]=0
kind2maxpoints["hog"]=0
kind2maxpoints["balloon"]=0
kind2maxpoints["distance"]=2
kind2maxpoints["point"]=1
self._kind = kind
self._classobj = kind2class[kind]
self._maxpoints = kind2maxpoints[kind]
self._input = incrset
self._minpoints = minpoints
IncrementalSet.__init__(self,[incrset])
def _receive_add(self, source, object):
if (isinstance(object,self._classobj) and
(len(object.vars) <= self._maxpoints or self._maxpoints == 0) and
len(object.vars) >= self._minpoints):
self._add(object)
def _receive_remove(self, source, object):
self._remove(object)
def __eq__(self, other):
if isinstance(other, KindFilter):
return (self._input, self._kind, self._minpoints)==(other._input,other._kind, self._minpoints)
else:
return False
def __hash__(self):
return hash((self._input, self._kind, self._minpoints))
def __repr__(self):
return "KindFilter(%s,%s,%s)"%(str(self._kind),str(self._minpoints),str(self._input))
class NConnectedPairs(IncrementalSet):
"""Incremental set of all unordered pairs of N-connected clusters in 1 incremental sets."""
def __init__(self, solver, n, incrset1, incrset2):
"""Creates an incremental set of all pairs frozetset([c1, c2]) from incrset1 and incrset2 respectively,
that are connected with N variables, according to solver"""
# defining variables
self._solver = solver
self._incrset1 = incrset1
self._incrset2 = incrset2
self._n = n
# map from objects to sets of pairs
self._map = {}
# super init
IncrementalSet.__init__(self, [incrset1, incrset2])
def _receive_add(self,source, obj):
# determine 1-connected objects
connected = set()
for var in obj.vars:
dependend = self._solver.find_dependend(var)
# check that connected objects in both sets
if source == self._incrset1:
dependend = filter(lambda x: x in self._incrset2, dependend)
elif source == self._incrset2:
dependend = filter(lambda x: x in self._incrset1, dependend)
connected.update(dependend)
# dont pair (obj,obj)
if obj in connected:
connected.remove(obj)
# for each connected object, check that #shared vars == n
for obj2 in connected:
shared = obj.vars.intersection(obj2.vars)
if len(shared) == self._n:
# add new pair
pair = frozenset([obj, obj2])
self._add(pair)
# add to mapping
if obj not in self._map:
self._map[obj] = set()
if obj2 not in self._map:
self._map[obj2] = set()
self._map[obj].add(pair)
self._map[obj2].add(pair)
def _receive_remove(self,source, obj):
# remove all pairs that contain obj
for pair in self._map[obj]:
(obj1,obj2) = pair
self._remove(pair)
# remove pair from mapping
self._map[obj1].remove(par)
self._map[obj2].remove(pair)
# clean up mapping
if len(self._map[obj1]) == 0:
del self._map[obj1]
if len(self._map[obj2]) == 0:
del self._map[obj2]
def __eq__(self, other):
if isinstance(other, NConnectedPairs):
return (self._solver == other._solver and
self._incrset1 == other._incrset1 and
self._incrset2 == other._incrset2 and
self._n == other._n)
else:
return False
def __hash__(self):
return hash((self._solver, self._incrset1, self._incrset2, self._n))
class PatternMatches(IncrementalSet):
"""Incrementally matches patterns of clusters"""
def __init__(self, pattern, solver):
self._solver = solver
# convert pattern to a set of tuples
listoftuples = []
for clusterpattern in pattern:
(kind, clustername, pointnames) = clusterpattern
listoftuples.append(tuple([kind, clustername, tuple(pointnames)]))
setoftuples = frozenset(listoftuples)
self._pattern = setoftuples
# create sub-sets for correct type of input clusters
self._subs = []
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern = {}
for clusterpattern in self._pattern:
(kind, clustername, pointnames) = clusterpattern
sub = KindFilter(kind, len(pointnames),self._solver.top_level())
self._subs.append(sub)
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern[sub]=[clusterpattern]
print "creating",sub,"for",self._source2pattern[sub]
#rof
# create sub-sets for all pairs of clusters
listoftuples = list(setoftuples)
l = len(listoftuples)
for i in range(l):
for j in range(i+1,l):
cp1 = listoftuples[i]
cp2 = listoftuples[j]
(kind1, clustername1, pointnames1) = cp1
(kind2, clustername2, pointnames2) = cp2
shared = set(pointnames1).intersection(set(pointnames2))
n = len(shared)
if n > 0:
kindfilter1 = KindFilter(kind1, len(pointnames1),self._solver.top_level())
kindfilter2 = KindFilter(kind2, len(pointnames2),self._solver.top_level())
sub = NConnectedPairs(solver, n, kindfilter1,kindfilter2)
self._subs.append(sub)
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
self._source2pattern[sub]=[cp1,cp2]
print "creating",sub,"for",self._source2pattern[sub]
IncrementalSet.__init__(self, self._subs)
def _receive_add(self, source, object):
# NOTE: we cannot use source like this to link it to a pattern, because sources may be equal!
print "receive",object
pattern = self._source2pattern[source]
print "matching sub-pattern", pattern
#raise NotImplementedError
def _receive_remove(self, source, object):
raise NotImplementedError
def __eq__(self, other):
if isinstance(other, PatternMatches):
return (self._solver, self._pattern)==(other._solver,other._pattern)
else:
return False
def __hash__(self):
return hash((self._solver, self._pattern))
def __repr__(self):
return "PatternMatches(%s,%s)"%(str(self._pattern),str(self._solver))
# -------------------------------------------------------
# -------------------------- test code ------------------
# -------------------------------------------------------
def test_icpm():
solver = ClusterSolver([])
solver.add(Rigid(["p5"]))
solver.add(Rigid(["p1","p2"]))
solver.add(Rigid(["p2","p3"]))
solver.add(Rigid(["p1","p3","p4","p5"]))
pattern = [["rigid","$d_ab",["$a", "$b"]],
["rigid", "$d_ac",["$a", "$c"]],
["rigid", "$d_bc",["$b","$c"]]]
matches = PatternMatches(pattern, solver)
print list(matches)
if __name__ == "__main__":
test_icpm()