Spitfire1970
use latest model
0886557
import os
import io
import torch
import requests
import chess.pgn
import numpy as np
from data_objects.game import Game
from encoder.model import Encoder
def generate_alternative_pgns(game):
if not game:
print("couldn't read game")
return [], None, None
# Set up board and get moves
board = game.board()
moves = list(game.mainline_moves())
# Play through the moves up to just before our target
for move in moves:
board.push(move)
# Get all legal moves from this position
legal_moves = list(board.legal_moves)
# Create a new PGN for each legal move
result_pgns = []
move_sans = []
for legal_move in legal_moves:
# Create a copy of the game up to the target position
new_game = chess.pgn.Game()
# Copy headers
for key in game.headers:
new_game.headers[key] = game.headers[key]
# Mark game as unfinished
if "Result" in new_game.headers:
new_game.headers["Result"] = "*"
# Create the move sequence
node = new_game
for move in moves:
node = node.add_variation(move)
# Add our alternative move
node = node.add_variation(legal_move)
# Convert to PGN string
new_pgn = io.StringIO()
exporter = chess.pgn.FileExporter(new_pgn)
new_game.accept(exporter)
# Save the PGN and the SAN notation of this move
result_pgns.append(new_pgn.getvalue())
move_sans.append(board.san(legal_move))
return result_pgns, move_sans
def process_game(game, prediction_mode = False):
def create_position_planes(board: chess.Board, positions_seen: set, cur_player: chess.Color) -> np.ndarray:
def bb_to_plane(bb: int, player: chess.Color) -> np.ndarray:
binary = format(bb, '064b')
h_flipped = np.fliplr(np.array([int(binary[i]) for i in range(64)], dtype=np.float32).reshape(8, 8))
if player:
return h_flipped
else:
return np.flip(h_flipped)
planes = np.zeros((13, 8, 8), dtype=np.float32)
piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
# white pieces (planes 1-6)
for i, piece_type in enumerate(piece_types):
bb = board.pieces_mask(piece_type, chess.WHITE)
planes[i] = bb_to_plane(bb, cur_player)
# black pieces (planes 7-12)
for i, piece_type in enumerate(piece_types):
bb = board.pieces_mask(piece_type, chess.BLACK)
planes[i + 6] = bb_to_plane(bb, cur_player)
# repetition plane (plane 13)
current_position = board.fen().split(' ')[0]
if list(positions_seen).count(current_position) > 1:
planes[12] = 1.0
return planes
board = chess.Board()
positions_seen = set()
positions_seen.add(board.fen().split(' ')[0])
white_moves = []
black_moves = []
node = game
while node.next():
node = node.next()
move = node.move
assert(move is not None)
cur_player = board.turn
current_planes = create_position_planes(board, positions_seen, cur_player)
board.push(move)
positions_seen.add(board.fen().split(' ')[0])
next_planes = create_position_planes(board, positions_seen, cur_player)
assert(not (current_planes==next_planes).all())
# print_planes(next_planes)
move_planes = np.zeros((34, 8, 8), dtype=np.float32)
# first 13 planes (before move)
move_planes[0:13] = current_planes
# next 13 planes (after move)
move_planes[13:26] = next_planes
# castling availability (planes 27-30)
move_planes[26] = float(board.has_queenside_castling_rights(chess.WHITE))
move_planes[27] = float(board.has_kingside_castling_rights(chess.WHITE))
move_planes[28] = float(board.has_queenside_castling_rights(chess.BLACK))
move_planes[29] = float(board.has_kingside_castling_rights(chess.BLACK))
# side to move (plane 31)
move_planes[30] = 1 if board.turn is chess.WHITE else 0
# fifty move counter (plane 32)
move_planes[31] = board.halfmove_clock / 100.0
# move time normalized between 0 and 1 (plane 33)
# change based on time control
clock_info = node.comment.strip('{}[] ').split()[1] if node.comment else "0:00:30"
try:
minutes, seconds = map(int, clock_info.split(':')[1:])
total_seconds = minutes * 60 + seconds
move_planes[32] = min(1.0, total_seconds / 180.0)
except:
move_planes[32] = 0.5
# all 1s (plane 34)
move_planes[33] = 1.0
if board.turn:
black_moves.append(move_planes)
else: # chess.BLACK is falsy
white_moves.append(move_planes)
if (not prediction_mode) and (len(white_moves) < 10 or len(black_moves) < 10):
return None
white_array = np.stack(white_moves, axis=0)
black_array = [] if not black_moves else np.stack(black_moves, axis=0)
return white_array, black_array
class EndpointHandler():
def __init__(self, model_dir):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(os.path.join(model_dir, "6_3.pt"), self.device, weights_only=True)
self.model = Encoder(self.device)
state_dict = checkpoint['model_state']
self.model.load_state_dict(state_dict)
self.model = self.model.to(self.device)
self.model.eval()
self.d = {
0: self.say_hi,
1: self.create_user_embedding,
2: self.ai_move
}
def say_hi(self, _data):
print('entering test endpoint')
print('exiting test endpoint')
return {"reply": "hello from inference api!!"}
def create_user_embedding(self, data):
print('entering create_username endpoint')
username = data["username"]
pgn_content = data["pgn_content"]
games_per_player = data["games_per_player"]
l = []
while True:
game = chess.pgn.read_game(io.StringIO(pgn_content))
if game is None:
print("breaking main loop")
break
white = game.headers.get("White")
black = game.headers.get("Black")
if white == username:
color = "white"
elif black == username:
color = "black"
else:
raise Exception
try:
arrs = process_game(game)
except:
print("skipped")
continue
if arrs is None: # skip if less than 10 moves
print("skipped")
continue
if color == "white":
l.append(arrs[0])
else:
l.append(arrs[1])
if not l: return None
inputs = np.array([Game(g).random_partial() for g in l[:games_per_player]])
num_games = min(len(l), games_per_player)
tensor = torch.tensor(inputs).float().to(self.device)
with torch.no_grad():
embeds = self.model(tensor)
embeds = embeds.view((1, num_games, -1)).to(self.device)
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
centroids_incl = centroids_incl.cpu().squeeze(1)
final_embeds = centroids_incl[0].numpy().tolist()
print('exiting create_username endpoint')
return {"reply": final_embeds}
def ai_move(self, data):
print('entering ai_move endpoint')
pgn_string = data["pgn_string"]
color = data["color"]
player_centroid = data["player_centroid"]
game = chess.pgn.read_game(io.StringIO(pgn_string))
alternative_pgns, move_sans = generate_alternative_pgns(game)
game = chess.pgn.read_game(io.StringIO(pgn_string))
inputs = []
for alt_pgn in alternative_pgns:
game_tensors = process_game(chess.pgn.read_game(io.StringIO(alt_pgn)), True)
game_tensor = game_tensors[0] if color == "white" else game_tensors[1]
inputs.append(game_tensor)
tensor = torch.tensor(np.array(inputs)).float().to(self.device)
with torch.no_grad():
embed = self.model(tensor)
embed = embed / torch.norm(embed)
arr = embed.cpu().numpy()
similarities = [np.dot(np.array(player_centroid), embed) for embed in arr]
result = move_sans[np.argmax(similarities)]
ordered_moves = np.argsort(similarities).tolist()[::-1]
try:
board = game.board()
moves = list(game.mainline_moves())
for move in moves:
board.push(move)
response = requests.post("http://13.49.80.182/stockfish_eval", json={"fen": board.fen()})
if response.status_code == 400:
print(response.text)
print('exiting ai_move endpoint status code before move')
return {"reply": result}
best_eval = response.json()["value"]
best_move = response.json()["best"]
best_move = chess.Move.from_uci(best_move)
best_move = board.san(best_move)
for move in ordered_moves:
test_board = board.copy()
test_board.push(board.parse_san(move_sans[move]))
response = requests.post("http://13.49.80.182/stockfish_eval", json={"fen": test_board.fen()})
if response.status_code == 500:
print('exiting ai_move endpoint status code after move')
return {"reply": best_move}
eval = response.json()["value"]
if (color == "white" and (best_eval - eval < 120)) or (color == "black" and (best_eval - eval > -120)):
print('exiting ai_move endpoint nice found!')
return {"reply": move_sans[move]}
print('exiting ai_move endpoint all moves are shit!')
return {"reply": best_move}
except Exception as e:
print('error sending to lichess', e)
print('exiting ai_move endpoint due to exception')
return {"reply": result}
def __call__(self, data):
data = data.get("inputs", data)
return self.d[data["endpoint_num"]](data)