Fixed solution-selection and root-transformation bugs.
This commit is contained in:
parent
38c71c882b
commit
74386c1710
|
@ -927,5 +927,18 @@ class ClusterSolver(Notifier):
|
||||||
## raise "cluster determined by more than one method"
|
## raise "cluster determined by more than one method"
|
||||||
## return result
|
## return result
|
||||||
|
|
||||||
|
|
||||||
|
##def _all_sources_constraint_in_cluster(self, constraint, cluster):
|
||||||
|
## if not self._contains_constraint(cluster, constraint):
|
||||||
|
## return Set()
|
||||||
|
## elif self._is_atomic(cluster):
|
||||||
|
## return Set([cluster])
|
||||||
|
## else:
|
||||||
|
## method = self._determining_method(cluster)
|
||||||
|
## sources = Set()
|
||||||
|
## for inp in method.input_clusters():
|
||||||
|
## sources.union_update(self._all_sources_constraint_in_cluster(constraint, inp))
|
||||||
|
## return sources
|
||||||
|
|
||||||
# class ClusterSolver
|
# class ClusterSolver
|
||||||
|
|
||||||
|
|
|
@ -88,6 +88,7 @@ class ClusterSolver3D(ClusterSolver):
|
||||||
|
|
||||||
# overriding ClusterSolver.set_root
|
# overriding ClusterSolver.set_root
|
||||||
def set_root(self, cluster):
|
def set_root(self, cluster):
|
||||||
|
"""Set root cluster, used for positionig and orienting the solutions"""
|
||||||
diag_print("set root "+str(self.rootcluster), "clsolver3D")
|
diag_print("set root "+str(self.rootcluster), "clsolver3D")
|
||||||
if self.rootcluster != None:
|
if self.rootcluster != None:
|
||||||
oldrootvar = rootname(self.rootcluster)
|
oldrootvar = rootname(self.rootcluster)
|
||||||
|
@ -99,18 +100,7 @@ class ClusterSolver3D(ClusterSolver):
|
||||||
|
|
||||||
# ------------ INTERNALLY USED METHODS --------
|
# ------------ INTERNALLY USED METHODS --------
|
||||||
|
|
||||||
def _all_sources_constraint_in_cluster(self, constraint, cluster):
|
|
||||||
if not self._contains_constraint(cluster, constraint):
|
|
||||||
return Set()
|
|
||||||
elif self._is_atomic(cluster):
|
|
||||||
return Set([cluster])
|
|
||||||
else:
|
|
||||||
method = self._determining_method(cluster)
|
|
||||||
sources = Set()
|
|
||||||
for inp in method.input_clusters():
|
|
||||||
sources.union_update(self._all_sources_constraint_in_cluster(constraint, inp))
|
|
||||||
return sources
|
|
||||||
|
|
||||||
# --------------
|
# --------------
|
||||||
# search methods
|
# search methods
|
||||||
# --------------
|
# --------------
|
||||||
|
@ -212,15 +202,19 @@ class ClusterSolver3D(ClusterSolver):
|
||||||
self._add_cluster(output)
|
self._add_cluster(output)
|
||||||
self._add_method(merge)
|
self._add_method(merge)
|
||||||
# remove input clusters from top_level
|
# remove input clusters from top_level
|
||||||
if not (hasattr(merge,"noremove") and merge.noremove == True):
|
merge.restore_toplevel = [] # make restore list in method
|
||||||
merge.restore_toplevel = [] # make restore list in method
|
for cluster in merge.input_clusters():
|
||||||
for cluster in merge.input_clusters():
|
# do not remove rigids from toplevel if method does not consider root
|
||||||
if num_constraints(cluster.intersection(output)) >= num_constraints(cluster):
|
if isinstance(cluster, Rigid):
|
||||||
diag_print("remove from top-level: "+str(cluster),"clsolver3D")
|
if hasattr(merge,"noremove") and merge.noremove == True:
|
||||||
self._rem_top_level(cluster)
|
continue
|
||||||
merge.restore_toplevel.append(cluster)
|
# remove input clusters when all its constraints are in output cluster
|
||||||
else:
|
if num_constraints(cluster.intersection(output)) >= num_constraints(cluster):
|
||||||
diag_print("keep top-level: "+str(cluster),"clsolver3D")
|
diag_print("remove from top-level: "+str(cluster),"clsolver3D")
|
||||||
|
self._rem_top_level(cluster)
|
||||||
|
merge.restore_toplevel.append(cluster)
|
||||||
|
else:
|
||||||
|
diag_print("keep top-level: "+str(cluster),"clsolver3D")
|
||||||
# add method to determine root-variable
|
# add method to determine root-variable
|
||||||
self._add_root_method(merge.input_clusters(),merge.outputs()[0])
|
self._add_root_method(merge.input_clusters(),merge.outputs()[0])
|
||||||
# add solution selection methods
|
# add solution selection methods
|
||||||
|
@ -544,7 +538,7 @@ class DeriveDAD(ClusterMethod):
|
||||||
return solutions
|
return solutions
|
||||||
|
|
||||||
class DeriveADD(ClusterMethod):
|
class DeriveADD(ClusterMethod):
|
||||||
"""Represents a merging of one distance and to distances"""
|
"""Represents a merging of one angle and two distances"""
|
||||||
def __init__(self, map):
|
def __init__(self, map):
|
||||||
# check inputs
|
# check inputs
|
||||||
self.a_cab = map["$a_cab"]
|
self.a_cab = map["$a_cab"]
|
||||||
|
@ -636,7 +630,7 @@ class DeriveAA(ClusterMethod):
|
||||||
return solutions
|
return solutions
|
||||||
|
|
||||||
class MergeSR(ClusterMethod):
|
class MergeSR(ClusterMethod):
|
||||||
"""Merge a Rigid from a Scalabe and a Rigid sharing two points"""
|
"""Merge a Scalabe and a Rigid sharing two points"""
|
||||||
def __init__(self, map):
|
def __init__(self, map):
|
||||||
# check inputs
|
# check inputs
|
||||||
in1 = map["$r"]
|
in1 = map["$r"]
|
||||||
|
|
|
@ -7,6 +7,7 @@ from multimethod import MultiVariable
|
||||||
class Distance:
|
class Distance:
|
||||||
"""A Distance represents a known distance"""
|
"""A Distance represents a known distance"""
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, a, b):
|
def __init__(self, a, b):
|
||||||
"""Create a new Distance
|
"""Create a new Distance
|
||||||
|
|
||||||
|
@ -15,7 +16,7 @@ class Distance:
|
||||||
b - point variable
|
b - point variable
|
||||||
"""
|
"""
|
||||||
self.vars = (a,b)
|
self.vars = (a,b)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "dist("\
|
return "dist("\
|
||||||
+str(self.vars[0])+","\
|
+str(self.vars[0])+","\
|
||||||
|
@ -63,6 +64,12 @@ class Angle:
|
||||||
|
|
||||||
class Cluster(MultiVariable):
|
class Cluster(MultiVariable):
|
||||||
"""A set of points, satisfying some constaint"""
|
"""A set of points, satisfying some constaint"""
|
||||||
|
|
||||||
|
staticcounter = 0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
Cluster.staticcounter += 1
|
||||||
|
self.creationtime = Cluster.staticcounter
|
||||||
|
|
||||||
def intersection(self, other):
|
def intersection(self, other):
|
||||||
shared = Set(self.vars).intersection(other.vars)
|
shared = Set(self.vars).intersection(other.vars)
|
||||||
|
@ -125,7 +132,8 @@ class Rigid(Cluster):
|
||||||
|
|
||||||
keyword args:
|
keyword args:
|
||||||
vars - list of variables
|
vars - list of variables
|
||||||
"""
|
"""
|
||||||
|
Cluster.__init__(self)
|
||||||
self.vars = ImmutableSet(vars)
|
self.vars = ImmutableSet(vars)
|
||||||
self.overconstrained = False
|
self.overconstrained = False
|
||||||
|
|
||||||
|
@ -153,6 +161,7 @@ class Hedgehog(Cluster):
|
||||||
cvar - center variable
|
cvar - center variable
|
||||||
xvars - list of variables
|
xvars - list of variables
|
||||||
"""
|
"""
|
||||||
|
Cluster.__init__(self)
|
||||||
self.cvar = cvar
|
self.cvar = cvar
|
||||||
if len(xvars) < 2:
|
if len(xvars) < 2:
|
||||||
raise StandardError, "hedgehog must have at least three variables"
|
raise StandardError, "hedgehog must have at least three variables"
|
||||||
|
@ -182,6 +191,7 @@ class Balloon(Cluster):
|
||||||
keyword args:
|
keyword args:
|
||||||
vars - collection of PointVar's
|
vars - collection of PointVar's
|
||||||
"""
|
"""
|
||||||
|
Cluster.__init__(self)
|
||||||
if len(variables) < 3:
|
if len(variables) < 3:
|
||||||
raise StandardError, "balloon must have at least three variables"
|
raise StandardError, "balloon must have at least three variables"
|
||||||
self.vars = ImmutableSet(variables)
|
self.vars = ImmutableSet(variables)
|
||||||
|
|
|
@ -328,9 +328,9 @@ class GeometricSolver (Listener):
|
||||||
map[geocluster].append(drcluster)
|
map[geocluster].append(drcluster)
|
||||||
|
|
||||||
for geocluster in geoclusters:
|
for geocluster in geoclusters:
|
||||||
# pick drcluster with fewest solutions
|
# pick newest drcluster
|
||||||
drclusters = map[geocluster]
|
drclusters = map[geocluster]
|
||||||
drcluster = min(drclusters, key=lambda c: len(self.dr.get(drcluster)))
|
drcluster = max(drclusters, key=lambda c: c.creationtime)
|
||||||
# determine solutions
|
# determine solutions
|
||||||
solutions = self.dr.get(drcluster)
|
solutions = self.dr.get(drcluster)
|
||||||
underconstrained = False
|
underconstrained = False
|
||||||
|
|
10
test/test.py
10
test/test.py
|
@ -739,11 +739,11 @@ def test(problem, use_prototype=True):
|
||||||
#diag_select(".*")
|
#diag_select(".*")
|
||||||
print "problem:"
|
print "problem:"
|
||||||
print problem
|
print problem
|
||||||
|
print "use_prototype=",use_prototype
|
||||||
solver = GeometricSolver(problem, use_prototype)
|
solver = GeometricSolver(problem, use_prototype)
|
||||||
#solver.set_prototype_selection(use_prototype)
|
print "drplan:"
|
||||||
#print "drplan:"
|
print solver.dr
|
||||||
#print solver.dr
|
print "top-level rigids:",solver.dr.top_level()
|
||||||
#print "number of top-level rigids:",len(solver.dr.top_level())
|
|
||||||
result = solver.get_result()
|
result = solver.get_result()
|
||||||
print "result:"
|
print "result:"
|
||||||
print result
|
print result
|
||||||
|
@ -751,6 +751,7 @@ def test(problem, use_prototype=True):
|
||||||
check = True
|
check = True
|
||||||
if len(result.solutions) == 0:
|
if len(result.solutions) == 0:
|
||||||
check = False
|
check = False
|
||||||
|
diag_select("GeometricProblem.verify")
|
||||||
for sol in result.solutions:
|
for sol in result.solutions:
|
||||||
print "solution:",sol
|
print "solution:",sol
|
||||||
check = check and problem.verify(sol)
|
check = check and problem.verify(sol)
|
||||||
|
@ -758,7 +759,6 @@ def test(problem, use_prototype=True):
|
||||||
print "all solutions valid"
|
print "all solutions valid"
|
||||||
else:
|
else:
|
||||||
print "INVALID"
|
print "INVALID"
|
||||||
|
|
||||||
|
|
||||||
# ----- what to test today -------
|
# ----- what to test today -------
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user