diff --git a/src/statebag.py b/src/statebag/__init__.py copy from src/statebag.py copy to src/statebag/__init__.py --- a/src/statebag.py +++ b/src/statebag/__init__.py @@ -16,19 +16,7 @@ So we try to find the correct crossover - remember the better variant - linearize the fork back by discarding s'_j-s preceding the crossover and s_j-s following the crossover """ -from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard -from go.engine import getTransitionSequence - - -## Crude lower bound on edit distance between states. -def estimateDistance(diff): - # lot of room for improvements - additions=sum(1 for d in diff if d[2]=="+") - deletions=sum(1 for d in diff if d[2]=="-") - replacements=len(diff)-additions-deletions - if additions>0: return additions+replacements - elif replacements==0 and deletions>0: return 2 # take n, return 1 - return replacements+1 # ??? +from .boardstate import BoardState ## Update contents of diff1 by contents of diff2. @@ -65,81 +53,6 @@ def transformSingle(action1,item1,action return None -class GameTreeNode: - def __init__(self,parent,color): - self.parent=parent - self.color=color # color closing the move sequence - self.prev=None - self.moves=[] # move sequence from prev to self - self.weight=0 - - ## Connect itself after v if it gains itself more weight. - def tryConnect(self,v,diff): - moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff) - if not moves: return - w=v.weight+2-len(moves) # proper one move transition increases the weight by 1 - if w>self.weight: - self.moves=moves - self.prev=v - self.weight=w - - -class BoardState: - def __init__(self,board): - self._board=tuple(tuple(x for x in row) for row in board) - self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE)) - self.cachedDiff=[] - self._hash=None - - def tryConnect(self,s,diff=None): - if diff is None: diff=self-s - distEst=estimateDistance(diff) - if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm - weightEst=s.getWeight()+2-distEst - if weightEst<=self.getWeight(): return - for v1 in s.nodes: - for v2 in self.nodes: - v2.tryConnect(v1,diff) - - def hash(self): - if self._hash is None: self._hash=hashBoard(self._board) - return self._hash - - def export(self): - return exportBoard(self._board) - - def exportDiff(self,s2): - return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export()) - - def __iter__(self): return iter(self._board) - - def __getitem__(self,key): return self._board[key] - - def __sub__(self,x): - res=[] - - for (r,(row1,row2)) in enumerate(zip(self._board,x)): - for (c,(item1,item2)) in enumerate(zip(row1,row2)): - if item1==item2: continue - elif item2==EMPTY: res.append((r,c,"+",item1)) - elif item1==EMPTY: res.append((r,c,"-",item2)) - else: res.append((r,c,"*",item1)) # ->to - return res - - def __eq__(self,x): - return self.hash()==x.hash() - - def setWeight(self,val): - for v in self.nodes: v.weight=val - - def getWeight(self): - return max(v.weight for v in self.nodes) - - def getPrev(self): - v=max(self.nodes, key=lambda v: v.weight) - return v.prev.parent if v.prev else None - - class StateBag: def __init__(self): self._states=[]