Spitfire1970 commited on
Commit
7f16368
·
1 Parent(s): 84efe39
Files changed (1) hide show
  1. handler.py +33 -61
handler.py CHANGED
@@ -2,48 +2,11 @@ import os
2
  import io
3
  import torch
4
  import requests
5
- import time
6
  import chess.pgn
7
- import chess.engine
8
  import numpy as np
9
  from data_objects.game import Game
10
  from encoder.model import Encoder
11
 
12
- STOCKFISH_CP_THRESHOLD = 100 # Allow style move if within this many centipawns of Stockfish's best
13
- MAX_STOCKFISH_DEPTH = 20 # Limit depth to avoid long waits
14
-
15
- def query_stockfish_eval(fen):
16
- url = f"https://lichess.org/api/cloud-eval?fen={fen}&multiPv=3&depth={MAX_STOCKFISH_DEPTH}"
17
- headers = {"Accept": "application/json"}
18
- try:
19
- response = requests.get(url, headers=headers)
20
- if response.status_code == 200:
21
- return response.json()
22
- else:
23
- print(f"Stockfish query failed with status {response.status_code}")
24
- return None
25
- except Exception as e:
26
- print(f"Error querying Stockfish: {e}")
27
- return None
28
-
29
- def is_reasonable_move(style_move, stockfish_pvs):
30
- for pv in stockfish_pvs:
31
- if style_move in pv["moves"].split():
32
- return True
33
- best_cp = stockfish_pvs[0].get("cp", 0)
34
- for pv in stockfish_pvs:
35
- if abs(pv.get("cp", 0) - best_cp) <= STOCKFISH_CP_THRESHOLD:
36
- if style_move == pv["moves"].split()[0]:
37
- return True
38
- return False
39
-
40
- def get_fen_after_move(game, move_san):
41
- board = game.board()
42
- for move in game.mainline_moves():
43
- board.push(move)
44
- move = board.parse_san(move_san)
45
- board.push(move)
46
- return board.fen()
47
 
48
  def generate_alternative_pgns(game):
49
  if not game:
@@ -290,34 +253,43 @@ class EndpointHandler():
290
  embed = embed / torch.norm(embed)
291
 
292
  arr = embed.cpu().numpy()
293
- similarities = [np.dot(np.array(player_centroid), emb) for emb in arr]
294
- ordered_indices = np.argsort(similarities)[::-1]
295
-
296
- for idx in ordered_indices:
297
- move_san = move_sans[idx]
298
- fen_after_move = get_fen_after_move(game, move_san)
299
 
300
- sf_eval = query_stockfish_eval(fen_after_move)
301
- if sf_eval and "pvs" in sf_eval:
302
- if is_reasonable_move(move_san, sf_eval["pvs"]):
303
- return move_san
304
- else:
305
- # fallback to style move if no eval available
306
- return move_san
307
 
308
- time.sleep(0.2) # Respect API rate limit
 
 
 
 
 
 
 
 
 
309
 
310
- # Fallback: take Stockfish's best move from original position
311
- board = game.board()
312
- for move in game.mainline_moves():
313
- board.push(move)
314
- fallback_eval = query_stockfish_eval(board.fen())
315
- if fallback_eval and "pvs" in fallback_eval:
316
- return fallback_eval["pvs"][0]["moves"].split()[0]
 
 
 
 
 
317
 
318
- # Final fallback: return top style move
319
- return move_sans[ordered_indices[0]]
 
 
320
 
321
  def __call__(self, data):
322
  data = data.get("inputs", data)
323
- return self.d[data["endpoint_num"]](data)
 
2
  import io
3
  import torch
4
  import requests
 
5
  import chess.pgn
 
6
  import numpy as np
7
  from data_objects.game import Game
8
  from encoder.model import Encoder
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def generate_alternative_pgns(game):
12
  if not game:
 
253
  embed = embed / torch.norm(embed)
254
 
255
  arr = embed.cpu().numpy()
256
+ similarities = [np.dot(np.array(player_centroid), embed) for embed in arr]
257
+ result = move_sans[np.argmax(similarities)]
 
 
 
 
258
 
259
+ ordered_moves = np.argsort(similarities).tolist()[::-1]
260
+ try:
261
+ board = game.board()
262
+ moves = list(game.mainline_moves())
 
 
 
263
 
264
+ # Play through the moves up to just before our target
265
+ for move in moves:
266
+ board.push(move)
267
+ url = f"https://lichess.org/api/cloud-eval?fen={board.fen()}"
268
+ headers = {"Accept": "application/json"}
269
+ response = requests.get(url, headers=headers)
270
+ if response.status_code == 404:
271
+ print('exiting ai_move endpoint status code before move')
272
+ return {"reply": result}
273
+ best_eval = response.json()["pvs"][0]["cp"]
274
 
275
+ for move in ordered_moves:
276
+ board.push_san(move_sans[move])
277
+ url = f"https://lichess.org/api/cloud-eval?fen={board.fen()}"
278
+ headers = {"Accept": "application/json"}
279
+ response = requests.get(url, headers=headers)
280
+ if response.status_code == 404 or "pvs" not in response.json():
281
+ print('exiting ai_move endpoint status code after move')
282
+ return {"reply": result}
283
+ eval = response.json()["pvs"][0]["cp"]
284
+ if (color == "white" and (best_eval - eval < 100)) or (color == "black" and (best_eval - eval < -100)):
285
+ print('exiting ai_move endpoint nice found!')
286
+ return {"reply": move}
287
 
288
+ except:
289
+ print('error sending to lichess')
290
+ print('exiting ai_move endpoint all moves are shit!')
291
+ return {"reply": result}
292
 
293
  def __call__(self, data):
294
  data = data.get("inputs", data)
295
+ return self.d[data["endpoint_num"]](data)