Changeset - 9a3f61bf97f2
[Not reviewed]
default
0 3 0
Laman - 7 years ago 2017-12-21 20:26:37

StateBag: optimized a little, still slow
3 files changed with 63 insertions and 11 deletions:
0 comments (0 inline, 0 general)
src/benchmark.py
Show inline comments
 
import sys
 
import cProfile
 

	
 
from tests.testEngine import TestTransitions
 
from tests.testStatebag import TestStateBag
 

	
 

	
 
# t=TestTransitions()
 
#
 
# cProfile.run(r"""
 
# t.testReal()
 
# """)
 
arg=sys.argv[1]
 

	
 
if arg=="engine":
 
	t=TestTransitions()
 
	cProfile.run(r"""t.testReal()""")
 
elif arg=="statebag":
 
t=TestStateBag()
 
cProfile.run(r"""t.testReal()""")
src/statebag.py
Show inline comments
 
@@ -10,76 +10,110 @@ In such a case we assume the new paramet
 
But the change might have been appropriate even earlier (before the user detected and fixed an error).
 
So we try to find the correct crossover point like this:
 
- construct a sequence S' = s'_i, ..., s'_n by reanalyzing the positions with a new set of parameters, where s_i is the point of previous user intervention or s_0
 
- for each s'_j:
 
	- try to append it to S[:j]
 
	- try to append it to S'[:j]
 
	- remember the better variant
 
- linearize the fork back by discarding s'_j-s preceding the crossover and s_j-s following the crossover
 
