Maximus Powers
final
3568151
import { Chess, Move } from 'chess.js'
import { pipeline } from '@xenova/transformers'
export class ChessAI {
private model: any = null
private isLoading: boolean = false
private modelId: string
constructor(modelId: string = 'mlabonne/chesspythia-70m') {
this.modelId = modelId
}
async initialize(): Promise<void> {
if (this.model || this.isLoading) return
this.isLoading = true
try {
this.model = await pipeline('text-generation', this.modelId)
} catch (error) {
console.error('Failed to load chess model:', error)
this.model = null
} finally {
this.isLoading = false
}
}
async getMove(chess: Chess, timeLimit: number = 10000): Promise<Move | null> {
if (!this.model) {
console.log('Using fallback AI (model not loaded)')
return this.getFallbackMove(chess)
}
try {
const legalMoves = chess.moves({ verbose: true })
if (legalMoves.length === 0) return null
const prompt = this.createChessPrompt(chess)
const startTime = Date.now()
const result = await Promise.race([
this.generateMove(prompt, legalMoves),
new Promise<null>((_, reject) =>
setTimeout(() => reject(new Error('Timeout')), timeLimit)
)
])
const elapsedTime = Date.now() - startTime
console.log(`AI move generated in ${elapsedTime}ms`)
return result || this.getFallbackMove(chess)
} catch (error) {
console.error('Error generating AI move:', error)
return this.getFallbackMove(chess)
}
}
private async generateMove(prompt: string, legalMoves: Move[]): Promise<Move | null> {
if (!this.model) return null
try {
const output = await this.model(prompt, {
max_new_tokens: 10,
temperature: 0.7
})
const generatedText = Array.isArray(output) ? output[0]?.generated_text : output.generated_text
console.log('Model output:', generatedText)
const move = this.parseMove(generatedText, legalMoves)
return move
} catch (error) {
console.error('Error in model generation:', error)
return null
}
}
private createChessPrompt(chess: Chess): string {
const turn = chess.turn()
const moveNumber = chess.moveNumber()
if (turn === 'w') {
return `${moveNumber}.`
} else {
return `${moveNumber}...`
}
}
private parseMove(generatedText: string, legalMoves: Move[]): Move | null {
if (!generatedText) return null
const cleanText = generatedText.trim().replace(/[+#]$/, '')
for (const move of legalMoves) {
if (move.san === cleanText || move.lan === cleanText) {
return move
}
}
for (const move of legalMoves) {
if (move.san.startsWith(cleanText) || cleanText.includes(move.san)) {
return move
}
}
console.log(`Could not parse move "${cleanText}" from legal moves:`, legalMoves.map(m => m.san))
return null
}
private getFallbackMove(chess: Chess): Move | null {
const legalMoves = chess.moves({ verbose: true })
if (legalMoves.length === 0) return null
let candidateMoves = legalMoves.filter(move => move.captured)
if (candidateMoves.length === 0) {
candidateMoves = legalMoves.filter(move => {
chess.move(move)
const isCheck = chess.inCheck()
chess.undo()
return isCheck
})
}
if (candidateMoves.length === 0) {
candidateMoves = legalMoves
}
const randomIndex = Math.floor(Math.random() * candidateMoves.length)
return candidateMoves[randomIndex]
}
isModelLoaded(): boolean {
return this.model !== null
}
isModelLoading(): boolean {
return this.isLoading
}
getModelInfo(): string {
if (this.isLoading) return 'Loading...'
if (this.model) return `${this.modelId} (Loaded)`
return `${this.modelId} (Not loaded)`
}
}