File size: 6,123 Bytes
6fa125c
 
0d0f934
57f49ec
0d0f934
6fa125c
 
 
 
 
 
57f49ec
 
6fa125c
0d0f934
 
6fa125c
1bb1792
0d0f934
1bb1792
 
 
 
 
 
6fa125c
 
 
 
 
 
2c016c2
6fa125c
 
 
 
 
57f49ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa125c
 
 
 
 
 
 
1bb1792
57f49ec
1bb1792
 
 
 
 
 
6fa125c
 
 
 
1bb1792
 
 
 
6fa125c
 
57f49ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa125c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
// @ts-check

import { pipeline, env } from '@huggingface/transformers';
import * as webllm from '@mlc-ai/web-llm';

import { loadModelCore } from './load-model-core';

export class ModelCache {
  cache = new Map();
  /** @type {import('@huggingface/transformers').DeviceType | undefined} */
  backend = undefined;
  /** @type {{ possible: boolean, lastError?: string } | undefined} */
  webllmProbe = undefined;

  env = env;

  knownModels = [
    'Xenova/llama2.c-stories15M', // nonsense
    'Xenova/phi-3-mini-4k-instruct', // huge
    'Xenova/all-MiniLM-L6-v2', // unsupported model type: bert
    'Xenova/phi-1.5', // gated
    'Qwen/Qwen2.5-3B', // cannot be loaded
    'microsoft/phi-1_5', // cannot be loaded
    'FlofloB/100k_fineweb_continued_pretraining_Qwen2.5-0.5B-Instruct_Unsloth_merged_16bit', // cannot be loaded 
    'ehristoforu/coolqwen-3b-it' // cannot be loaded
  ];

  /**
   * @param {{
   *  modelName: string
   * }} _
   * @return {ReturnType<typeof this._loadModelAndStore>}
   */
  getModel({ modelName }) {
    return this.cache.get(modelName) || this._loadModelAndStore({ modelName });
  }

  /**
   * Lightweight probe to detect WebLLM API availability (advisory only)
   */
  probeWebLLM() {
    if (this.webllmProbe) return this.webllmProbe;
    
    try {
      // Check if basic WebLLM APIs are available
      const hasWebLLM = typeof webllm?.CreateMLCEngine === 'function' && 
                        typeof webllm?.prebuiltAppConfig !== 'undefined';
      this.webllmProbe = { possible: hasWebLLM };
    } catch (err) {
      this.webllmProbe = { possible: false, lastError: String(err) };
    }
    
    return this.webllmProbe;
  }

  /**
   * @param {{
   *  modelName: string
   * }} _
   */
  _loadModelAndStore({ modelName }) {
    if (!this.backend) this.backend = detectTransformersBackend();
    // Create a loader promise that will try multiple backends in order.
    const loader = this._loadWebLLMOrFallbackToTransformersModelNow({ modelName });

    // store the in-progress promise so concurrent requests reuse it
    this.cache.set(modelName, loader);
    loader.then(
      (model) => {
        // on success, loader already stored the model
        this.cache.set(modelName, model);
      },
      () => {
        this.cache.delete(modelName);
      }
    );

    return loader;
  }

  async _loadWebLLMOrFallbackToTransformersModelNow({ modelName }) {
    const probe = this.probeWebLLM();
    
    // Try WebLLM first if probe suggests it's possible
    if (probe.possible) {
      try {
        const webLLMId = modelName.split('/').pop() || modelName;
        console.log(`Loading ${webLLMId} via WebLLM...`);
        const engine = await webllm.CreateMLCEngine(webLLMId, {
          appConfig: webllm.prebuiltAppConfig
        });
        
        // Quick end-to-end validation: run a very small prompt to ensure the
        // engine responds correctly before caching it. If this fails we
        // throw so the outer catch falls back to Transformers.js.
        try {
          const webllmEngine = engine;
          const testResp = await webllmEngine.chat.completions.create({
            messages: [{ role: 'user', content: 'Hello' }],
            max_tokens: 8,
            temperature: 0.2
          });
          const testText = testResp?.choices?.[0]?.message?.content ?? '';
          if (!testText || String(testText).trim() === '') {
            throw new Error('WebLLM test prompt returned empty response');
          }
        } catch (e) {
          throw new Error('WebLLM validation failed: ' + String(e));
        }
        
        console.log(`WebLLM loaded: ${webLLMId}`);
        return engine;
      } catch (err) {
        console.log(`WebLLM failed for ${modelName}: ${err.message}`);
        // Fall through to Transformers.js
      }
    }

    // Fallback to Transformers.js
    return this._loadTransformersModelNow({ modelName });
  }

  async _loadTransformersModelNow({ modelName }) {
    // candidate order: detected backend first, then common fallbacks
    let candidates = ['webgpu', 'gpu', 'wasm'];
    // candidates = ['gpu', 'wasm'];
    candidates = candidates.slice(candidates.indexOf(this.backend || 'wasm'));
    candidates = ['auto'];// , 'wasm'];

    let errs = [];
    console.log('Trying candidates ', candidates);
    for (const device of candidates) {
      try {
        const model = await loadModelCore({
          modelName,
          device: /** @type {import('@huggingface/transformers').DeviceType} */ (device)
        });
        // on success, update backend to the working device and store model
        this.backend = /** @type {import('@huggingface/transformers').DeviceType} */ (device);
        this.cache.set(modelName, model);
        return model;
      } catch (err) {
        console.log('Failed ', device, ' ', err);
        errs.push(device + ': ' + err.stack);
        // continue to next candidate
      }
    }

    // none succeeded
    const err = new Error(
      'Backends failed: ' + JSON.stringify(candidates) + ', errors:\n\n' +
      errs.join('\n\n'));
    throw err;
  }

}

export function detectTransformersBackend() {
  /**
   * Detect available acceleration backends
   * @type {import('@huggingface/transformers').DeviceType}
   */
  let backend = 'wasm';
  try {
    const hasWebGPU = typeof navigator !== 'undefined' && !!/** @type {*} */(navigator).gpu;
    let hasWebGL2 = false;
    try {
      // In a worker environment prefer OffscreenCanvas to test webgl2
      if (typeof OffscreenCanvas !== 'undefined') {
        const c = new OffscreenCanvas(1, 1);
        const gl = c.getContext('webgl2') || c.getContext('webgl');
        hasWebGL2 = !!gl;
      } else if (typeof document !== 'undefined') {
        const canvas = document.createElement('canvas');
        const gl = canvas.getContext('webgl2') || canvas.getContext('webgl');
        hasWebGL2 = !!gl;
      }
    } catch (e) {
      hasWebGL2 = false;
    }

    if (hasWebGPU) backend = 'webgpu';
    else if (hasWebGL2) backend = 'gpu';
  } catch (e) {
    backend = 'wasm';
  }

  return backend;
}