|
import os |
|
import io |
|
import torch |
|
import requests |
|
import chess.pgn |
|
import numpy as np |
|
from data_objects.game import Game |
|
from encoder.model import Encoder |
|
|
|
|
|
def generate_alternative_pgns(game): |
|
if not game: |
|
print("couldn't read game") |
|
return [], None, None |
|
|
|
|
|
board = game.board() |
|
moves = list(game.mainline_moves()) |
|
|
|
|
|
for move in moves: |
|
board.push(move) |
|
|
|
|
|
legal_moves = list(board.legal_moves) |
|
|
|
|
|
result_pgns = [] |
|
move_sans = [] |
|
|
|
for legal_move in legal_moves: |
|
|
|
new_game = chess.pgn.Game() |
|
|
|
|
|
for key in game.headers: |
|
new_game.headers[key] = game.headers[key] |
|
|
|
|
|
if "Result" in new_game.headers: |
|
new_game.headers["Result"] = "*" |
|
|
|
|
|
node = new_game |
|
for move in moves: |
|
node = node.add_variation(move) |
|
|
|
|
|
node = node.add_variation(legal_move) |
|
|
|
|
|
new_pgn = io.StringIO() |
|
exporter = chess.pgn.FileExporter(new_pgn) |
|
new_game.accept(exporter) |
|
|
|
|
|
result_pgns.append(new_pgn.getvalue()) |
|
move_sans.append(board.san(legal_move)) |
|
|
|
return result_pgns, move_sans |
|
|
|
def process_game(game, prediction_mode = False): |
|
def create_position_planes(board: chess.Board, positions_seen: set, cur_player: chess.Color) -> np.ndarray: |
|
|
|
def bb_to_plane(bb: int, player: chess.Color) -> np.ndarray: |
|
binary = format(bb, '064b') |
|
h_flipped = np.fliplr(np.array([int(binary[i]) for i in range(64)], dtype=np.float32).reshape(8, 8)) |
|
if player: |
|
return h_flipped |
|
else: |
|
return np.flip(h_flipped) |
|
|
|
planes = np.zeros((13, 8, 8), dtype=np.float32) |
|
|
|
piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING] |
|
|
|
|
|
for i, piece_type in enumerate(piece_types): |
|
bb = board.pieces_mask(piece_type, chess.WHITE) |
|
planes[i] = bb_to_plane(bb, cur_player) |
|
|
|
|
|
for i, piece_type in enumerate(piece_types): |
|
bb = board.pieces_mask(piece_type, chess.BLACK) |
|
planes[i + 6] = bb_to_plane(bb, cur_player) |
|
|
|
|
|
current_position = board.fen().split(' ')[0] |
|
if list(positions_seen).count(current_position) > 1: |
|
planes[12] = 1.0 |
|
|
|
return planes |
|
|
|
board = chess.Board() |
|
positions_seen = set() |
|
positions_seen.add(board.fen().split(' ')[0]) |
|
|
|
white_moves = [] |
|
black_moves = [] |
|
|
|
node = game |
|
while node.next(): |
|
node = node.next() |
|
move = node.move |
|
assert(move is not None) |
|
cur_player = board.turn |
|
|
|
current_planes = create_position_planes(board, positions_seen, cur_player) |
|
|
|
board.push(move) |
|
|
|
positions_seen.add(board.fen().split(' ')[0]) |
|
|
|
next_planes = create_position_planes(board, positions_seen, cur_player) |
|
assert(not (current_planes==next_planes).all()) |
|
|
|
|
|
move_planes = np.zeros((34, 8, 8), dtype=np.float32) |
|
|
|
|
|
move_planes[0:13] = current_planes |
|
|
|
|
|
move_planes[13:26] = next_planes |
|
|
|
|
|
move_planes[26] = float(board.has_queenside_castling_rights(chess.WHITE)) |
|
move_planes[27] = float(board.has_kingside_castling_rights(chess.WHITE)) |
|
move_planes[28] = float(board.has_queenside_castling_rights(chess.BLACK)) |
|
move_planes[29] = float(board.has_kingside_castling_rights(chess.BLACK)) |
|
|
|
|
|
move_planes[30] = 1 if board.turn is chess.WHITE else 0 |
|
|
|
|
|
move_planes[31] = board.halfmove_clock / 100.0 |
|
|
|
|
|
|
|
clock_info = node.comment.strip('{}[] ').split()[1] if node.comment else "0:00:30" |
|
try: |
|
minutes, seconds = map(int, clock_info.split(':')[1:]) |
|
total_seconds = minutes * 60 + seconds |
|
move_planes[32] = min(1.0, total_seconds / 180.0) |
|
except: |
|
move_planes[32] = 0.5 |
|
|
|
|
|
move_planes[33] = 1.0 |
|
|
|
if board.turn: |
|
black_moves.append(move_planes) |
|
else: |
|
white_moves.append(move_planes) |
|
|
|
if (not prediction_mode) and (len(white_moves) < 10 or len(black_moves) < 10): |
|
return None |
|
|
|
white_array = np.stack(white_moves, axis=0) |
|
black_array = [] if not black_moves else np.stack(black_moves, axis=0) |
|
|
|
return white_array, black_array |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, model_dir): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
checkpoint = torch.load(os.path.join(model_dir, "6_3.pt"), self.device, weights_only=True) |
|
self.model = Encoder(self.device) |
|
state_dict = checkpoint['model_state'] |
|
self.model.load_state_dict(state_dict) |
|
self.model = self.model.to(self.device) |
|
self.model.eval() |
|
self.d = { |
|
0: self.say_hi, |
|
1: self.create_user_embedding, |
|
2: self.ai_move |
|
} |
|
|
|
def say_hi(self, _data): |
|
print('entering test endpoint') |
|
|
|
print('exiting test endpoint') |
|
return {"reply": "hello from inference api!!"} |
|
|
|
def create_user_embedding(self, data): |
|
print('entering create_username endpoint') |
|
username = data["username"] |
|
pgn_content = data["pgn_content"] |
|
games_per_player = data["games_per_player"] |
|
|
|
l = [] |
|
while True: |
|
game = chess.pgn.read_game(io.StringIO(pgn_content)) |
|
if game is None: |
|
print("breaking main loop") |
|
break |
|
white = game.headers.get("White") |
|
black = game.headers.get("Black") |
|
if white == username: |
|
color = "white" |
|
elif black == username: |
|
color = "black" |
|
else: |
|
raise Exception |
|
try: |
|
arrs = process_game(game) |
|
except: |
|
print("skipped") |
|
continue |
|
if arrs is None: |
|
print("skipped") |
|
continue |
|
if color == "white": |
|
l.append(arrs[0]) |
|
else: |
|
l.append(arrs[1]) |
|
if not l: return None |
|
|
|
inputs = np.array([Game(g).random_partial() for g in l[:games_per_player]]) |
|
num_games = min(len(l), games_per_player) |
|
|
|
tensor = torch.tensor(inputs).float().to(self.device) |
|
with torch.no_grad(): |
|
embeds = self.model(tensor) |
|
embeds = embeds.view((1, num_games, -1)).to(self.device) |
|
centroids_incl = torch.mean(embeds, dim=1, keepdim=True) |
|
centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) |
|
centroids_incl = centroids_incl.cpu().squeeze(1) |
|
final_embeds = centroids_incl[0].numpy().tolist() |
|
|
|
print('exiting create_username endpoint') |
|
return {"reply": final_embeds} |
|
|
|
def ai_move(self, data): |
|
print('entering ai_move endpoint') |
|
pgn_string = data["pgn_string"] |
|
color = data["color"] |
|
player_centroid = data["player_centroid"] |
|
|
|
game = chess.pgn.read_game(io.StringIO(pgn_string)) |
|
alternative_pgns, move_sans = generate_alternative_pgns(game) |
|
game = chess.pgn.read_game(io.StringIO(pgn_string)) |
|
|
|
inputs = [] |
|
for alt_pgn in alternative_pgns: |
|
game_tensors = process_game(chess.pgn.read_game(io.StringIO(alt_pgn)), True) |
|
game_tensor = game_tensors[0] if color == "white" else game_tensors[1] |
|
inputs.append(game_tensor) |
|
|
|
tensor = torch.tensor(np.array(inputs)).float().to(self.device) |
|
with torch.no_grad(): |
|
embed = self.model(tensor) |
|
embed = embed / torch.norm(embed) |
|
|
|
arr = embed.cpu().numpy() |
|
similarities = [np.dot(np.array(player_centroid), embed) for embed in arr] |
|
result = move_sans[np.argmax(similarities)] |
|
|
|
ordered_moves = np.argsort(similarities).tolist()[::-1] |
|
try: |
|
board = game.board() |
|
moves = list(game.mainline_moves()) |
|
|
|
for move in moves: |
|
board.push(move) |
|
response = requests.post("http://13.49.80.182/stockfish_eval", json={"fen": board.fen()}) |
|
|
|
if response.status_code == 400: |
|
print(response.text) |
|
print('exiting ai_move endpoint status code before move') |
|
return {"reply": result} |
|
best_eval = response.json()["value"] |
|
best_move = response.json()["best"] |
|
best_move = chess.Move.from_uci(best_move) |
|
best_move = board.san(best_move) |
|
|
|
for move in ordered_moves: |
|
test_board = board.copy() |
|
test_board.push(board.parse_san(move_sans[move])) |
|
response = requests.post("http://13.49.80.182/stockfish_eval", json={"fen": test_board.fen()}) |
|
if response.status_code == 500: |
|
print('exiting ai_move endpoint status code after move') |
|
return {"reply": best_move} |
|
eval = response.json()["value"] |
|
if (color == "white" and (best_eval - eval < 120)) or (color == "black" and (best_eval - eval > -120)): |
|
print('exiting ai_move endpoint nice found!') |
|
return {"reply": move_sans[move]} |
|
print('exiting ai_move endpoint all moves are shit!') |
|
return {"reply": best_move} |
|
|
|
except Exception as e: |
|
print('error sending to lichess', e) |
|
print('exiting ai_move endpoint due to exception') |
|
return {"reply": result} |
|
|
|
def __call__(self, data): |
|
data = data.get("inputs", data) |
|
return self.d[data["endpoint_num"]](data) |