diff --git a/src/tests/testStatebag.py b/src/tests/testStatebag.py new file mode 100644 --- /dev/null +++ b/src/tests/testStatebag.py @@ -0,0 +1,74 @@ +import os.path +from unittest import TestCase + +import config as cfg +from util import BLACK as B,WHITE as W,EMPTY as _ +from go.engine import SpecGo,Engine +import go.engine +from statebag import BoardState,StateBag +from .util import simpleLoadSgf,listStates + + +class TestBoardState(TestCase): + def testBasic(self): + s1=BoardState([ + [_,_,_], + [_,_,_], + [_,_,_] + ]) + s2=BoardState([ + [_,_,_], + [_,B,_], + [_,_,_] + ]) + g=SpecGo(3) + go.engine.eng=Engine(g) + + s2.tryConnect(s1) + self.assertIs(s2.getPrev(), s1) + self.assertEqual(s2.getWeight(), 1) + + def test2ply(self): + s1=BoardState([ + [_,_,_], + [_,_,_], + [_,_,_] + ]) + s2=BoardState([ + [_,_,_], + [_,B,_], + [_,_,_] + ]) + s3=BoardState([ + [_,W,_], + [_,B,B], + [_,_,_] + ]) + g=SpecGo(3) + go.engine.eng=Engine(g) + + s2.tryConnect(s1) + s3.tryConnect(s2) + self.assertIs(s3.getPrev(), s2) + self.assertEqual(s3.getWeight(), 1) + + +class TestStateBag(TestCase): + def testReal(self): + go.engine.eng=Engine() + files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"] + + for f in files: + moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f)) + states=listStates(moves) + + for k in range(1,3): + bag=StateBag() + i=0 + for s_ in states: + i+=1 + if i%(2*k-1)>=k: # keep k, skip k-1 + continue + s=bag.pushState(s_) + if len(bag._states)>1: + self.assertIs(s.getPrev(), bag._states[-2])