diff --git a/src/statebag.py b/src/statebag.py --- a/src/statebag.py +++ b/src/statebag.py @@ -16,8 +16,8 @@ So we try to find the correct crossover - remember the better variant - linearize the fork back by discarding s'_j-s preceding the crossover and s_j-s following the crossover """ -from util import EMPTY, hashBoard,exportBoard -from go.engine import transitionSequence +from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard +from go.engine import getTransitionSequence ## Crude lower bound on edit distance between states. @@ -31,16 +31,41 @@ def estimateDistance(diff): return replacements+1 # ??? +class GameTreeNode: + def __init__(self,parent,color): + self.parent=parent + self.color=color # color closing the move sequence + self.prev=None + self.moves=[] # move sequence from prev to self + self.weight=0 + + ## Connect itself after v if it gains itself more weight. + def tryConnect(self,v,diff): + moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff) + w=v.weight+2-len(moves) # proper one move transition increases the weight by 1 + if w>self.weight: + self.moves=moves + self.prev=v + self.weight=w + + class BoardState: def __init__(self,board): self._board=tuple(tuple(x for x in row) for row in board) - self.prev=None - self._next=None - self.moves=[] - self.weight=0 + self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE)) self.diff2Prev=None self._hash=None + def tryConnect(self,s): + diff=s-self + distEst=estimateDistance(diff) + if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm + weightEst=s.getWeight()-distEst + if weightEst<=self.getWeight(): return + for v1 in s.nodes: + for v2 in self.nodes: + v2.tryConnect(v1,diff) + def hash(self): if self._hash is None: self._hash=hashBoard(self._board) return self._hash @@ -69,6 +94,9 @@ class BoardState: def __eq__(self,x): return self.hash()==x.hash() + def getWeight(self): + return max(v.weight for v in self.nodes) + class StateBag: def __init__(self): @@ -79,18 +107,7 @@ class StateBag: if self._states and sn==self._states[-1]: return # no change for s in reversed(self._states): - diff=sn-s - distEst=estimateDistance(diff) - if distEst>3: continue # we couldn't find every such move sequence anyway without a clever algorithm - weightEst=s.weight-distEst - if weightEst<=sn.weight: continue - moves=transitionSequence(s,sn,diff) - weight=s.weight-len(moves) - if weight<=sn.weight: continue - sn.prev=s - sn.diff2Prev=diff - sn.moves=moves - sn.weight=weight + sn.tryConnect(s) self._states.append(sn) def merge(self,branch):