Changeset - b06733513452
[Not reviewed]
default
0 4 2
Laman - 7 years ago 2017-12-19 12:41:55

StateBag: tests and fixes
6 files changed with 126 insertions and 33 deletions:
0 comments (0 inline, 0 general)
src/benchmark.py
Show inline comments
 
import cProfile
 

	
 
from tests.testEngine import TestTransitions
 
from tests.testStatebag import TestStateBag
 

	
 

	
 
t=TestTransitions()
 
# t=TestTransitions()
 
#
 
# cProfile.run(r"""
 
# t.testReal()
 
# """)
 

	
 
cProfile.run(r"""
 
t.testReal()
 
""")
 
t=TestStateBag()
 
cProfile.run(r"""t.testReal()""")
src/go/engine.py
Show inline comments
 
from .core import PASS
 
from . import core
 

	
 

	
 
## Compute move sequence from state1 to state2.
 
#
 
# @param colorIn {BLACK,WHITE}: color to start the sequence
 
# @param colorOut {BLACK,WHITE}: color to close the sequence
 
def getTransitionSequence(state1,state2,colorIn,colorOut,diff):
 
	eng.load(state1,diff)
 
	eng.iterativelyDeepen(state2,colorIn,colorOut)
 
	return eng.iterativelyDeepen(state2,colorIn,colorOut)
 

	
 

	
 