"""
 
from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard
 
from go.engine import getTransitionSequence
 

	
 

	
 
## Crude lower bound on edit distance between states.
 
def estimateDistance(diff):
 
	# lot of room for improvements
 
	additions=sum(1 for d in diff if d[2]=="+")
 
	deletions=sum(1 for d in diff if d[2]=="-")
 
	replacements=len(diff)-additions-deletions
 
	if additions>0: return additions+replacements
 
	elif replacements==0 and deletions>0: return 2 # take n, return 1
 
	return replacements+1 # ???
 

	
 

	
 
## Update contents of diff1 by contents of diff2.
 
def updateDiff(diff1,diff2):
 
	res=[]
 
	i=j=0
 
	m=len(diff1)
 
	n=len(diff2)
 
	while i<m and j<n:
 
		(r1,c1,action1,item1)=diff1[i]
 
		(r2,c2,action2,item2)=diff2[j]
 
		if (r1,c1)==(r2,c2):
 
			merged=transformSingle(action1,item1,action2,item2)
 
			if merged: res.append((r1,c1,*merged))
 
			i+=1
 
			j+=1
 
		elif (r1,c1)<(r2,c2):
 
			res.append(diff1[i])
 
			i+=1
 
		else:
 
			res.append(diff2[j])
 
			j+=1
 
	if i<m: res.extend(diff1[i:])
 
	else: res.extend(diff2[j:])
 
	return res
 

	
 

	
 
def transformSingle(action1,item1,action2,item2):
 
	if action1=="+":
 
		if action2!="-":
 
			return ("+",item2)
 
	elif action2=="-": return ("-",item2)
 
	elif (action1=="*" and item1==item2) or (action1=="-" and item1!=item2): return ("*",item2)
 
	return None
 

	
 

	
 
class GameTreeNode:
 
	def __init__(self,parent,color):
 
		self.parent=parent
 
		self.color=color # color closing the move sequence
 
		self.prev=None
 
		self.moves=[] # move sequence from prev to self
 
		self.weight=0
 

	
 
	## Connect itself after v if it gains itself more weight.
 
	def tryConnect(self,v,diff):
 
		moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff)
 
		if not moves: return
 
		w=v.weight+2-len(moves) # proper one move transition increases the weight by 1
 
		if w>self.weight:
 
			self.moves=moves
 
			self.prev=v
 
			self.weight=w
 

	
 

	
 
class BoardState:
 
	def __init__(self,board):
 
		self._board=tuple(tuple(x for x in row) for row in board)
 
		self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE))
 
		self.diff2Prev=None
 
		self.cachedDiff=[]
 
		self._hash=None
 

	
 
	def tryConnect(self,s):
 
		diff=self-s
 
	def tryConnect(self,s,diff=None):
 
		if diff is None: diff=self-s
 
		distEst=estimateDistance(diff)
 
		if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm
 
		weightEst=s.getWeight()+2-distEst
 
		if weightEst<=self.getWeight(): return
 
		for v1 in s.nodes:
 
			for v2 in self.nodes:
 
				v2.tryConnect(v1,diff)
 

	
 
	def hash(self):
 
		if self._hash is None: self._hash=hashBoard(self._board)
 
		return self._hash
 

	
 
	def export(self):
 
		return exportBoard(self._board)
 

	
 
	def exportDiff(self,s2):
 
		return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export())
 

	
 
	def __iter__(self): return iter(self._board)
 

	
 
	def __getitem__(self,key): return self._board[key]
 

	
 
	def __sub__(self,x):
 
		res=[]
 
@@ -93,34 +127,36 @@ class BoardState:
 
		return res
 

	
 
	def __eq__(self,x):
 
		return self.hash()==x.hash()
 

	
 
	def setWeight(self,val):
 
		for v in self.nodes: v.weight=val
 

	
 
	def getWeight(self):
 
		return max(v.weight for v in self.nodes)
 

	
 
	def getPrev(self):
 
		v=max(self.nodes, key=lambda v: v.weight)
 
		return v.prev.parent if v.prev else None
 

	
 

	
 
class StateBag:
 
	def __init__(self):
 
		self._states=[]
 

	
 
	def pushState(self,board):
 
		sn=BoardState(board)
 
		if len(self._states)>0:
 
			if sn==self._states[-1]: return None # no change
 
			sn.diff2Prev=sn-self._states[-1]
 
			sn.cachedDiff=sn-self._states[-1]
 
		else: sn.setWeight(1)
 

	
 
		diff=sn.cachedDiff
 
		for s in reversed(self._states):
 
			sn.tryConnect(s)
 
			sn.tryConnect(s,diff)
 
			diff=updateDiff(s.cachedDiff,diff)
 
		self._states.append(sn)
 
		return sn
 

	
 
	def merge(self,branch):
 
		pass
src/tests/testStatebag.py
Show inline comments
 
import os.path
 
from unittest import TestCase
 

	
 
import config as cfg
 
from util import BLACK as B,WHITE as W,EMPTY as _
 
from go.engine import SpecGo,Engine
 
import go.engine
 
from statebag import BoardState,StateBag
 
from statebag import BoardState,StateBag,updateDiff
 
from .util import simpleLoadSgf,listStates
 

	
 

	
 
class TestBoardState(TestCase):
 
	def testBasic(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		go.engine.eng=Engine(g)
 

	
 
		s2.tryConnect(s1)
 
		self.assertIs(s2.getPrev(), s1)
 
		self.assertEqual(s2.getWeight(), 1)
 

	
 
	def test2ply(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
		s3=BoardState([
 
			[_,W,_],
 
			[_,B,B],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		go.engine.eng=Engine(g)
 

	
 
		s2.tryConnect(s1)
 
		s3.tryConnect(s2)
 
		self.assertIs(s3.getPrev(), s2)
 
		self.assertEqual(s3.getWeight(), 1)
 

	
 
	def testUpdateDiff(self):
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 
			for (i,j,k) in [(1,2,3),(10,20,30),(90,100,110),(20,70,120)]:
 
				s1=states[i]
 
				s2=states[j]
 
				s3=states[k]
 
				diff1=s2-s1
 
				diff2=s3-s2
 
				with self.subTest(file=f,ijk=(i,j,k)):
 
					self.assertEqual(s3-s1,updateDiff(diff1,diff2))
 

	
 

	
 
class TestStateBag(TestCase):
 
	def testReal(self):
 
		go.engine.eng=Engine()
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 

	
 
			for k in range(1,3):
 
				bag=StateBag()
 
				i=0
 
				for s_ in states:
 
					i+=1
 
					if i%(2*k-1)>=k: # keep k, skip k-1
 
						continue
 
					s=bag.pushState(s_)
 
					if len(bag._states)>1:
 
						self.assertIs(s.getPrev(), bag._states[-2])
0 comments (0 inline, 0 general)