|
|
|
""" |
|
Example usage of the Advanced Magnus Chess Model from Hugging Face |
|
""" |
|
|
|
import torch |
|
import chess |
|
import yaml |
|
import json |
|
from pathlib import Path |
|
import sys |
|
|
|
|
|
sys.path.append(".") |
|
|
|
|
|
def load_model_from_hf(): |
|
"""Load the Advanced Magnus model""" |
|
try: |
|
from advanced_magnus_predictor import AdvancedMagnusPredictor |
|
|
|
|
|
predictor = AdvancedMagnusPredictor() |
|
|
|
if predictor.model is None: |
|
raise Exception("Failed to load model") |
|
|
|
print("β
Advanced Magnus Chess Model loaded successfully!") |
|
print(f" Device: {predictor.device}") |
|
print(f" Vocabulary size: {predictor.vocab_size}") |
|
print( |
|
f" Parameters: {sum(p.numel() for p in predictor.model.parameters()):,}" |
|
) |
|
|
|
return predictor |
|
|
|
except Exception as e: |
|
print(f"β Failed to load model: {e}") |
|
return None |
|
|
|
|
|
def demo_predictions(predictor): |
|
"""Demonstrate model predictions on various positions""" |
|
|
|
print("\nπ― Magnus Style Move Predictions Demo") |
|
print("=" * 50) |
|
|
|
|
|
positions = [ |
|
{ |
|
"name": "Opening - King's Pawn", |
|
"fen": "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1", |
|
"description": "Black to move after 1.e4", |
|
}, |
|
{ |
|
"name": "Sicilian Defense", |
|
"fen": "rnbqkbnr/pp1ppppp/8/2p5/4P3/8/PPPP1PPP/RNBQKBNR w KQkq c6 0 2", |
|
"description": "White to move after 1.e4 c5", |
|
}, |
|
{ |
|
"name": "Queen's Gambit", |
|
"fen": "rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR b KQkq c3 0 2", |
|
"description": "Black to move after 1.d4 d5 2.c4", |
|
}, |
|
] |
|
|
|
for pos in positions: |
|
print(f"\nπ {pos['name']}") |
|
print(f" {pos['description']}") |
|
print(f" FEN: {pos['fen']}") |
|
|
|
try: |
|
board = chess.Board(pos["fen"]) |
|
predictions = predictor.predict_moves(board, top_k=3) |
|
|
|
print(" π§ Magnus-style predictions:") |
|
for i, pred in enumerate(predictions[:3], 1): |
|
move = pred["move"] |
|
confidence = pred["confidence"] |
|
san = board.san(chess.Move.from_uci(move)) |
|
print(f" {i}. {san} ({move}) - {confidence:.3f} confidence") |
|
|
|
except Exception as e: |
|
print(f" β Error predicting for this position: {e}") |
|
|
|
|
|
def show_model_info(): |
|
"""Display model information""" |
|
print("\nπ Model Information") |
|
print("=" * 30) |
|
|
|
|
|
if Path("config.yaml").exists(): |
|
with open("config.yaml", "r") as f: |
|
config = yaml.safe_load(f) |
|
|
|
print(f"Architecture: {config['model']['architecture']}") |
|
print(f"Version: {config['model']['version']}") |
|
print(f"Parameters: {config['training']['total_params']:,}") |
|
print(f"Vocabulary: {config['training']['vocab_size']} moves") |
|
print( |
|
f"Training time: {config['metrics']['training_time_minutes']:.1f} minutes" |
|
) |
|
print(f"Test accuracy: {config['metrics']['test_accuracy']:.4f}") |
|
print(f"Top-3 accuracy: {config['metrics']['test_top3_accuracy']:.4f}") |
|
print(f"Top-5 accuracy: {config['metrics']['test_top5_accuracy']:.4f}") |
|
|
|
|
|
if Path("version.json").exists(): |
|
with open("version.json", "r") as f: |
|
version = json.load(f) |
|
|
|
print(f"\nModel ID: {version['model_id']}") |
|
print(f"Timestamp: {version['timestamp']}") |
|
print(f"Hash: {version['model_hash'][:16]}...") |
|
|
|
|
|
def main(): |
|
"""Main demo function""" |
|
print("π― Advanced Magnus Chess Model - Demo") |
|
print("π Trained on Magnus Carlsen's games") |
|
print("=" * 60) |
|
|
|
|
|
show_model_info() |
|
|
|
|
|
predictor = load_model_from_hf() |
|
|
|
if predictor is None: |
|
print("Failed to load model. Please ensure all files are present.") |
|
return |
|
|
|
|
|
demo_predictions(predictor) |
|
|
|
print("\n" + "=" * 60) |
|
print("β¨ Demo completed! Try your own positions with the predictor.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|