# HG changeset patch # User Laman # Date 2017-12-21 20:26:37 # Node ID 9a3f61bf97f280fb68ffa09fdb1cb0bed0dcf869 # Parent b06733513452d3c06bc10d20b555f3cb3d62f829 StateBag: optimized a little, still slow diff --git a/src/benchmark.py b/src/benchmark.py --- a/src/benchmark.py +++ b/src/benchmark.py @@ -1,14 +1,15 @@ +import sys import cProfile from tests.testEngine import TestTransitions from tests.testStatebag import TestStateBag -# t=TestTransitions() -# -# cProfile.run(r""" -# t.testReal() -# """) +arg=sys.argv[1] -t=TestStateBag() -cProfile.run(r"""t.testReal()""") +if arg=="engine": + t=TestTransitions() + cProfile.run(r"""t.testReal()""") +elif arg=="statebag": + t=TestStateBag() + cProfile.run(r"""t.testReal()""") diff --git a/src/statebag.py b/src/statebag.py --- a/src/statebag.py +++ b/src/statebag.py @@ -31,6 +31,40 @@ def estimateDistance(diff): return replacements+1 # ??? +## Update contents of diff1 by contents of diff2. +def updateDiff(diff1,diff2): + res=[] + i=j=0 + m=len(diff1) + n=len(diff2) + while i3: return # we couldn't find every such move sequence anyway without a clever algorithm weightEst=s.getWeight()+2-distEst @@ -114,11 +148,13 @@ class StateBag: sn=BoardState(board) if len(self._states)>0: if sn==self._states[-1]: return None # no change - sn.diff2Prev=sn-self._states[-1] + sn.cachedDiff=sn-self._states[-1] else: sn.setWeight(1) + diff=sn.cachedDiff for s in reversed(self._states): - sn.tryConnect(s) + sn.tryConnect(s,diff) + diff=updateDiff(s.cachedDiff,diff) self._states.append(sn) return sn diff --git a/src/tests/testStatebag.py b/src/tests/testStatebag.py --- a/src/tests/testStatebag.py +++ b/src/tests/testStatebag.py @@ -5,7 +5,7 @@ import config as cfg from util import BLACK as B,WHITE as W,EMPTY as _ from go.engine import SpecGo,Engine import go.engine -from statebag import BoardState,StateBag +from statebag import BoardState,StateBag,updateDiff from .util import simpleLoadSgf,listStates @@ -52,6 +52,21 @@ class TestBoardState(TestCase): self.assertIs(s3.getPrev(), s2) self.assertEqual(s3.getWeight(), 1) + def testUpdateDiff(self): + files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"] + + for f in files: + moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f)) + states=listStates(moves) + for (i,j,k) in [(1,2,3),(10,20,30),(90,100,110),(20,70,120)]: + s1=states[i] + s2=states[j] + s3=states[k] + diff1=s2-s1 + diff2=s3-s2 + with self.subTest(file=f,ijk=(i,j,k)): + self.assertEqual(s3-s1,updateDiff(diff1,diff2)) + class TestStateBag(TestCase): def testReal(self):