class SpecGo(core.Go):
 
	def __init__(self,boardSize=19):
 
		super().__init__(boardSize)
 

	
 
	def listRelevantMoves(self,diff):
 
		"""There can be 3 different changes in the diff: additions, deletions and replacements.
 
		Additions can be taken as relevant right away.
 
		Deletions and replacements had to be captured, so we add their liberties.
 
		Also any non-missing stones of partially deleted (or replaced) groups had to be replayed, so add them too.
 
		Needs to handle snapback, throw-in.
 
@@ -60,46 +60,47 @@ class Engine:
 
		self._g.load(state1)
 
		self._moveList=self._g.listRelevantMoves(diff)
 

	
 
	def iterativelyDeepen(self,state2,colorIn,colorOut):
 
		startDepth=1 if colorIn==colorOut else 2
 
		self._g.toMove=colorIn
 

	
 
		for i in range(startDepth,10,2):
 
			seq=self.dfs(state2,i)
 
			if seq:
 
				seq.reverse()
 
				return seq
 
		return None
 

	
 
	def dfs(self,state2,limit):
 
		g=self._g
 
		moveSet=self._moveList[(1-g.toMove)>>1]
 
		for m in moveSet.copy():
 
			if not g.doMove(g.toMove,*m): continue
 
			captured=g.captures[:g.captureCount]
 
			moveSet.remove(m)
 
			if m==PASS: # no reason for both players to pass
 
				self._moveList[(1-g.toMove)>>1].remove(m)
 

	
 
			if g.hash()==state2.hash():
 
				self._undoMove(m,captured)
 
				return [(g.toMove,*m)]
 

	
 
			if limit>1:
 
				seq=self.dfs(state2,limit-1)
 
				if seq:
 
					self._undoMove(m,captured)
 
					seq.append((g.toMove,*m))
 
					return seq
 

	
 
			if limit==1 and g.hash()==state2.hash():
 
				self._undoMove(m,captured)
 
				return [(g.toMove,*m)]
 

	
 
			self._undoMove(m,captured)
 
		return False
 
		return None
 

	
 
	def _undoMove(self,move,captured):
 
		g=self._g
 
		g.undoMove(*move,captured)
 
		k=(1-g.toMove)>>1
 
		self._moveList[k].add(move)
 
		if move==PASS:
 
			self._moveList[1-k].add(move)
 

	
 
eng=Engine()
src/statebag.py
Show inline comments
 
@@ -33,43 +33,44 @@ def estimateDistance(diff):
 

	
 
class GameTreeNode:
 
	def __init__(self,parent,color):
 
		self.parent=parent
 
		self.color=color # color closing the move sequence
 
		self.prev=None
 
		self.moves=[] # move sequence from prev to self
 
		self.weight=0
 

	
 
	## Connect itself after v if it gains itself more weight.
 
	def tryConnect(self,v,diff):
 
		moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff)
 
		if not moves: return
 
		w=v.weight+2-len(moves) # proper one move transition increases the weight by 1
 
		if w>self.weight:
 
			self.moves=moves
 
			self.prev=v
 
			self.weight=w
 

	
 

	
 
class BoardState:
 
	def __init__(self,board):
 
		self._board=tuple(tuple(x for x in row) for row in board)
 
		self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE))
 
		self.diff2Prev=None
 
		self._hash=None
 

	
 
	def tryConnect(self,s):
 
		diff=s-self
 
		diff=self-s
 
		distEst=estimateDistance(diff)
 
		if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm
 
		weightEst=s.getWeight()-distEst
 
		weightEst=s.getWeight()+2-distEst
 
		if weightEst<=self.getWeight(): return
 
		for v1 in s.nodes:
 
			for v2 in self.nodes:
 
				v2.tryConnect(v1,diff)
 

	
 
	def hash(self):
 
		if self._hash is None: self._hash=hashBoard(self._board)
 
		return self._hash
 

	
 
	def export(self):
 
		return exportBoard(self._board)
 

	
 
@@ -85,30 +86,41 @@ class BoardState:
 

	
 
		for (r,(row1,row2)) in enumerate(zip(self._board,x)):
 
			for (c,(item1,item2)) in enumerate(zip(row1,row2)):
 
				if item1==item2: continue
 
				elif item2==EMPTY: res.append((r,c,"+",item1))
 
				elif item1==EMPTY: res.append((r,c,"-",item2))
 
				else: res.append((r,c,"*",item1)) # ->to
 
		return res
 

	
 
	def __eq__(self,x):
 
		return self.hash()==x.hash()
 

	
 
	def setWeight(self,val):
 
		for v in self.nodes: v.weight=val
 

	
 
	def getWeight(self):
 
		return max(v.weight for v in self.nodes)
 

	
 
	def getPrev(self):
 
		v=max(self.nodes, key=lambda v: v.weight)
 
		return v.prev.parent if v.prev else None
 

	
 

	
 
class StateBag:
 
	def __init__(self):
 
		self._states=[]
 

	
 
	def pushState(self,board):
 
		sn=BoardState(board)
 
		if self._states and sn==self._states[-1]: return # no change
 
		if len(self._states)>0:
 
			if sn==self._states[-1]: return None # no change
 
			sn.diff2Prev=sn-self._states[-1]
 
		else: sn.setWeight(1)
 

	
 
		for s in reversed(self._states):
 
			sn.tryConnect(s)
 
		self._states.append(sn)
 
		return sn
 

	
 
	def merge(self,branch):
 
		pass
src/tests/testEngine.py
Show inline comments
 
import re
 
import os.path
 
import logging as log
 
from unittest import TestCase
 

	
 
import config as cfg
 
from util import BLACK as B,WHITE as W,EMPTY as _
 
from go.core import Go
 
from go.engine import SpecGo,Engine
 
from statebag import BoardState
 
from .util import simpleLoadSgf,listStates
 

	
 

	
 
_log=log.getLogger(__name__)
 
_log.setLevel(log.INFO)
 
_log.propagate=False
 
formatter=log.Formatter("%(asctime)s %(levelname)s: %(message)s",datefmt="%Y-%m-%d %H:%M:%S")
 
handler=log.FileHandler("/tmp/oneeye.log",mode="w")
 
handler.setFormatter(formatter)
 
_log.addHandler(handler)
 

	
 

	
 
def simpleLoadSgf(filename):
 
	with open(filename) as f:
 
		contents=f.read()
 
	g=lambda m: tuple(ord(c)-ord('a') for c in reversed(m.group(1)))
 
	return [g(m) for m in re.finditer(r"\b[BW]\[([a-z]{2})\]",contents)]
 

	
 

	
 
def listStates(moves):
 
	g=Go()
 
	res=[BoardState(g.board)]
 
	for m in moves:
 
		g.doMove(g.toMove,*m)
 
		res.append(BoardState(g.board))
 
	return res
 

	
 

	
 
class TestTransitions(TestCase):
 
	def testBasic(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
@@ -107,26 +90,25 @@ class TestTransitions(TestCase):
 
			[_,W,W]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[W,B,B],
 
			[_,W,_]
 
		])
 
		eng.load(s1,s2-s1)
 
		self.assertEqual(eng.dfs(s2,2),[(W,2,1),(B,2,0)])
 

	
 
	def testReal(self):
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 
		g=SpecGo()
 
		eng=Engine(g)
 
		eng=Engine()
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 

	
 
			for k in range(1,4):
 
				toMove=B
 
				for (i,(s1,s2)) in enumerate(zip(states,states[k:])):
 
					diff=s2-s1
 
					eng.load(s1,diff)
 
					colorIn=W if i&1 else B
 
					colorOut=colorIn if k&1 else 1-colorIn
src/tests/testStatebag.py
Show inline comments
 
new file 100644
 
import os.path
 
from unittest import TestCase
 

	
 
import config as cfg
 
from util import BLACK as B,WHITE as W,EMPTY as _
 
from go.engine import SpecGo,Engine
 
import go.engine
 
from statebag import BoardState,StateBag
 
from .util import simpleLoadSgf,listStates
 

	
 

	
 
class TestBoardState(TestCase):
 
	def testBasic(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		go.engine.eng=Engine(g)
 

	
 
		s2.tryConnect(s1)
 
		self.assertIs(s2.getPrev(), s1)
 
		self.assertEqual(s2.getWeight(), 1)
 

	
 
	def test2ply(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
		s3=BoardState([
 
			[_,W,_],
 
			[_,B,B],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		go.engine.eng=Engine(g)
 

	
 
		s2.tryConnect(s1)
 
		s3.tryConnect(s2)
 
		self.assertIs(s3.getPrev(), s2)
 
		self.assertEqual(s3.getWeight(), 1)
 

	
 

	
 
class TestStateBag(TestCase):
 
	def testReal(self):
 
		go.engine.eng=Engine()
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 

	
 
			for k in range(1,3):
 
				bag=StateBag()
 
				i=0
 
				for s_ in states:
 
					i+=1
 
					if i%(2*k-1)>=k: # keep k, skip k-1
 
						continue
 
					s=bag.pushState(s_)
 
					if len(bag._states)>1:
 
						self.assertIs(s.getPrev(), bag._states[-2])
src/tests/util.py
Show inline comments
 
new file 100644
 
import re
 

	
 
from go.core import Go
 
from statebag import BoardState
 

	
 

	
 
def simpleLoadSgf(filename):
 
	with open(filename) as f:
 
		contents=f.read()
 
	g=lambda m: tuple(ord(c)-ord('a') for c in reversed(m.group(1)))
 
	return [g(m) for m in re.finditer(r"\b[BW]\[([a-z]{2})\]",contents)]
 

	
 

	
 
def listStates(moves):
 
	g=Go()
 
	res=[BoardState(g.board)]
 
	for m in moves:
 
		g.doMove(g.toMove,*m)
 
		res.append(BoardState(g.board))
 
	return res
0 comments (0 inline, 0 general)