File size: 6,850 Bytes
f059dc2
 
6fa125c
dcf5774
f059dc2
4537ca2
 
f059dc2
6fa125c
3ed38da
f059dc2
 
 
 
 
 
 
6fa125c
b46ef31
 
6fa125c
0d0f934
b46ef31
6fa125c
 
774048c
 
b46ef31
6fa125c
 
f059dc2
4537ca2
774048c
 
 
 
 
 
 
 
6fa125c
 
 
 
3ed38da
6fa125c
 
 
 
 
 
 
 
 
f059dc2
6fa125c
f059dc2
6fa125c
f059dc2
3ed38da
6fa125c
57f49ec
 
 
6fa125c
57f49ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa125c
 
 
 
b46ef31
6fa125c
774048c
dcf5774
774048c
4537ca2
 
 
 
dcf5774
 
f62c223
 
 
 
 
 
 
 
 
5e0e4e8
f62c223
 
 
 
 
 
 
 
 
 
 
 
 
4537ca2
 
774048c
dcf5774
4537ca2
 
dcf5774
 
774048c
 
4dc14dc
4537ca2
 
dcf5774
 
4537ca2
 
 
774048c
 
4537ca2
 
dcf5774
774048c
 
 
 
 
dcf5774
 
 
6fa125c
b46ef31
6fa125c
 
 
 
 
 
 
2c016c2
 
 
 
 
 
b46ef31
6fa125c
 
 
 
 
f059dc2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// @ts-check

import { ModelCache } from './model-cache';
import { listChatModelsIterator } from './list-chat-models.js';

import curatedList from './curated-model-list.json' assert { type: 'json' };

