File size: 6,123 Bytes
6fa125c 0d0f934 57f49ec 0d0f934 6fa125c 57f49ec 6fa125c 0d0f934 6fa125c 1bb1792 0d0f934 1bb1792 6fa125c 2c016c2 6fa125c 57f49ec 6fa125c 1bb1792 57f49ec 1bb1792 6fa125c 1bb1792 6fa125c 57f49ec 6fa125c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
// @ts-check
import { pipeline, env } from '@huggingface/transformers';
import * as webllm from '@mlc-ai/web-llm';
import { loadModelCore } from './load-model-core';
export class ModelCache {
cache = new Map();
/** @type {import('@huggingface/transformers').DeviceType | undefined} */
backend = undefined;
/** @type {{ possible: boolean, lastError?: string } | undefined} */
webllmProbe = undefined;
env = env;
knownModels = [
'Xenova/llama2.c-stories15M', // nonsense
'Xenova/phi-3-mini-4k-instruct', // huge
'Xenova/all-MiniLM-L6-v2', // unsupported model type: bert
'Xenova/phi-1.5', // gated
'Qwen/Qwen2.5-3B', // cannot be loaded
'microsoft/phi-1_5', // cannot be loaded
'FlofloB/100k_fineweb_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit', // cannot be loaded
'ehristoforu/coolqwen-3b-it' // cannot be loaded
];
/**
* @param {{
* modelName: string
* }} _
* @return {ReturnType<typeof this._loadModelAndStore>}
*/
getModel({ modelName }) {
return this.cache.get(modelName) || this._loadModelAndStore({ modelName });
}
/**
* Lightweight probe to detect WebLLM API availability (advisory only)
*/
probeWebLLM() {
if (this.webllmProbe) return this.webllmProbe;
try {
// Check if basic WebLLM APIs are available
const hasWebLLM = typeof webllm?.CreateMLCEngine === 'function' &&
typeof webllm?.prebuiltAppConfig !== 'undefined';
this.webllmProbe = { possible: hasWebLLM };
} catch (err) {
this.webllmProbe = { possible: false, lastError: String(err) };
}
return this.webllmProbe;
}
/**
* @param {{
* modelName: string
* }} _
*/
_loadModelAndStore({ modelName }) {
if (!this.backend) this.backend = detectTransformersBackend();
// Create a loader promise that will try multiple backends in order.
const loader = this._loadWebLLMOrFallbackToTransformersModelNow({ modelName });
// store the in-progress promise so concurrent requests reuse it
this.cache.set(modelName, loader);
loader.then(
(model) => {
// on success, loader already stored the model
this.cache.set(modelName, model);
},
() => {
this.cache.delete(modelName);
}
);
return loader;
}
async _loadWebLLMOrFallbackToTransformersModelNow({ modelName }) {
const probe = this.probeWebLLM();
// Try WebLLM first if probe suggests it's possible
if (probe.possible) {
try {
const webLLMId = modelName.split('/').pop() || modelName;
console.log(`Loading ${webLLMId} via WebLLM...`);
const engine = await webllm.CreateMLCEngine(webLLMId, {
appConfig: webllm.prebuiltAppConfig
});
// Quick end-to-end validation: run a very small prompt to ensure the
// engine responds correctly before caching it. If this fails we
// throw so the outer catch falls back to Transformers.js.
try {
const webllmEngine = engine;
const testResp = await webllmEngine.chat.completions.create({
messages: [{ role: 'user', content: 'Hello' }],
max_tokens: 8,
temperature: 0.2
});
const testText = testResp?.choices?.[0]?.message?.content ?? '';
if (!testText || String(testText).trim() === '') {
throw new Error('WebLLM test prompt returned empty response');
}
} catch (e) {
throw new Error('WebLLM validation failed: ' + String(e));
}
console.log(`WebLLM loaded: ${webLLMId}`);
return engine;
} catch (err) {
console.log(`WebLLM failed for ${modelName}: ${err.message}`);
// Fall through to Transformers.js
}
}
// Fallback to Transformers.js
return this._loadTransformersModelNow({ modelName });
}
async _loadTransformersModelNow({ modelName }) {
// candidate order: detected backend first, then common fallbacks
let candidates = ['webgpu', 'gpu', 'wasm'];
// candidates = ['gpu', 'wasm'];
candidates = candidates.slice(candidates.indexOf(this.backend || 'wasm'));
candidates = ['auto'];// , 'wasm'];
let errs = [];
console.log('Trying candidates ', candidates);
for (const device of candidates) {
try {
const model = await loadModelCore({
modelName,
device: /** @type {import('@huggingface/transformers').DeviceType} */ (device)
});
// on success, update backend to the working device and store model
this.backend = /** @type {import('@huggingface/transformers').DeviceType} */ (device);
this.cache.set(modelName, model);
return model;
} catch (err) {
console.log('Failed ', device, ' ', err);
errs.push(device + ': ' + err.stack);
// continue to next candidate
}
}
// none succeeded
const err = new Error(
'Backends failed: ' + JSON.stringify(candidates) + ', errors:\n\n' +
errs.join('\n\n'));
throw err;
}
}
export function detectTransformersBackend() {
/**
* Detect available acceleration backends
* @type {import('@huggingface/transformers').DeviceType}
*/
let backend = 'wasm';
try {
const hasWebGPU = typeof navigator !== 'undefined' && !!/** @type {*} */(navigator).gpu;
let hasWebGL2 = false;
try {
// In a worker environment prefer OffscreenCanvas to test webgl2
if (typeof OffscreenCanvas !== 'undefined') {
const c = new OffscreenCanvas(1, 1);
const gl = c.getContext('webgl2') || c.getContext('webgl');
hasWebGL2 = !!gl;
} else if (typeof document !== 'undefined') {
const canvas = document.createElement('canvas');
const gl = canvas.getContext('webgl2') || canvas.getContext('webgl');
hasWebGL2 = !!gl;
}
} catch (e) {
hasWebGL2 = false;
}
if (hasWebGPU) backend = 'webgpu';
else if (hasWebGL2) backend = 'gpu';
} catch (e) {
backend = 'wasm';
}
return backend;
} |