Changeset - 5b4c6f5a0a28
[Not reviewed]
default
0 3 1
Laman - 6 years ago 2018-11-29 17:30:15

transposition table
4 files changed with 63 insertions and 30 deletions:
0 comments (0 inline, 0 general)
src/go/engine.py
Show inline comments
 
from .core import PASS
 
from .transpositiontable import TranspositionTable
 
from . import core
 

	
 

	
 
## Compute move sequence from state1 to state2.
 
#
 
# @param colorIn {BLACK,WHITE}: color to start the sequence
 
# @param colorOut {BLACK,WHITE}: color to close the sequence
 
def getTransitionSequence(state1,state2,colorIn,colorOut,diff):
 
	eng.load(state1,diff)
 
	return eng.iterativelyDeepen(state2,colorIn,colorOut)
 
	eng.load(state1)
 
	return eng.iterativelyDeepen(state2,diff,colorIn,colorOut)
 

	
 

	
 
class SpecGo(core.Go):
 
	def __init__(self,boardSize=19):
 
		super().__init__(boardSize)
 

	
 
	def listRelevantMoves(self,diff):
 
		"""There can be 3 different changes in the diff: additions, deletions and replacements.
 
		Additions can be taken as relevant right away.
 
		Deletions and replacements had to be captured, so we add their liberties.
 
		Also any non-missing stones of partially deleted (or replaced) groups had to be replayed, so add them too.
 
		Needs to handle snapback, throw-in.
 
		There's no end to what could be theoretically relevant, but such sequences are long and we will pretend they won't happen."""
 
		res=({PASS},{PASS})
 
		for d in diff:
 
			(r,c,action,color)=d
 
			colorKey=(1-color)>>1 # {-1,1}->{1,0}
 
			if action!="-" and (r,c) not in res[colorKey]:
 
				res[colorKey].add((r,c))
 
				if action=="*":
 
				for (ri,ci) in self.listNeighbours(r,c): # in case a stone was played and captured. !! might want to add even more
 
					res[1-colorKey].add((ri,ci))
 
			# this is rather sloppy but correct. the time will show if it is effective enough
 
			# just floodFill from the current intersection, add everything you find and also all the neighbours to be sure
 
			if action!="+" and (r,c) not in res[colorKey] and (r,c) not in res[1-colorKey]:
 
				self._helper.clear()
 
				self._helper.floodFill(color if action=="-" else 1-color, r, c)
 
				res[colorKey].union(self._helper.getContinuousArea())
 
				for (ri,ci) in self._helper.getContinuousArea():
 
					res[colorKey].add((ri,ci))
 
					res[1-colorKey].add((ri,ci))
 
					for (rj,cj) in self.listNeighbours(ri,ci):
 
						res[colorKey].add((rj,cj))
 
						res[1-colorKey].add((rj,cj))
 
		return res
 

	
 
	def listNeighbours(self,r,c):
 
		if r>0: yield (r-1,c)
 
		if r+1<self.boardSize: yield (r+1,c)
 
		if c>0: yield (r,c-1)
 
		if c+1<self.boardSize: yield (r,c+1)
 

	
 

	
 
class Engine:
 
	"""Class searching for move sequences changing one board state into another."""
 
	def __init__(self,g=None):
 
		self._g=g or SpecGo()
 
		self._moveList=(set(),set())
 
		self._transpositions=TranspositionTable()
 

	
 
	def load(self,state1,diff):
 
	def load(self,state1):
 
		self._g.load(state1)
 

	
 
	def iterativelyDeepen(self,state2,diff,colorIn,colorOut):
 
		"""Search for a move sequence from the loaded state to state2. Tries progressively longer sequences."""
 
		self._moveList=self._g.listRelevantMoves(diff)
 

	
 
	def iterativelyDeepen(self,state2,colorIn,colorOut):
 
		startDepth=1 if colorIn==colorOut else 2
 
		self._g.toMove=colorIn
 

	
 
		for i in range(startDepth,6,2):
 
			seq=self.dfs(state2,i)
 
		for i in range(startDepth,5,2):
 
			seq=self._dfs(state2,i)
 
			if seq:
 
				seq.reverse()
 
				return seq
 
		return None
 

	
 
	def dfs(self,state2,limit):
 
	def _dfs(self,state2,limit):
 
		"""Search for a "limit" move sequence from the loaded state to state2."""
 
		g=self._g
 
		moveSet=self._moveList[(1-g.toMove)>>1]
 
		transKey=(g.hash()*state2.hash()*limit)&0xffffffff
 

	
 
		transSeq=self._transpositions.get(transKey)
 
		if transSeq is not None:
 
			return transSeq[:] if transSeq else transSeq
 

	
 
		for m in moveSet.copy():
 
			if not g.doMove(g.toMove,*m): continue
 
			captured=g.captures[:g.captureCount]
 
			moveSet.remove(m)
 
			if m==PASS: # no reason for both players to pass
 
				self._moveList[(1-g.toMove)>>1].remove(m)
 

	
 
			if limit>1:
 
				seq=self.dfs(state2,limit-1)
 
				seq=self._dfs(state2,limit-1)
 
				if seq:
 
					self._undoMove(m,captured)
 
					seq.append((g.toMove,*m))
 
					self._transpositions.put(transKey,seq[:])
 
					return seq
 
				else: self._transpositions.put(transKey,False)
 

	
 
			if limit==1 and g.hash()==state2.hash():
 
				self._undoMove(m,captured)
 
				return [(g.toMove,*m)]
 

	
 
			self._undoMove(m,captured)
 
		return None
 

	
 
	def _undoMove(self,move,captured):
 
		g=self._g
 
		g.undoMove(*move,captured)
 
		k=(1-g.toMove)>>1
 
		self._moveList[k].add(move)
 
		if move==PASS:
 
			self._moveList[1-k].add(move)
 

	
 
