diff --git a/src/statebag.py b/src/statebag/__init__.py rename from src/statebag.py rename to src/statebag/__init__.py --- a/src/statebag.py +++ b/src/statebag/__init__.py @@ -16,19 +16,7 @@ So we try to find the correct crossover - remember the better variant - linearize the fork back by discarding s'_j-s preceding the crossover and s_j-s following the crossover """ -from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard -from go.engine import getTransitionSequence - - -## Crude lower bound on edit distance between states. -def estimateDistance(diff): - # lot of room for improvements - additions=sum(1 for d in diff if d[2]=="+") - deletions=sum(1 for d in diff if d[2]=="-") - replacements=len(diff)-additions-deletions - if additions>0: return additions+replacements - elif replacements==0 and deletions>0: return 2 # take n, return 1 - return replacements+1 # ??? +from .boardstate import BoardState ## Update contents of diff1 by contents of diff2. @@ -65,81 +53,6 @@ def transformSingle(action1,item1,action return None -class GameTreeNode: - def __init__(self,parent,color): - self.parent=parent - self.color=color # color closing the move sequence - self.prev=None - self.moves=[] # move sequence from prev to self - self.weight=0 - - ## Connect itself after v if it gains itself more weight. - def tryConnect(self,v,diff): - moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff) - if not moves: return - w=v.weight+2-len(moves) # proper one move transition increases the weight by 1 - if w>self.weight: - self.moves=moves - self.prev=v - self.weight=w - - -class BoardState: - def __init__(self,board): - self._board=tuple(tuple(x for x in row) for row in board) - self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE)) - self.cachedDiff=[] - self._hash=None - - def tryConnect(self,s,diff=None): - if diff is None: diff=self-s - distEst=estimateDistance(diff) - if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm - weightEst=s.getWeight()+2-distEst - if weightEst<=self.getWeight(): return - for v1 in s.nodes: - for v2 in self.nodes: - v2.tryConnect(v1,diff) - - def hash(self): - if self._hash is None: self._hash=hashBoard(self._board) - return self._hash - - def export(self): - return exportBoard(self._board) - - def exportDiff(self,s2): - return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export()) - - def __iter__(self): return iter(self._board) - - def __getitem__(self,key): return self._board[key] - - def __sub__(self,x): - res=[] - - for (r,(row1,row2)) in enumerate(zip(self._board,x)): - for (c,(item1,item2)) in enumerate(zip(row1,row2)): - if item1==item2: continue - elif item2==EMPTY: res.append((r,c,"+",item1)) - elif item1==EMPTY: res.append((r,c,"-",item2)) - else: res.append((r,c,"*",item1)) # ->to - return res - - def __eq__(self,x): - return self.hash()==x.hash() - - def setWeight(self,val): - for v in self.nodes: v.weight=val - - def getWeight(self): - return max(v.weight for v in self.nodes) - - def getPrev(self): - v=max(self.nodes, key=lambda v: v.weight) - return v.prev.parent if v.prev else None - - class StateBag: def __init__(self): self._states=[] diff --git a/src/statebag/boardstate.py b/src/statebag/boardstate.py new file mode 100644 --- /dev/null +++ b/src/statebag/boardstate.py @@ -0,0 +1,88 @@ +from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard +from go.engine import getTransitionSequence + + +## Crude lower bound on edit distance between states. +def estimateDistance(diff): + # lot of room for improvements + additions=sum(1 for d in diff if d[2]=="+") + deletions=sum(1 for d in diff if d[2]=="-") + replacements=len(diff)-additions-deletions + if additions>0: return additions+replacements + elif replacements==0 and deletions>0: return 2 # take n, return 1 + return replacements+1 # ??? + + +class GameTreeNode: + def __init__(self,parent,color): + self.parent=parent + self.color=color # color closing the move sequence + self.prev=None + self.moves=[] # move sequence from prev to self + self.weight=0 + + ## Connect itself after v if it gains itself more weight. + def tryConnect(self,v,diff): + moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff) + if not moves: return + w=v.weight+2-len(moves) # proper one move transition increases the weight by 1 + if w>self.weight: + self.moves=moves + self.prev=v + self.weight=w + + +class BoardState: + def __init__(self,board): + self._board=tuple(tuple(x for x in row) for row in board) + self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE)) + self.cachedDiff=[] + self._hash=None + + def tryConnect(self,s,diff=None): + if diff is None: diff=self-s + distEst=estimateDistance(diff) + if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm + weightEst=s.getWeight()+2-distEst + if weightEst<=self.getWeight(): return + for v1 in s.nodes: + for v2 in self.nodes: + v2.tryConnect(v1,diff) + + def hash(self): + if self._hash is None: self._hash=hashBoard(self._board) + return self._hash + + def export(self): + return exportBoard(self._board) + + def exportDiff(self,s2): + return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export()) + + def __iter__(self): return iter(self._board) + + def __getitem__(self,key): return self._board[key] + + def __sub__(self,x): + res=[] + + for (r,(row1,row2)) in enumerate(zip(self._board,x)): + for (c,(item1,item2)) in enumerate(zip(row1,row2)): + if item1==item2: continue + elif item2==EMPTY: res.append((r,c,"+",item1)) + elif item1==EMPTY: res.append((r,c,"-",item2)) + else: res.append((r,c,"*",item1)) # ->to + return res + + def __eq__(self,x): + return self.hash()==x.hash() + + def setWeight(self,val): + for v in self.nodes: v.weight=val + + def getWeight(self): + return max(v.weight for v in self.nodes) + + def getPrev(self): + v=max(self.nodes, key=lambda v: v.weight) + return v.prev.parent if v.prev else None