Files @ 630c42e6d376
Branch filter:

Location: OneEye/src/statebag/boardstate.py

Laman
requirements.txt
from util import EMPTY,BLACK,WHITE, hashBoard,exportBoard
from go.engine import getTransitionSequence,SpecGo


g=SpecGo()


## Crude lower bound on edit distance between states.
def estimateDistance(diff,s1,s2):
	# lot of room for improvements
	additions=deletions=replacements=unaccounted=0
	for (r,c,d,color) in diff:
		if d=="+": additions+=1
		elif d=="-":
			deletions+=1
			for (ri,ci) in g.listNeighbours(r,c):
				if s1[ri][ci]==EMPTY and s2[ri][ci]==EMPTY:
					unaccounted+=1
		else: replacements+=1
	if additions>0 or unaccounted>0: return additions+replacements+unaccounted
	elif replacements==0 and deletions>0: return 2 # take n, return 1
	return replacements+1 # ???


class GameTreeNode:
	def __init__(self,parent,color):
		self.parent=parent
		self.color=color # color closing the move sequence
		self.prev=None
		self.moves=[] # move sequence from prev to self
		self.weight=0

	## Connect itself after v if it gains itself more weight.
	def tryConnect(self,v,diff):
		moves=getTransitionSequence(v.parent,self.parent,-1*v.color,self.color,diff)
		if not moves: return
		w=v.weight+2-len(moves) # proper one move transition increases the weight by 1
		if w>self.weight:
			self.moves=moves
			self.prev=v
			self.weight=w

	def exportRecord(self):
		""":return: [(c,row,col), ...]. c in {BLACK,WHITE} == {1,-1}"""
		sequence=[]
		v=self
		while v is not None:
			sequence.append(v)
			v=v.prev

		res=[]
		for v in reversed(sequence):
			res.extend(v.moves)
		return res


class BoardState:
	def __init__(self,board):
		self._board=tuple(tuple(x for x in row) for row in board)
		self.nodes=(GameTreeNode(self,BLACK),GameTreeNode(self,WHITE))
		self.cachedDiff=[]
		self._hash=None

	def tryConnect(self,s,diff=None):
		""":param s: BoardState s
		:param diff: [(r,c,change,color), ...], change in {+,-,*}, color in {BLACK,WHITE}"""
		if diff is None: diff=self-s
		distEst=estimateDistance(diff,s,self)
		if distEst>3: return # we couldn't find every such move sequence anyway without a clever algorithm
		weightEst=s.getWeight()+2-distEst
		if weightEst<=self.getWeight(): return
		for v1 in s.nodes:
			for v2 in self.nodes:
				v2.tryConnect(v1,diff)

	def hash(self):
		if self._hash is None: self._hash=hashBoard(self._board)
		return self._hash

	def export(self):
		return exportBoard(self._board)

	def exportDiff(self,s2):
		return "vvv\n{0}\n=== {1} ===\n{2}\n^^^".format(self.export(), s2-self, s2.export())

	def exportRecord(self):
		""":return: [(c,row,col), ...]. c in {BLACK,WHITE} == {1,-1}"""
		v=self.nodes[0] if self.nodes[0].weight>self.nodes[1].weight else self.nodes[1]
		return v.exportRecord()

	def __iter__(self): return iter(self._board)

	def __getitem__(self,key): return self._board[key]

	## Compute difference self-s.
	def __sub__(self,s):
		res=[]

		for (r,(row,rowS)) in enumerate(zip(self._board,s)):
			for (c,(item,itemS)) in enumerate(zip(row,rowS)):
				if item==itemS: continue
				elif itemS==EMPTY: res.append((r,c,"+",item))
				elif item==EMPTY: res.append((r,c,"-",itemS))
				else: res.append((r,c,"*",item)) # ->to
		return res

	def __eq__(self,x):
		return self.hash()==x.hash()

	def setWeight(self,val):
		for v in self.nodes: v.weight=val

	def getWeight(self):
		return max(v.weight for v in self.nodes)

	def getPrev(self):
		v=max(self.nodes, key=lambda v: v.weight)
		return v.prev.parent if v.prev else None