eng=Engine()
src/go/transpositiontable.py
Show inline comments
 
new file 100644
 
class TranspositionTable:
 
	def __init__(self,capacity=2**20):
 
		self._capacity=capacity
 
		self._table=[None]*capacity
 

	
 
	def put(self,key,val):
 
		self._table[key%self._capacity]=(key,val)
 

	
 
	def get(self,key):
 
		res=self._table[key%self._capacity]
 
		if res is None: return None
 
		elif res[0]==key: return res[1]
 
		else: return None
src/statebag/boardstate.py
Show inline comments
 
@@ -42,47 +42,48 @@ class BoardState:
 
	def tryConnect(self,s,diff=None):
 
		if diff is None: diff=self-s
 
		distEst=estimateDistance(diff)
 
		if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm
 
		weightEst=s.getWeight()+2-distEst
 
		if weightEst<=self.getWeight(): return
 
		for v1 in s.nodes:
 
			for v2 in self.nodes:
 
				v2.tryConnect(v1,diff)
 

	
 
	def hash(self):
 
		if self._hash is None: self._hash=hashBoard(self._board)
 
		return self._hash
 

	
 
	def export(self):
 
		return exportBoard(self._board)
 

	
 
	def exportDiff(self,s2):
 
		return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export())
 

	
 
	def __iter__(self): return iter(self._board)
 

	
 
	def __getitem__(self,key): return self._board[key]
 

	
 
	def __sub__(self,x):
 
	## Compute difference self-s.
 
	def __sub__(self,s):
 
		res=[]
 

	
 
		for (r,(row1,row2)) in enumerate(zip(self._board,x)):
 
			for (c,(item1,item2)) in enumerate(zip(row1,row2)):
 
				if item1==item2: continue
 
				elif item2==EMPTY: res.append((r,c,"+",item1))
 
				elif item1==EMPTY: res.append((r,c,"-",item2))
 
				else: res.append((r,c,"*",item1)) # ->to
 
		for (r,(row,rowS)) in enumerate(zip(self._board,s)):
 
			for (c,(item,itemS)) in enumerate(zip(row,rowS)):
 
				if item==itemS: continue
 
				elif itemS==EMPTY: res.append((r,c,"+",item))
 
				elif item==EMPTY: res.append((r,c,"-",itemS))
 
				else: res.append((r,c,"*",item)) # ->to
 
		return res
 

	
 
	def __eq__(self,x):
 
		return self.hash()==x.hash()
 

	
 
	def setWeight(self,val):
 
		for v in self.nodes: v.weight=val
 

	
 
	def getWeight(self):
 
		return max(v.weight for v in self.nodes)
 

	
 
	def getPrev(self):
 
		v=max(self.nodes, key=lambda v: v.weight)
 
		return v.prev.parent if v.prev else None
src/tests/testEngine.py
Show inline comments
 
@@ -11,110 +11,115 @@ from .util import simpleLoadSgf,listStat
 

	
 
_log=log.getLogger(__name__)
 
_log.setLevel(log.INFO)
 
_log.propagate=False
 
formatter=log.Formatter("%(asctime)s %(levelname)s: %(message)s",datefmt="%Y-%m-%d %H:%M:%S")
 
