Spaces:
Sleeping
Sleeping
import { Chess, Move } from 'chess.js' | |
import { pipeline } from '@xenova/transformers' | |
export class ChessAI { | |
private model: any = null | |
private isLoading: boolean = false | |
private modelId: string | |
constructor(modelId: string = 'mlabonne/chesspythia-70m') { | |
this.modelId = modelId | |
} | |
async initialize(): Promise<void> { | |
if (this.model || this.isLoading) return | |
this.isLoading = true | |
try { | |
this.model = await pipeline('text-generation', this.modelId) | |
} catch (error) { | |
console.error('Failed to load chess model:', error) | |
this.model = null | |
} finally { | |
this.isLoading = false | |
} | |
} | |
async getMove(chess: Chess, timeLimit: number = 10000): Promise<Move | null> { | |
if (!this.model) { | |
console.log('Using fallback AI (model not loaded)') | |
return this.getFallbackMove(chess) | |
} | |
try { | |
const legalMoves = chess.moves({ verbose: true }) | |
if (legalMoves.length === 0) return null | |
const prompt = this.createChessPrompt(chess) | |
const startTime = Date.now() | |
const result = await Promise.race([ | |
this.generateMove(prompt, legalMoves), | |
new Promise<null>((_, reject) => | |
setTimeout(() => reject(new Error('Timeout')), timeLimit) | |
) | |
]) | |
const elapsedTime = Date.now() - startTime | |
console.log(`AI move generated in ${elapsedTime}ms`) | |
return result || this.getFallbackMove(chess) | |
} catch (error) { | |
console.error('Error generating AI move:', error) | |
return this.getFallbackMove(chess) | |
} | |
} | |
private async generateMove(prompt: string, legalMoves: Move[]): Promise<Move | null> { | |
if (!this.model) return null | |
try { | |
const output = await this.model(prompt, { | |
max_new_tokens: 10, | |
temperature: 0.7 | |
}) | |
const generatedText = Array.isArray(output) ? output[0]?.generated_text : output.generated_text | |
console.log('Model output:', generatedText) | |
const move = this.parseMove(generatedText, legalMoves) | |
return move | |
} catch (error) { | |
console.error('Error in model generation:', error) | |
return null | |
} | |
} | |
private createChessPrompt(chess: Chess): string { | |
const turn = chess.turn() | |
const moveNumber = chess.moveNumber() | |
if (turn === 'w') { | |
return `${moveNumber}.` | |
} else { | |
return `${moveNumber}...` | |
} | |
} | |
private parseMove(generatedText: string, legalMoves: Move[]): Move | null { | |
if (!generatedText) return null | |
const cleanText = generatedText.trim().replace(/[+#]$/, '') | |
for (const move of legalMoves) { | |
if (move.san === cleanText || move.lan === cleanText) { | |
return move | |
} | |
} | |
for (const move of legalMoves) { | |
if (move.san.startsWith(cleanText) || cleanText.includes(move.san)) { | |
return move | |
} | |
} | |
console.log(`Could not parse move "${cleanText}" from legal moves:`, legalMoves.map(m => m.san)) | |
return null | |
} | |
private getFallbackMove(chess: Chess): Move | null { | |
const legalMoves = chess.moves({ verbose: true }) | |
if (legalMoves.length === 0) return null | |
let candidateMoves = legalMoves.filter(move => move.captured) | |
if (candidateMoves.length === 0) { | |
candidateMoves = legalMoves.filter(move => { | |
chess.move(move) | |
const isCheck = chess.inCheck() | |
chess.undo() | |
return isCheck | |
}) | |
} | |
if (candidateMoves.length === 0) { | |
candidateMoves = legalMoves | |
} | |
const randomIndex = Math.floor(Math.random() * candidateMoves.length) | |
return candidateMoves[randomIndex] | |
} | |
isModelLoaded(): boolean { | |
return this.model !== null | |
} | |
isModelLoading(): boolean { | |
return this.isLoading | |
} | |
getModelInfo(): string { | |
if (this.isLoading) return 'Loading...' | |
if (this.model) return `${this.modelId} (Loaded)` | |
return `${this.modelId} (Not loaded)` | |
} | |
} |