Changeset - eb6c3126cf26
[Not reviewed]
default
0 4 0
Laman - 6 years ago 2018-12-06 22:38:52

StateBag: optimized estimateDistance, tests for dealing with wrong states
4 files changed with 58 insertions and 16 deletions:
0 comments (0 inline, 0 general)
src/benchmark.py
Show inline comments
 
import sys
 
import cProfile
 

	
 
from tests.testEngine import TestTransitions
 
from tests.testStatebag import TestStateBag
 

	
 

	
 
arg=sys.argv[1]
 

	
 
if arg=="engine":
 
	t=TestTransitions()
 
	cProfile.run(r"""t.testReal()""")
 
elif arg=="statebag":
 
	t=TestStateBag()
 
	cProfile.run(r"""t.testReal()""")
 
	cProfile.run(r"""t.testNoise()""")
src/statebag/boardstate.py
Show inline comments
 
from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard
 
from go.engine import getTransitionSequence
 
from go.engine import getTransitionSequence,SpecGo
 

	
 

	
 
g=SpecGo()
 

	
 

	
 
## Crude lower bound on edit distance between states.
 
def estimateDistance(diff):
 
def estimateDistance(diff,s1,s2):
 
	# lot of room for improvements
 
	additions=sum(1 for d in diff if d[2]=="+")
 
	deletions=sum(1 for d in diff if d[2]=="-")
 
	replacements=len(diff)-additions-deletions
 
	if additions>0: return additions+replacements
 
	additions=deletions=replacements=unaccounted=0
 
	for (r,c,d,color) in diff:
 
		if d=="+": additions+=1
 
		elif d=="-":
 
			deletions+=1
 
			for (ri,ci) in g.listNeighbours(r,c):
 
				if s1[ri][ci]==EMPTY and s2[ri][ci]==EMPTY:
 
					unaccounted+=1
 
		else: replacements+=1
 
	if additions>0 or unaccounted>0: return additions+replacements+unaccounted
 
	elif replacements==0 and deletions>0: return 2 # take n, return 1
 
	return replacements+1 # ???
 

	
 

	
 
class GameTreeNode:
 
	def __init__(self,parent,color):
 
		self.parent=parent
 
		self.color=color # color closing the move sequence
 
		self.prev=None
 
		self.moves=[] # move sequence from prev to self
 
		self.weight=0
 

	
 
	## Connect itself after v if it gains itself more weight.
 
	def tryConnect(self,v,diff):
 
		moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff)
 
		if not moves: return
 
		w=v.weight+2-len(moves) # proper one move transition increases the weight by 1
 
		if w>self.weight:
 
			self.moves=moves
 
			self.prev=v
 
			self.weight=w
 

	
 
	def exportRecord(self):
 
		""":return: [(c,row,col), ...]. c in {BLACK,WHITE} == {1,-1}"""
 
		sequence=[]
 
		v=self
 
		while v is not None:
 
			sequence.append(v)
 
			v=v.prev
 

	
 
		res=[]
 
		for v in reversed(sequence):
 
			res.extend(v.moves)
 
		return res
 

	
 

	
 
class BoardState:
 
	def __init__(self,board):
 
		self._board=tuple(tuple(x for x in row) for row in board)
 
		self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE))
 
		self.cachedDiff=[]
 
		self._hash=None
 

	
 
	def tryConnect(self,s,diff=None):
 
		""":param s: BoardState s
 
		:param diff: [(r,c,change,color), ...], change in {+,-,*}, color in {BLACK,WHITE}"""
 
		if diff is None: diff=self-s
 
		distEst=estimateDistance(diff)
 
		distEst=estimateDistance(diff,s,self)
 
		if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm
 
		weightEst=s.getWeight()+2-distEst
 
		if weightEst<=self.getWeight(): return
 
		for v1 in s.nodes:
 
			for v2 in self.nodes:
 
				v2.tryConnect(v1,diff)
 

	
 
	def hash(self):
 
		if self._hash is None: self._hash=hashBoard(self._board)
 
		return self._hash
 

	
 
	def export(self):
 
		return exportBoard(self._board)
 

	
 
	def exportDiff(self,s2):
 
		return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export())
 

	
 
	def exportRecord(self):
 
		""":return: [(c,row,col), ...]. c in {BLACK,WHITE} == {1,-1}"""
 
		v=self.nodes[0] if self.nodes[0].weight>self.nodes[1].weight else self.nodes[1]
 
		return v.exportRecord()
 

	
 
	def __iter__(self): return iter(self._board)
 

	
 
	def __getitem__(self,key): return self._board[key]
 

	
 
	## Compute difference self-s.
 
	def __sub__(self,s):
 
		res=[]
 

	
 
		for (r,(row,rowS)) in enumerate(zip(self._board,s)):
 
			for (c,(item,itemS)) in enumerate(zip(row,rowS)):
 
				if item==itemS: continue
 
				elif itemS==EMPTY: res.append((r,c,"+",item))
 
				elif item==EMPTY: res.append((r,c,"-",itemS))
 
				else: res.append((r,c,"*",item)) # ->to
 
		return res
 

	
 
	def __eq__(self,x):
 
		return self.hash()==x.hash()
 

	
 
	def setWeight(self,val):
 
		for v in self.nodes: v.weight=val
 

	
 
	def getWeight(self):
 
		return max(v.weight for v in self.nodes)
 

	
 
	def getPrev(self):
