diff --git a/src/benchmark.py b/src/benchmark.py --- a/src/benchmark.py +++ b/src/benchmark.py @@ -12,4 +12,4 @@ if arg=="engine": cProfile.run(r"""t.testReal()""") elif arg=="statebag": t=TestStateBag() - cProfile.run(r"""t.testReal()""") + cProfile.run(r"""t.testNoise()""") diff --git a/src/statebag/boardstate.py b/src/statebag/boardstate.py --- a/src/statebag/boardstate.py +++ b/src/statebag/boardstate.py @@ -1,14 +1,23 @@ from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard -from go.engine import getTransitionSequence +from go.engine import getTransitionSequence,SpecGo + + +g=SpecGo() ## Crude lower bound on edit distance between states. -def estimateDistance(diff): +def estimateDistance(diff,s1,s2): # lot of room for improvements - additions=sum(1 for d in diff if d[2]=="+") - deletions=sum(1 for d in diff if d[2]=="-") - replacements=len(diff)-additions-deletions - if additions>0: return additions+replacements + additions=deletions=replacements=unaccounted=0 + for (r,c,d,color) in diff: + if d=="+": additions+=1 + elif d=="-": + deletions+=1 + for (ri,ci) in g.listNeighbours(r,c): + if s1[ri][ci]==EMPTY and s2[ri][ci]==EMPTY: + unaccounted+=1 + else: replacements+=1 + if additions>0 or unaccounted>0: return additions+replacements+unaccounted elif replacements==0 and deletions>0: return 2 # take n, return 1 return replacements+1 # ??? @@ -53,8 +62,10 @@ class BoardState: self._hash=None def tryConnect(self,s,diff=None): + """:param s: BoardState s + :param diff: [(r,c,change,color), ...], change in {+,-,*}, color in {BLACK,WHITE}""" if diff is None: diff=self-s - distEst=estimateDistance(diff) + distEst=estimateDistance(diff,s,self) if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm weightEst=s.getWeight()+2-distEst if weightEst<=self.getWeight(): return diff --git a/src/tests/testStatebag.py b/src/tests/testStatebag.py --- a/src/tests/testStatebag.py +++ b/src/tests/testStatebag.py @@ -1,3 +1,4 @@ +import random import os.path from unittest import TestCase @@ -6,7 +7,7 @@ from util import BLACK as B,WHITE as W,E from go.engine import SpecGo,Engine import go.engine from statebag import BoardState,StateBag,updateDiff -from .util import simpleLoadSgf,listStates +from .util import simpleLoadSgf,listBoards,listStates class TestBoardState(TestCase): @@ -69,21 +70,47 @@ class TestBoardState(TestCase): class TestStateBag(TestCase): - def testReal(self): + def testSkips(self): go.engine.eng=Engine() files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"] for f in files: moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f)) - states=listStates(moves) + boards=listBoards(moves) for k in range(1,3): bag=StateBag() i=0 - for s_ in states: + for b in boards: i+=1 if i%(2*k-1)>=k: # keep k, skip k-1 continue - s=bag.pushState(s_) + s=bag.pushState(b) if len(bag._states)>1: self.assertIs(s.getPrev(), bag._states[-2]) + + def testNoise(self): + random.seed(361) + go.engine.eng=Engine() + files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"] + + for f in files: + moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f)) + boards=listBoards(moves) + + bag=StateBag() + for b in boards: + s=bag.pushState(b) + if len(bag._states)>1: + # correct state skipping the erroneous one, connected to the previous correct one + self.assertIs(s.getPrev(), bag._states[key]) + + if random.random()<0.9: + key=-2 + continue + for i in range(random.randrange(1,10)): + r=random.randrange(19) + c=random.randrange(19) + b[r][c]=(b[r][c]+random.choice((2,3)))%3-1 # random transformation [-1,1]->[-1,1] + bag.pushState(b) + key=-3 diff --git a/src/tests/util.py b/src/tests/util.py --- a/src/tests/util.py +++ b/src/tests/util.py @@ -11,10 +11,14 @@ def simpleLoadSgf(filename): return [g(m) for m in re.finditer(r"\b[BW]\[([a-z]{2})\]",contents)] -def listStates(moves): +def listBoards(moves): g=Go() - res=[BoardState(g.board)] + res=[tuple(list(x for x in row) for row in g.board)] for m in moves: g.doMove(g.toMove,*m) - res.append(BoardState(g.board)) + res.append(tuple(list(x for x in row) for row in g.board)) return res + + +def listStates(moves): + return [BoardState(b) for b in listBoards(moves)]