|
import { |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
TextStreamer, |
|
InterruptableStoppingCriteria, |
|
} from "@huggingface/transformers"; |
|
|
|
|
|
|
|
|
|
|
|
async function check() { |
|
try { |
|
const adapter = await navigator.gpu.requestAdapter(); |
|
if (!adapter) { |
|
throw new Error("WebGPU is not supported (no adapter found)"); |
|
} |
|
|
|
} catch (e) { |
|
self.postMessage({ |
|
status: "error", |
|
data: e.toString(), |
|
}); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
class TextGenerationPipeline { |
|
static model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"; |
|
|
|
static async getInstance(progress_callback = null) { |
|
this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { |
|
progress_callback, |
|
}); |
|
|
|
this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, { |
|
dtype: "q4f16", |
|
device: "webgpu", |
|
progress_callback, |
|
}); |
|
|
|
return Promise.all([this.tokenizer, this.model]); |
|
} |
|
} |
|
|
|
const stopping_criteria = new InterruptableStoppingCriteria(); |
|
|
|
let past_key_values_cache = null; |
|
async function generate(messages) { |
|
|
|
const [tokenizer, model] = await TextGenerationPipeline.getInstance(); |
|
|
|
const inputs = tokenizer.apply_chat_template(messages, { |
|
add_generation_prompt: true, |
|
return_dict: true, |
|
}); |
|
|
|
let startTime; |
|
let numTokens = 0; |
|
let tps; |
|
const token_callback_function = () => { |
|
startTime ??= performance.now(); |
|
|
|
if (numTokens++ > 0) { |
|
tps = (numTokens / (performance.now() - startTime)) * 1000; |
|
} |
|
}; |
|
const callback_function = (output) => { |
|
self.postMessage({ |
|
status: "update", |
|
output, |
|
tps, |
|
numTokens, |
|
}); |
|
}; |
|
|
|
const streamer = new TextStreamer(tokenizer, { |
|
skip_prompt: true, |
|
skip_special_tokens: true, |
|
callback_function, |
|
token_callback_function, |
|
}); |
|
|
|
|
|
self.postMessage({ status: "start" }); |
|
|
|
const { past_key_values, sequences } = await model.generate({ |
|
...inputs, |
|
past_key_values: past_key_values_cache, |
|
|
|
|
|
|
|
|
|
|
|
|
|
max_new_tokens: 1024, |
|
streamer, |
|
stopping_criteria, |
|
return_dict_in_generate: true, |
|
}); |
|
past_key_values_cache = past_key_values; |
|
|
|
const decoded = tokenizer.batch_decode(sequences, { |
|
skip_special_tokens: true, |
|
}); |
|
|
|
|
|
self.postMessage({ |
|
status: "complete", |
|
output: decoded, |
|
}); |
|
} |
|
|
|
async function load() { |
|
self.postMessage({ |
|
status: "loading", |
|
data: "Loading model...", |
|
}); |
|
|
|
|
|
const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => { |
|
|
|
|
|
self.postMessage(x); |
|
}); |
|
|
|
self.postMessage({ |
|
status: "loading", |
|
data: "Compiling shaders and warming up model...", |
|
}); |
|
|
|
|
|
const inputs = tokenizer("a"); |
|
await model.generate({ ...inputs, max_new_tokens: 1 }); |
|
self.postMessage({ status: "ready" }); |
|
} |
|
|
|
self.addEventListener("message", async (e) => { |
|
const { type, data } = e.data; |
|
|
|
switch (type) { |
|
case "check": |
|
check(); |
|
break; |
|
|
|
case "load": |
|
load(); |
|
break; |
|
|
|
case "generate": |
|
stopping_criteria.reset(); |
|
generate(data); |
|
break; |
|
|
|
case "interrupt": |
|
stopping_criteria.interrupt(); |
|
break; |
|
|
|
case "reset": |
|
past_key_values_cache = null; |
|
stopping_criteria.reset(); |
|
break; |
|
} |
|
}); |
|
|