handler=log.FileHandler("/tmp/oneeye.log",mode="w")
 
handler.setFormatter(formatter)
 
_log.addHandler(handler)
 

	
 

	
 
class TestTransitions(TestCase):
 
	def testBasic(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,_],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		eng=Engine(g)
 
		eng.load(s1,s2-s1)
 
		self.assertEqual(eng.dfs(s2,1),[(1,1,1)])
 
		eng.load(s1)
 
		eng._moveList=eng._g.listRelevantMoves(s2-s1)
 
		self.assertEqual(eng._dfs(s2,1),[(1,1,1)])
 

	
 
	def testCapture(self):
 
		s1=BoardState([
 
			[_,W,_],
 
			[W,B,_],
 
			[_,W,_]
 
		])
 
		s2=BoardState([
 
			[_,W,_],
 
			[W,_,W],
 
			[_,W,_]
 
		])
 
		g=SpecGo(3)
 
		g.toMove=W
 
		eng=Engine(g)
 
		eng.load(s1,s2-s1)
 
		self.assertEqual(eng.dfs(s2,1),[(W,1,2)])
 
		eng.load(s1)
 
		eng._moveList=eng._g.listRelevantMoves(s2-s1)
 
		self.assertEqual(eng._dfs(s2,1),[(W,1,2)])
 

	
 
	def testMulti(self):
 
		s1=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,_,_]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,B,W],
 
			[_,_,_]
 
		])
 
		g=SpecGo(3)
 
		eng=Engine(g)
 
		eng.load(s1,s2-s1)
 
		self.assertEqual(eng.dfs(s2,2),[(W,1,2),(B,1,1)])
 
		eng.load(s1)
 
		eng._moveList=eng._g.listRelevantMoves(s2-s1)
 
		self.assertEqual(eng._dfs(s2,2),[(W,1,2),(B,1,1)])
 

	
 
	def testSnapback(self):
 
		s1=BoardState([
 
			[B,B,B],
 
			[B,_,B],
 
			[B,W,B]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[_,_,_],
 
			[_,W,_]
 
		])
 
		g=SpecGo(3)
 
		eng=Engine(g)
 
		eng.load(s1,s2-s1)
 
		self.assertEqual(eng.dfs(s2,2),[(W,2,1),(B,1,1)])
 
		eng.load(s1)
 
		eng._moveList=eng._g.listRelevantMoves(s2-s1)
 
		self.assertEqual(eng._dfs(s2,2),[(W,2,1),(B,1,1)])
 

	
 
		s1=BoardState([
 
			[_,_,_],
 
			[W,B,B],
 
			[_,W,W]
 
		])
 
		s2=BoardState([
 
			[_,_,_],
 
			[W,B,B],
 
			[_,W,_]
 
		])
 
		eng.load(s1,s2-s1)
 
		self.assertEqual(eng.dfs(s2,2),[(W,2,1),(B,2,0)])
 
		eng.load(s1)
 
		eng._moveList=eng._g.listRelevantMoves(s2-s1)
 
		self.assertEqual(eng._dfs(s2,2),[(W,2,1),(B,2,0)])
 

	
 
	def testReal(self):
 
		files=["O-Takao-20110106.sgf","Sakai-Iyama-20110110.sgf"]
 
		eng=Engine()
 

	
 
		for f in files:
 
			moves=simpleLoadSgf(os.path.join(cfg.srcDir,"tests/data",f))
 
			states=listStates(moves)
 

	
 
			for k in range(1,4):
 
				toMove=B
 
				for (i,(s1,s2)) in enumerate(zip(states,states[k:])):
 
					diff=s2-s1
 
					eng.load(s1,diff)
 
					eng.load(s1)
 
					colorIn=W if i&1 else B
 
					colorOut=colorIn if k&1 else 1-colorIn
 
					seq=eng.iterativelyDeepen(s2,colorIn,colorOut)
 
					msg="\n"+s1.exportDiff(s2)
 
					colorOut=colorIn if k&1 else -colorIn
 
					seq=eng.iterativelyDeepen(s2,diff,colorIn,colorOut)
 
					msg="\n>>> {0} ({1}, {2})\n>>> {3}\n".format(f,k,i,str(seq))+s1.exportDiff(s2)
 
					self.assertIsNotNone(seq,msg)
 
					self.assertLessEqual(len(seq),k,msg)
 
					if len(seq)!=k: _log.warning("shorter than expected transition sequence:" + msg + "\n" + str(seq))
 
					toMove*=-1
0 comments (0 inline, 0 general)