diff --git a/src/tests/testEngine.py b/src/tests/testEngine.py --- a/src/tests/testEngine.py +++ b/src/tests/testEngine.py @@ -32,8 +32,9 @@ class TestTransitions(TestCase): ]) g=SpecGo(3) eng=Engine(g) - eng.load(s1,s2-s1) - self.assertEqual(eng.dfs(s2,1),[(1,1,1)]) + eng.load(s1) + eng._moveList=eng._g.listRelevantMoves(s2-s1) + self.assertEqual(eng._dfs(s2,1),[(1,1,1)]) def testCapture(self): s1=BoardState([ @@ -49,8 +50,9 @@ class TestTransitions(TestCase): g=SpecGo(3) g.toMove=W eng=Engine(g) - eng.load(s1,s2-s1) - self.assertEqual(eng.dfs(s2,1),[(W,1,2)]) + eng.load(s1) + eng._moveList=eng._g.listRelevantMoves(s2-s1) + self.assertEqual(eng._dfs(s2,1),[(W,1,2)]) def testMulti(self): s1=BoardState([ @@ -65,8 +67,9 @@ class TestTransitions(TestCase): ]) g=SpecGo(3) eng=Engine(g) - eng.load(s1,s2-s1) - self.assertEqual(eng.dfs(s2,2),[(W,1,2),(B,1,1)]) + eng.load(s1) + eng._moveList=eng._g.listRelevantMoves(s2-s1) + self.assertEqual(eng._dfs(s2,2),[(W,1,2),(B,1,1)]) def testSnapback(self): s1=BoardState([ @@ -81,8 +84,9 @@ class TestTransitions(TestCase): ]) g=SpecGo(3) eng=Engine(g) - eng.load(s1,s2-s1) - self.assertEqual(eng.dfs(s2,2),[(W,2,1),(B,1,1)]) + eng.load(s1) + eng._moveList=eng._g.listRelevantMoves(s2-s1) + self.assertEqual(eng._dfs(s2,2),[(W,2,1),(B,1,1)]) s1=BoardState([ [_,_,_], @@ -94,8 +98,9 @@ class TestTransitions(TestCase): [W,B,B], [_,W,_] ]) - eng.load(s1,s2-s1) - self.assertEqual(eng.dfs(s2,2),[(W,2,1),(B,2,0)]) + eng.load(s1) + eng._moveList=eng._g.listRelevantMoves(s2-s1) + self.assertEqual(eng._dfs(s2,2),[(W,2,1),(B,2,0)]) def testReal(self): files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"] @@ -109,11 +114,11 @@ class TestTransitions(TestCase): toMove=B for (i,(s1,s2)) in enumerate(zip(states,states[k:])): diff=s2-s1 - eng.load(s1,diff) + eng.load(s1) colorIn=W if i&1 else B - colorOut=colorIn if k&1 else 1-colorIn - seq=eng.iterativelyDeepen(s2,colorIn,colorOut) - msg="\n"+s1.exportDiff(s2) + colorOut=colorIn if k&1 else -colorIn + seq=eng.iterativelyDeepen(s2,diff,colorIn,colorOut) + msg="\n>>> {0} ({1}, {2})\n>>> {3}\n".format(f,k,i,str(seq))+s1.exportDiff(s2) self.assertIsNotNone(seq,msg) self.assertLessEqual(len(seq),k,msg) if len(seq)!=k: _log.warning("shorter than expected transition sequence:" + msg + "\n" + str(seq))