diff --git a/src/statebag/boardstate.py b/src/statebag/boardstate.py new file mode 100644 --- /dev/null +++ b/src/statebag/boardstate.py @@ -0,0 +1,88 @@ +from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard +from go.engine import getTransitionSequence + + +## Crude lower bound on edit distance between states. +def estimateDistance(diff): + # lot of room for improvements + additions=sum(1 for d in diff if d[2]=="+") + deletions=sum(1 for d in diff if d[2]=="-") + replacements=len(diff)-additions-deletions + if additions>0: return additions+replacements + elif replacements==0 and deletions>0: return 2 # take n, return 1 + return replacements+1 # ??? + + +class GameTreeNode: + def __init__(self,parent,color): + self.parent=parent + self.color=color # color closing the move sequence + self.prev=None + self.moves=[] # move sequence from prev to self + self.weight=0 + + ## Connect itself after v if it gains itself more weight. + def tryConnect(self,v,diff): + moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff) + if not moves: return + w=v.weight+2-len(moves) # proper one move transition increases the weight by 1 + if w>self.weight: + self.moves=moves + self.prev=v + self.weight=w + + +class BoardState: + def __init__(self,board): + self._board=tuple(tuple(x for x in row) for row in board) + self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE)) + self.cachedDiff=[] + self._hash=None + + def tryConnect(self,s,diff=None): + if diff is None: diff=self-s + distEst=estimateDistance(diff) + if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm + weightEst=s.getWeight()+2-distEst + if weightEst<=self.getWeight(): return + for v1 in s.nodes: + for v2 in self.nodes: + v2.tryConnect(v1,diff) + + def hash(self): + if self._hash is None: self._hash=hashBoard(self._board) + return self._hash + + def export(self): + return exportBoard(self._board) + + def exportDiff(self,s2): + return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export()) + + def __iter__(self): return iter(self._board) + + def __getitem__(self,key): return self._board[key] + + def __sub__(self,x): + res=[] + + for (r,(row1,row2)) in enumerate(zip(self._board,x)): + for (c,(item1,item2)) in enumerate(zip(row1,row2)): + if item1==item2: continue + elif item2==EMPTY: res.append((r,c,"+",item1)) + elif item1==EMPTY: res.append((r,c,"-",item2)) + else: res.append((r,c,"*",item1)) # ->to + return res + + def __eq__(self,x): + return self.hash()==x.hash() + + def setWeight(self,val): + for v in self.nodes: v.weight=val + + def getWeight(self): + return max(v.weight for v in self.nodes) + + def getPrev(self): + v=max(self.nodes, key=lambda v: v.weight) + return v.prev.parent if v.prev else None