export function bootWorker() {
  const modelCache = new ModelCache();
  let selectedModel = modelCache.knownModels[0];
  // Report starting
  try {
    self.postMessage({ type: 'status', status: 'initializing' });
  } catch (e) {
    // ignore if postMessage not available for some reason
  }

  self.postMessage({ type: 'status', status: 'backend-detected', backend: modelCache.backend });


  // signal ready to main thread (worker script loaded; model runtime may still be pending)
  self.postMessage({ type: 'ready', env: modelCache.env, backend: modelCache.backend });

  // handle incoming requests from the UI thread
  self.addEventListener('message', handleMessage);
  // track cancellable tasks by id
  const activeTasks = new Map();

  async function handleMessage({ data }) {
    const { id } = data;
    try {
      if (data.type === 'listChatModels') {
        // kick off the long-running listing/classification task
        handleListChatModels(data).catch(err => {
          self.postMessage({ id, type: 'error', error: String(err) });
        });
      } else if (data.type === 'cancelListChatModels') {
        const task = activeTasks.get(id);
        if (task && task.abort) task.abort();
        self.postMessage({ id, type: 'response', result: { cancelled: true } });
      } else if (data.type === 'loadModel') {
        const { modelName = modelCache.knownModels[0] } = data;
        try {
          const pipe = await modelCache.getModel({ modelName });
          selectedModel = modelName;
          self.postMessage({ id, type: 'response', result: { model: modelName, status: 'loaded' } });
        } catch (err) {
          self.postMessage({ id, type: 'error', error: String(err) });
        }
      } else if (data.type === 'runPrompt') {
        handleRunPrompt(data);
      } else {
        if (id) self.postMessage({ id, type: 'error', error: 'unknown-message-type' });
      }
    } catch (err) {
      if (id) self.postMessage({ id, type: 'error', error: String(err) });
    }
  }

  async function handleRunPrompt({ prompt, modelName = selectedModel, id, options }) {
    try {
      const engine = await modelCache.getModel({ modelName });
      if (!engine) throw new Error('engine not available');
      
      self.postMessage({ id, type: 'status', status: 'inference-start', model: modelName });
      
      // Duck-typing to detect engine type and route accordingly
      let text;
      if (/** @type {any} */(engine).chat?.completions?.create) {
        // WebLLM engine detected
        try {
          const webllmEngine = /** @type {any} */(engine);
          const response = await webllmEngine.chat.completions.create({
            messages: [{ role: "user", content: prompt }],
            max_tokens: options?.max_new_tokens ?? 250,
            temperature: options?.temperature ?? 0.7
          });
          text = response.choices[0]?.message?.content ?? '';
        } catch (err) {
          console.log(`WebLLM inference failed for ${modelName}: ${err.message}`);
          throw err; // Re-throw since we can't easily fallback mid-inference
        }
      } else if (typeof engine === 'function') {
        // Transformers.js pipeline detected
        const out = await engine(prompt, {
          max_new_tokens: 250,
          temperature: 0.7,
          do_sample: true,
          pad_token_id: engine.tokenizer?.eos_token_id,
          return_full_text: false,
          ...options
        });
        text = extractText(out);
      } else {
        throw new Error('Unknown engine type');
      }
      
      self.postMessage({ id, type: 'status', status: 'inference-done', model: modelName });
      self.postMessage({ id, type: 'response', result: text });
    } catch (err) {
      self.postMessage({ id, type: 'error', error: String(err) });
    }
  }

  // Implementation of the listChatModels worker action using the async-iterator action.
  async function handleListChatModels({ id, params = {} }) {

    self.postMessage({ id, type: 'response', result: { models: curatedList } });
    return;

    const iterator = listChatModelsIterator(params);
    let sawDone = false;
    // batching buffer
    let batchBuffer = [];
    let batchTimer = null;
    const BATCH_MS = 50;
    const BATCH_MAX = 50;

    function flushBatch() {
      if (!batchBuffer || batchBuffer.length === 0) return;
      try {
        console.log('Loading: ', batchBuffer[batchBuffer.length - 1]);
        self.postMessage({ id, type: 'progress', batch: true, items: batchBuffer.splice(0) });
      } catch (e) {}
      if (batchTimer) { clearTimeout(batchTimer); batchTimer = null; }
    }

    function enqueueProgress(delta) {
      batchBuffer.push(delta);
      if (batchBuffer.length >= BATCH_MAX) return flushBatch();
      if (!batchTimer) {
        batchTimer = setTimeout(() => { flushBatch(); }, BATCH_MS);
      }
    }

    activeTasks.set(id, { abort: () => iterator.return() });
    let lastBatchDelta;
    try {
      for await (const delta of iterator) {
        try { enqueueProgress(delta); } catch (e) { }
        if (delta.models) lastBatchDelta = delta;
        if (delta && delta.status === 'done') {
          sawDone = true;
        }
      }

      // flush any remaining progress messages synchronously
      flushBatch();
      if (!sawDone) {
        // iterator exited early (likely cancelled)
        self.postMessage({ id, type: 'response', result: { cancelled: true } });
      } else {
        self.postMessage({ id, type: 'response', result: lastBatchDelta });
      }
    } catch (err) {
      flushBatch();
      self.postMessage({ id, type: 'error', error: String(err), code: err.code || null });
    } finally {
      activeTasks.delete(id);
    }
  }

  // helper: fetchConfigForModel
  // Note: fetchConfigForModel and classifyModel were moved to the
  // `src/worker/list-chat-models.js` async-iterator action. Keep this file
  // minimal and delegate to the iterator for listing/classification logic.
}

// helper to extract generated text from various runtime outputs
function extractText(output) {
  // typical shapes: [{ generated_text: '...' }] or [{ text: '...' }] or string
  try {
    if (!output) return '';
    if (typeof output === 'string') return output;
    if (Array.isArray(output) && output.length > 0) {
      return output.map(el => {
        if (el.generated_text) return el.generated_text;
        if (el.text) return el.text;
        // Some runtimes return an array of strings
        if (typeof el === 'string') return el;
      });
    }
    // Fallback: try JSON stringify
    return String(output);
  } catch (e) {
    return '';
  }
}