|
|
|
|
|
import { pipeline, env } from '@huggingface/transformers'; |
|
import * as webllm from '@mlc-ai/web-llm'; |
|
|
|
import { loadModelCore } from './load-model-core'; |
|
|
|
export class ModelCache { |
|
cache = new Map(); |
|
|
|
backend = undefined; |
|
|
|
webllmProbe = undefined; |
|
|
|
env = env; |
|
|
|
knownModels = [ |
|
'Xenova/llama2.c-stories15M', |
|
'Xenova/phi-3-mini-4k-instruct', |
|
'Xenova/all-MiniLM-L6-v2', |
|
'Xenova/phi-1.5', |
|
'Qwen/Qwen2.5-3B', |
|
'microsoft/phi-1_5', |
|
'FlofloB/100k_fineweb_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit', |
|
'ehristoforu/coolqwen-3b-it' |
|
]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
getModel({ modelName }) { |
|
return this.cache.get(modelName) || this._loadModelAndStore({ modelName }); |
|
} |
|
|
|
|
|
|
|
|
|
probeWebLLM() { |
|
if (this.webllmProbe) return this.webllmProbe; |
|
|
|
try { |
|
|
|
const hasWebLLM = typeof webllm?.CreateMLCEngine === 'function' && |
|
typeof webllm?.prebuiltAppConfig !== 'undefined'; |
|
this.webllmProbe = { possible: hasWebLLM }; |
|
} catch (err) { |
|
this.webllmProbe = { possible: false, lastError: String(err) }; |
|
} |
|
|
|
return this.webllmProbe; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
_loadModelAndStore({ modelName }) { |
|
if (!this.backend) this.backend = detectTransformersBackend(); |
|
|
|
const loader = this._loadWebLLMOrFallbackToTransformersModelNow({ modelName }); |
|
|
|
|
|
this.cache.set(modelName, loader); |
|
loader.then( |
|
(model) => { |
|
|
|
this.cache.set(modelName, model); |
|
}, |
|
() => { |
|
this.cache.delete(modelName); |
|
} |
|
); |
|
|
|
return loader; |
|
} |
|
|
|
async _loadWebLLMOrFallbackToTransformersModelNow({ modelName }) { |
|
const probe = this.probeWebLLM(); |
|
|
|
|
|
if (probe.possible) { |
|
try { |
|
const webLLMId = modelName.split('/').pop() || modelName; |
|
console.log(`Loading ${webLLMId} via WebLLM...`); |
|
const engine = await webllm.CreateMLCEngine(webLLMId, { |
|
appConfig: webllm.prebuiltAppConfig |
|
}); |
|
|
|
|
|
|
|
|
|
try { |
|
const webllmEngine = engine; |
|
const testResp = await webllmEngine.chat.completions.create({ |
|
messages: [{ role: 'user', content: 'Hello' }], |
|
max_tokens: 8, |
|
temperature: 0.2 |
|
}); |
|
const testText = testResp?.choices?.[0]?.message?.content ?? ''; |
|
if (!testText || String(testText).trim() === '') { |
|
throw new Error('WebLLM test prompt returned empty response'); |
|
} |
|
} catch (e) { |
|
throw new Error('WebLLM validation failed: ' + String(e)); |
|
} |
|
|
|
console.log(`WebLLM loaded: ${webLLMId}`); |
|
return engine; |
|
} catch (err) { |
|
console.log(`WebLLM failed for ${modelName}: ${err.message}`); |
|
|
|
} |
|
} |
|
|
|
|
|
return this._loadTransformersModelNow({ modelName }); |
|
} |
|
|
|
async _loadTransformersModelNow({ modelName }) { |
|
|
|
let candidates = ['webgpu', 'gpu', 'wasm']; |
|
|
|
candidates = candidates.slice(candidates.indexOf(this.backend || 'wasm')); |
|
candidates = ['auto']; |
|
|
|
let errs = []; |
|
console.log('Trying candidates ', candidates); |
|
for (const device of candidates) { |
|
try { |
|
const model = await loadModelCore({ |
|
modelName, |
|
device: (device) |
|
}); |
|
|
|
this.backend = (device); |
|
this.cache.set(modelName, model); |
|
return model; |
|
} catch (err) { |
|
console.log('Failed ', device, ' ', err); |
|
errs.push(device + ': ' + err.stack); |
|
|
|
} |
|
} |
|
|
|
|
|
const err = new Error( |
|
'Backends failed: ' + JSON.stringify(candidates) + ', errors:\n\n' + |
|
errs.join('\n\n')); |
|
throw err; |
|
} |
|
|
|
} |
|
|
|
export function detectTransformersBackend() { |
|
|
|
|
|
|
|
|
|
let backend = 'wasm'; |
|
try { |
|
const hasWebGPU = typeof navigator !== 'undefined' && !!(navigator).gpu; |
|
let hasWebGL2 = false; |
|
try { |
|
|
|
if (typeof OffscreenCanvas !== 'undefined') { |
|
const c = new OffscreenCanvas(1, 1); |
|
const gl = c.getContext('webgl2') || c.getContext('webgl'); |
|
hasWebGL2 = !!gl; |
|
} else if (typeof document !== 'undefined') { |
|
const canvas = document.createElement('canvas'); |
|
const gl = canvas.getContext('webgl2') || canvas.getContext('webgl'); |
|
hasWebGL2 = !!gl; |
|
} |
|
} catch (e) { |
|
hasWebGL2 = false; |
|
} |
|
|
|
if (hasWebGPU) backend = 'webgpu'; |
|
else if (hasWebGL2) backend = 'gpu'; |
|
} catch (e) { |
|
backend = 'wasm'; |
|
} |
|
|
|
return backend; |
|
} |