Changeset - 9a3f61bf97f2
[Not reviewed]
default
0 3 0
Laman - 7 years ago 2017-12-21 20:26:37

StateBag: optimized a little, still slow
3 files changed with 65 insertions and 13 deletions:
0 comments (0 inline, 0 general)
src/benchmark.py
Show inline comments
 
import sys
 
import cProfile
 

	
 
from tests.testEngine import TestTransitions
 
from tests.testStatebag import TestStateBag
 

	
 

	
 
# t=TestTransitions()
 
#
 
# cProfile.run(r"""
 
# t.testReal()
 
# """)
 
arg=sys.argv[1]
 

	
 
t=TestStateBag()
 
cProfile.run(r"""t.testReal()""")
 
if arg=="engine":
 
	t=TestTransitions()
 
	cProfile.run(r"""t.testReal()""")
 
elif arg=="statebag":
 
	t=TestStateBag()
 
	cProfile.run(r"""t.testReal()""")
src/statebag.py
Show inline comments
 
@@ -22,52 +22,86 @@ from go.engine import getTransitionSeque
 

	
 
## Crude lower bound on edit distance between states.
 
def estimateDistance(diff):
 
	# lot of room for improvements
 
	additions=sum(1 for d in diff if d[2]=="+")
 
	deletions=sum(1 for d in diff if d[2]=="-")
 
	replacements=len(diff)-additions-deletions
 
	if additions>0: return additions+replacements
 
	elif replacements==0 and deletions>0: return 2 # take n, return 1
 
	return replacements+1 # ???
 

	
 

	
 
## Update contents of diff1 by contents of diff2.
 
def updateDiff(diff1,diff2):
 
	res=[]
 
	i=j=0
 
	m=len(diff1)
 
	n=len(diff2)
 
	while i<m and j<n:
 
		(r1,c1,action1,item1)=diff1[i]
 
		(r2,c2,action2,item2)=diff2[j]
 
		if (r1,c1)==(r2,c2):
 
			merged=transformSingle(action1,item1,action2,item2)
 
			if merged: res.append((r1,c1,*merged))
 
			i+=1
 
			j+=1
 
		elif (r1,c1)<(r2,c2):
 
			res.append(diff1[i])
 
			i+=1
 
		else:
 
			res.append(diff2[j])
 
			j+=1
 
	if i<m: res.extend(diff1[i:])
 
	else: res.extend(diff2[j:])
 
	return res
 

	
 

	
 
def transformSingle(action1,item1,action2,item2):
 
	if action1=="+":
 
		if action2!="-":
 
			return ("+",item2)
 
	elif action2=="-": return ("-",item2)
 
	elif (action1=="*" and item1==item2) or (action1=="-" and item1!=item2): return ("*",item2)
 
	return None
 

	
 

	
 
class GameTreeNode:
 
	def __init__(self,parent,color):
 
		self.parent=parent
 
		self.color=color # color closing the move sequence
 
		self.prev=None
 
		self.moves=[] # move sequence from prev to self
 
		self.weight=0
 

	
 
	## Connect itself after v if it gains itself more weight.
 
	def tryConnect(self,v,diff):
 
		moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff)
 
		if not moves: return
 
		w=v.weight+2-len(moves) # proper one move transition increases the weight by 1
 
		if w>self.weight:
 
			self.moves=moves
 
			self.prev=v
 
			self.weight=w
 

	
 

	
 
class BoardState:
 
	def __init__(self,board):
 
		self._board=tuple(tuple(x for x in row) for row in board)
 
		self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE))
 
		self.diff2Prev=None
 
		self.cachedDiff=[]
 
		self._hash=None
 

	
 
	def tryConnect(self,s):
 
		diff=self-s
 
	def tryConnect(self,s,diff=None):
 
		if diff is None: diff=self-s
 
		distEst=estimateDistance(diff)
 
		if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm
 
		weightEst=s.getWeight()+2-distEst
 
		if weightEst<=self.getWeight(): return
 
		for v1 in s.nodes:
 
			for v2 in self.nodes:
 
				v2.tryConnect(v1,diff)
 

	
 
	def hash(self):
 
		if self._hash is None: self._hash=hashBoard(self._board)
 
		return self._hash
 

	
 
@@ -105,22 +139,24 @@ class BoardState:
 
		v=max(self.nodes, key=lambda v: v.weight)
 
		return v.prev.parent if v.prev else None
 

	
 

	
 
class StateBag:
 
	def __init__(self):
 
		self._states=[]
 

	
 
	def pushState(self,board):
 
		sn=BoardState(board)
 
		if len(self._states)>0:
 
			if sn==self._states[-1]: return None # no change
 
			sn.diff2Prev=sn-self._states[-1]
 
			sn.cachedDiff=sn-self._states[-1]
 
		else: sn.setWeight(1)
 

	
 
		diff=sn.cachedDiff
 
		for s in reversed(self._states):
 
			sn.tryConnect(s)
 
			sn.tryConnect(s,diff)
 
			diff=updateDiff(s.cachedDiff,diff)
 
		self._states.append(sn)
 
		return sn
 

	
 
	def merge(self,branch):
 
		pass
src/tests/testStatebag.py
Show inline comments
 
import os.path
 
from unittest import TestCase
 

	
 
import config as cfg
 
from util import BLACK as B,WHITE as W,EMPTY as _
 
from go.engine import SpecGo,Engine
 
import go.engine
 
from statebag import BoardState,StateBag
 
from statebag import BoardState,StateBag,updateDiff
 
from .util import simpleLoadSgf,listStates
 

	
 

	
 
class TestBoardState(TestCase):
 
	def testBasic(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
@@ -43,24 +43,39 @@ class TestBoardState(TestCase):
 
			[_,W,_],
 
			[_,B,B],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		go.engine.eng=Engine(g)
 

	
 
		s2.tryConnect(s1)
 
		s3.tryConnect(s2)
 
		self.assertIs(s3.getPrev(), s2)
 
		self.assertEqual(s3.getWeight(), 1)
 

	
 
	def testUpdateDiff(self):
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 
			for (i,j,k) in [(1,2,3),(10,20,30),(90,100,110),(20,70,120)]:
 
				s1=states[i]
 
				s2=states[j]
 
				s3=states[k]
 
				diff1=s2-s1
 
				diff2=s3-s2
 
				with self.subTest(file=f,ijk=(i,j,k)):
 
					self.assertEqual(s3-s1,updateDiff(diff1,diff2))
 

	
 

	
 
class TestStateBag(TestCase):
 
	def testReal(self):
 
		go.engine.eng=Engine()
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 

	
 
			for k in range(1,3):
 
				bag=StateBag()
0 comments (0 inline, 0 general)