src/tests/testStatebag.py
Show inline comments
 
import random
 
import os.path
 
from unittest import TestCase
 

	
 
import config as cfg
 
from util import BLACK as B,WHITE as W,EMPTY as _
 
from go.engine import SpecGo,Engine
 
import go.engine
 
from statebag import BoardState,StateBag,updateDiff
 
from .util import simpleLoadSgf,listStates
 
from .util import simpleLoadSgf,listBoards,listStates
 

	
 

	
 
class TestBoardState(TestCase):
 
	def testBasic(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		go.engine.eng=Engine(g)
 

	
 
		s2.tryConnect(s1)
 
		self.assertIs(s2.getPrev(), s1)
 
		self.assertEqual(s2.getWeight(), 1)
 

	
 
	def test2ply(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
		s3=BoardState([
 
			[_,W,_],
 
			[_,B,B],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		go.engine.eng=Engine(g)
 

	
 
		s2.tryConnect(s1)
 
		s3.tryConnect(s2)
 
		self.assertIs(s3.getPrev(), s2)
 
		self.assertEqual(s3.getWeight(), 1)
 

	
 
	def testUpdateDiff(self):
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 
			for (i,j,k) in [(1,2,3),(10,20,30),(90,100,110),(20,70,120)]:
 
				s1=states[i]
 
				s2=states[j]
 
				s3=states[k]
 
				diff1=s2-s1
 
				diff2=s3-s2
 
				with self.subTest(file=f,ijk=(i,j,k)):
 
					self.assertEqual(s3-s1,updateDiff(diff1,diff2))
 

	
 

	
 
class TestStateBag(TestCase):
 
	def testReal(self):
 
	def testSkips(self):
 
		go.engine.eng=Engine()
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 
			boards=listBoards(moves)
 

	
 
			for k in range(1,3):
 
				bag=StateBag()
 
				i=0
 
				for s_ in states:
 
				for b in boards:
 
					i+=1
 
					if i%(2*k-1)>=k: # keep k, skip k-1
 
						continue
 
					s=bag.pushState(s_)
 
					s=bag.pushState(b)
 
					if len(bag._states)>1:
 
						self.assertIs(s.getPrev(), bag._states[-2])
 

	
 
	def testNoise(self):
 
		random.seed(361)
 
		go.engine.eng=Engine()
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			boards=listBoards(moves)
 

	
 
			bag=StateBag()
 
			for b in boards:
 
				s=bag.pushState(b)
 
				if len(bag._states)>1:
 
					# correct state skipping the erroneous one, connected to the previous correct one
 
					self.assertIs(s.getPrev(), bag._states[key])
 

	
 
				if random.random()<0.9:
 
					key=-2
 
					continue
 
				for i in range(random.randrange(1,10)):
 
					r=random.randrange(19)
 
					c=random.randrange(19)
 
					b[r][c]=(b[r][c]+random.choice((2,3)))%3-1 # random transformation [-1,1]->[-1,1]
 
				bag.pushState(b)
 
				key=-3
src/tests/util.py
Show inline comments
 
import re
 

	
 
from go.core import Go
 
from statebag import BoardState
 

	
 

	
 
def simpleLoadSgf(filename):
 
	with open(filename) as f:
 
		contents=f.read()
 
	g=lambda m: tuple(ord(c)-ord('a') for c in reversed(m.group(1)))
 
	return [g(m) for m in re.finditer(r"\b[BW]\[([a-z]{2})\]",contents)]
 

	
 

	
 
def listStates(moves):
 
def listBoards(moves):
 
	g=Go()
 
	res=[BoardState(g.board)]
 
	res=[tuple(list(x for x in row) for row in g.board)]
 
	for m in moves:
 
		g.doMove(g.toMove,*m)
 
		res.append(BoardState(g.board))
 
		res.append(tuple(list(x for x in row) for row in g.board))
 
	return res
 

	
 

	
 
def listStates(moves):
 
	return [BoardState(b) for b in listBoards(moves)]
0 comments (0 inline, 0 general)