# HG changeset patch # User Laman # Date 2017-12-19 12:41:55 # Node ID b06733513452d3c06bc10d20b555f3cb3d62f829 # Parent 7ef3360afbe51272d3aa1248821a6bfe99897e33 StateBag: tests and fixes diff --git a/src/benchmark.py b/src/benchmark.py --- a/src/benchmark.py +++ b/src/benchmark.py @@ -1,10 +1,14 @@ import cProfile from tests.testEngine import TestTransitions +from tests.testStatebag import TestStateBag -t=TestTransitions() +# t=TestTransitions() +# +# cProfile.run(r""" +# t.testReal() +# """) -cProfile.run(r""" -t.testReal() -""") +t=TestStateBag() +cProfile.run(r"""t.testReal()""") diff --git a/src/go/engine.py b/src/go/engine.py --- a/src/go/engine.py +++ b/src/go/engine.py @@ -8,7 +8,7 @@ from . import core # @param colorOut {BLACK,WHITE}: color to close the sequence def getTransitionSequence(state1,state2,colorIn,colorOut,diff): eng.load(state1,diff) - eng.iterativelyDeepen(state2,colorIn,colorOut) + return eng.iterativelyDeepen(state2,colorIn,colorOut) class SpecGo(core.Go): @@ -69,6 +69,7 @@ class Engine: if seq: seq.reverse() return seq + return None def dfs(self,state2,limit): g=self._g @@ -80,10 +81,6 @@ class Engine: if m==PASS: # no reason for both players to pass self._moveList[(1-g.toMove)>>1].remove(m) - if g.hash()==state2.hash(): - self._undoMove(m,captured) - return [(g.toMove,*m)] - if limit>1: seq=self.dfs(state2,limit-1) if seq: @@ -91,8 +88,12 @@ class Engine: seq.append((g.toMove,*m)) return seq + if limit==1 and g.hash()==state2.hash(): + self._undoMove(m,captured) + return [(g.toMove,*m)] + self._undoMove(m,captured) - return False + return None def _undoMove(self,move,captured): g=self._g diff --git a/src/statebag.py b/src/statebag.py --- a/src/statebag.py +++ b/src/statebag.py @@ -42,6 +42,7 @@ class GameTreeNode: ## Connect itself after v if it gains itself more weight. def tryConnect(self,v,diff): moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff) + if not moves: return w=v.weight+2-len(moves) # proper one move transition increases the weight by 1 if w>self.weight: self.moves=moves @@ -57,10 +58,10 @@ class BoardState: self._hash=None def tryConnect(self,s): - diff=s-self + diff=self-s distEst=estimateDistance(diff) if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm - weightEst=s.getWeight()-distEst + weightEst=s.getWeight()+2-distEst if weightEst<=self.getWeight(): return for v1 in s.nodes: for v2 in self.nodes: @@ -94,9 +95,16 @@ class BoardState: def __eq__(self,x): return self.hash()==x.hash() + def setWeight(self,val): + for v in self.nodes: v.weight=val + def getWeight(self): return max(v.weight for v in self.nodes) + def getPrev(self): + v=max(self.nodes, key=lambda v: v.weight) + return v.prev.parent if v.prev else None + class StateBag: def __init__(self): @@ -104,11 +112,15 @@ class StateBag: def pushState(self,board): sn=BoardState(board) - if self._states and sn==self._states[-1]: return # no change + if len(self._states)>0: + if sn==self._states[-1]: return None # no change + sn.diff2Prev=sn-self._states[-1] + else: sn.setWeight(1) for s in reversed(self._states): sn.tryConnect(s) self._states.append(sn) + return sn def merge(self,branch): pass diff --git a/src/tests/testEngine.py b/src/tests/testEngine.py --- a/src/tests/testEngine.py +++ b/src/tests/testEngine.py @@ -1,13 +1,12 @@ -import re import os.path import logging as log from unittest import TestCase import config as cfg from util import BLACK as B,WHITE as W,EMPTY as _ -from go.core import Go from go.engine import SpecGo,Engine from statebag import BoardState +from .util import simpleLoadSgf,listStates _log=log.getLogger(__name__) @@ -19,22 +18,6 @@ handler.setFormatter(formatter) _log.addHandler(handler) -def simpleLoadSgf(filename): - with open(filename) as f: - contents=f.read() - g=lambda m: tuple(ord(c)-ord('a') for c in reversed(m.group(1))) - return [g(m) for m in re.finditer(r"\b[BW]\[([a-z]{2})\]",contents)] - - -def listStates(moves): - g=Go() - res=[BoardState(g.board)] - for m in moves: - g.doMove(g.toMove,*m) - res.append(BoardState(g.board)) - return res - - class TestTransitions(TestCase): def testBasic(self): s1=BoardState([ @@ -116,8 +99,7 @@ class TestTransitions(TestCase): def testReal(self): files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"] - g=SpecGo() - eng=Engine(g) + eng=Engine() for f in files: moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f)) diff --git a/src/tests/testStatebag.py b/src/tests/testStatebag.py new file mode 100644 --- /dev/null +++ b/src/tests/testStatebag.py @@ -0,0 +1,74 @@ +import os.path +from unittest import TestCase + +import config as cfg +from util import BLACK as B,WHITE as W,EMPTY as _ +from go.engine import SpecGo,Engine +import go.engine +from statebag import BoardState,StateBag +from .util import simpleLoadSgf,listStates + + +class TestBoardState(TestCase): + def testBasic(self): + s1=BoardState([ + [_,_,_], + [_,_,_], + [_,_,_] + ]) + s2=BoardState([ + [_,_,_], + [_,B,_], + [_,_,_] + ]) + g=SpecGo(3) + go.engine.eng=Engine(g) + + s2.tryConnect(s1) + self.assertIs(s2.getPrev(), s1) + self.assertEqual(s2.getWeight(), 1) + + def test2ply(self): + s1=BoardState([ + [_,_,_], + [_,_,_], + [_,_,_] + ]) + s2=BoardState([ + [_,_,_], + [_,B,_], + [_,_,_] + ]) + s3=BoardState([ + [_,W,_], + [_,B,B], + [_,_,_] + ]) + g=SpecGo(3) + go.engine.eng=Engine(g) + + s2.tryConnect(s1) + s3.tryConnect(s2) + self.assertIs(s3.getPrev(), s2) + self.assertEqual(s3.getWeight(), 1) + + +class TestStateBag(TestCase): + def testReal(self): + go.engine.eng=Engine() + files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"] + + for f in files: + moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f)) + states=listStates(moves) + + for k in range(1,3): + bag=StateBag() + i=0 + for s_ in states: + i+=1 + if i%(2*k-1)>=k: # keep k, skip k-1 + continue + s=bag.pushState(s_) + if len(bag._states)>1: + self.assertIs(s.getPrev(), bag._states[-2]) diff --git a/src/tests/util.py b/src/tests/util.py new file mode 100644 --- /dev/null +++ b/src/tests/util.py @@ -0,0 +1,20 @@ +import re + +from go.core import Go +from statebag import BoardState + + +def simpleLoadSgf(filename): + with open(filename) as f: + contents=f.read() + g=lambda m: tuple(ord(c)-ord('a') for c in reversed(m.group(1))) + return [g(m) for m in re.finditer(r"\b[BW]\[([a-z]{2})\]",contents)] + + +def listStates(moves): + g=Go() + res=[BoardState(g.board)] + for m in moves: + g.doMove(g.toMove,*m) + res.append(BoardState(g.board)) + return res