|
import { |
|
AutoProcessor, |
|
MultiModalityCausalLM, |
|
BaseStreamer, |
|
TextStreamer, |
|
InterruptableStoppingCriteria, |
|
} from "@huggingface/transformers"; |
|
|
|
|
|
const IMAGE_GENERATION_COMMAND_PREFIX = "/imagine "; |
|
const MAX_NEW_TEXT_TOKENS = 1024; |
|
|
|
|
|
|
|
|
|
let fp16_supported = false; |
|
async function check() { |
|
try { |
|
const adapter = await navigator.gpu.requestAdapter(); |
|
if (!adapter) { |
|
throw new Error("WebGPU is not supported (no adapter found)"); |
|
} |
|
fp16_supported = adapter.features.has("shader-f16"); |
|
self.postMessage({ |
|
status: "success", |
|
data: fp16_supported, |
|
}); |
|
} catch (e) { |
|
self.postMessage({ |
|
status: "error", |
|
data: e.toString(), |
|
}); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
class ImageGenerationPipeline { |
|
static model_id = "onnx-community/Janus-1.3B-ONNX"; |
|
|
|
static async getInstance(progress_callback = null) { |
|
this.processor ??= AutoProcessor.from_pretrained(this.model_id, { |
|
progress_callback, |
|
}); |
|
|
|
this.model ??= MultiModalityCausalLM.from_pretrained(this.model_id, { |
|
dtype: fp16_supported |
|
? { |
|
prepare_inputs_embeds: "q4", |
|
language_model: "q4f16", |
|
lm_head: "fp16", |
|
gen_head: "fp16", |
|
gen_img_embeds: "fp16", |
|
image_decode: "fp32", |
|
} |
|
: { |
|
prepare_inputs_embeds: "fp32", |
|
language_model: "q4", |
|
lm_head: "fp32", |
|
gen_head: "fp32", |
|
gen_img_embeds: "fp32", |
|
image_decode: "fp32", |
|
}, |
|
device: { |
|
prepare_inputs_embeds: "wasm", |
|
language_model: "webgpu", |
|
lm_head: "webgpu", |
|
gen_head: "webgpu", |
|
gen_img_embeds: "webgpu", |
|
image_decode: "webgpu", |
|
}, |
|
progress_callback, |
|
}); |
|
|
|
return Promise.all([this.processor, this.model]); |
|
} |
|
} |
|
|
|
class ProgressStreamer extends BaseStreamer { |
|
constructor(total, on_progress) { |
|
super(); |
|
this.total = total; |
|
this.on_progress = on_progress; |
|
|
|
this.count = null; |
|
this.start_time = null; |
|
} |
|
|
|
put(value) { |
|
if (this.count === null) { |
|
|
|
this.count = 0; |
|
this.start_time = performance.now(); |
|
return; |
|
} |
|
|
|
const progress = ++this.count / this.total; |
|
|
|
this.on_progress({ |
|
count: this.count, |
|
total: this.total, |
|
progress, |
|
time: performance.now() - this.start_time, |
|
}); |
|
} |
|
|
|
end() { |
|
|
|
} |
|
} |
|
|
|
const stopping_criteria = new InterruptableStoppingCriteria(); |
|
|
|
async function generate(messages) { |
|
|
|
const message = messages.at(-1); |
|
|
|
|
|
self.postMessage({ status: "start" }); |
|
|
|
|
|
const [processor, model] = await ImageGenerationPipeline.getInstance(); |
|
|
|
|
|
if (message.content.startsWith(IMAGE_GENERATION_COMMAND_PREFIX)) { |
|
const text = message.content.replace(IMAGE_GENERATION_COMMAND_PREFIX, ""); |
|
|
|
const conversation = [ |
|
{ |
|
role: "User", |
|
content: text, |
|
}, |
|
]; |
|
const inputs = await processor(conversation, { |
|
chat_template: "text_to_image", |
|
}); |
|
|
|
const callback_function = (output) => { |
|
self.postMessage({ |
|
status: "image-update", |
|
...output, |
|
}); |
|
}; |
|
|
|
const num_image_tokens = processor.num_image_tokens; |
|
const streamer = new ProgressStreamer(num_image_tokens, callback_function); |
|
|
|
const outputs = await model.generate_images({ |
|
...inputs, |
|
min_new_tokens: num_image_tokens, |
|
max_new_tokens: num_image_tokens, |
|
do_sample: true, |
|
streamer, |
|
}); |
|
|
|
const blob = await outputs[0].toBlob(); |
|
|
|
|
|
self.postMessage({ |
|
status: "image-update", |
|
blob, |
|
}); |
|
} else { |
|
const inputs = await processor( |
|
message.image |
|
? [ |
|
{ |
|
role: "User", |
|
content: "<image_placeholder>\n" + message.content, |
|
images: [message.image], |
|
}, |
|
] |
|
: [ |
|
{ |
|
role: "System", |
|
content: |
|
"You are a helpful assistant. Answer the user's questions in a concise manner.", |
|
}, |
|
{ |
|
role: "User", |
|
content: message.content, |
|
}, |
|
], |
|
); |
|
|
|
let startTime; |
|
let numTokens = 0; |
|
let tps; |
|
const token_callback_function = () => { |
|
startTime ??= performance.now(); |
|
|
|
if (numTokens++ > 0) { |
|
tps = (numTokens / (performance.now() - startTime)) * 1000; |
|
} |
|
}; |
|
const callback_function = (output) => { |
|
self.postMessage({ |
|
status: "text-update", |
|
output, |
|
tps, |
|
numTokens, |
|
}); |
|
}; |
|
|
|
const streamer = new TextStreamer(processor.tokenizer, { |
|
skip_prompt: true, |
|
skip_special_tokens: true, |
|
callback_function, |
|
token_callback_function, |
|
}); |
|
|
|
|
|
const outputs = await model.generate({ |
|
...inputs, |
|
max_new_tokens: MAX_NEW_TEXT_TOKENS, |
|
do_sample: false, |
|
streamer, |
|
stopping_criteria, |
|
}); |
|
} |
|
|
|
|
|
self.postMessage({ |
|
status: "complete", |
|
}); |
|
} |
|
|
|
async function load() { |
|
self.postMessage({ |
|
status: "loading", |
|
data: "Loading model...", |
|
}); |
|
|
|
|
|
const [processor, model] = await ImageGenerationPipeline.getInstance((x) => { |
|
|
|
|
|
self.postMessage(x); |
|
}); |
|
|
|
self.postMessage({ status: "ready" }); |
|
} |
|
|
|
|
|
self.addEventListener("message", async (e) => { |
|
const { type, data } = e.data; |
|
|
|
switch (type) { |
|
case "check": |
|
check(); |
|
break; |
|
|
|
case "load": |
|
load(); |
|
break; |
|
|
|
case "generate": |
|
stopping_criteria.reset(); |
|
generate(data); |
|
break; |
|
|
|
case "interrupt": |
|
stopping_criteria.interrupt(); |
|
break; |
|
|
|
case "reset": |
|
stopping_criteria.reset(); |
|
break; |
|
} |
|
}); |
|
|