|
import { |
|
|
|
AutoModel, |
|
|
|
|
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
TextStreamer, |
|
InterruptableStoppingCriteria, |
|
|
|
|
|
Tensor, |
|
pipeline, |
|
} from "@huggingface/transformers"; |
|
|
|
import { KokoroTTS, TextSplitterStream } from "kokoro-js"; |
|
|
|
import { |
|
MAX_BUFFER_DURATION, |
|
INPUT_SAMPLE_RATE, |
|
SPEECH_THRESHOLD, |
|
EXIT_THRESHOLD, |
|
SPEECH_PAD_SAMPLES, |
|
MAX_NUM_PREV_BUFFERS, |
|
MIN_SILENCE_DURATION_SAMPLES, |
|
MIN_SPEECH_DURATION_SAMPLES, |
|
} from "./constants"; |
|
|
|
const model_id = "onnx-community/Kokoro-82M-v1.0-ONNX"; |
|
let voice; |
|
const tts = await KokoroTTS.from_pretrained(model_id, { |
|
dtype: "fp32", |
|
device: "webgpu", |
|
}); |
|
|
|
const device = "webgpu"; |
|
self.postMessage({ type: "info", message: `Using device: "${device}"` }); |
|
self.postMessage({ |
|
type: "info", |
|
message: "Loading models...", |
|
duration: "until_next", |
|
}); |
|
|
|
|
|
const silero_vad = await AutoModel.from_pretrained( |
|
"onnx-community/silero-vad", |
|
{ |
|
config: { model_type: "custom" }, |
|
dtype: "fp32", |
|
}, |
|
).catch((error) => { |
|
self.postMessage({ error }); |
|
throw error; |
|
}); |
|
|
|
const DEVICE_DTYPE_CONFIGS = { |
|
webgpu: { |
|
encoder_model: "fp32", |
|
decoder_model_merged: "fp32", |
|
}, |
|
wasm: { |
|
encoder_model: "fp32", |
|
decoder_model_merged: "q8", |
|
}, |
|
}; |
|
const transcriber = await pipeline( |
|
"automatic-speech-recognition", |
|
"onnx-community/whisper-base", |
|
{ |
|
device, |
|
dtype: DEVICE_DTYPE_CONFIGS[device], |
|
}, |
|
).catch((error) => { |
|
self.postMessage({ error }); |
|
throw error; |
|
}); |
|
|
|
await transcriber(new Float32Array(INPUT_SAMPLE_RATE)); |
|
|
|
const llm_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"; |
|
const tokenizer = await AutoTokenizer.from_pretrained(llm_model_id); |
|
const llm = await AutoModelForCausalLM.from_pretrained(llm_model_id, { |
|
dtype: "q4f16", |
|
device: "webgpu", |
|
}); |
|
|
|
const SYSTEM_MESSAGE = { |
|
role: "system", |
|
content: |
|
"You're a helpful and conversational voice assistant. Keep your responses short, clear, and casual.", |
|
}; |
|
await llm.generate({ ...tokenizer("x"), max_new_tokens: 1 }); |
|
|
|
let messages = [SYSTEM_MESSAGE]; |
|
let past_key_values_cache; |
|
let stopping_criteria; |
|
self.postMessage({ |
|
type: "status", |
|
status: "ready", |
|
message: "Ready!", |
|
voices: tts.voices, |
|
}); |
|
|
|
|
|
const BUFFER = new Float32Array(MAX_BUFFER_DURATION * INPUT_SAMPLE_RATE); |
|
let bufferPointer = 0; |
|
|
|
|
|
const sr = new Tensor("int64", [INPUT_SAMPLE_RATE], []); |
|
let state = new Tensor("float32", new Float32Array(2 * 1 * 128), [2, 1, 128]); |
|
|
|
|
|
let isRecording = false; |
|
let isPlaying = false; |
|
|
|
|
|
|
|
|
|
|
|
|
|
async function vad(buffer) { |
|
const input = new Tensor("float32", buffer, [1, buffer.length]); |
|
|
|
const { stateN, output } = await silero_vad({ input, sr, state }); |
|
state = stateN; |
|
|
|
const isSpeech = output.data[0]; |
|
|
|
|
|
return ( |
|
|
|
isSpeech > SPEECH_THRESHOLD || |
|
|
|
(isRecording && isSpeech >= EXIT_THRESHOLD) |
|
); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
const speechToSpeech = async (buffer, data) => { |
|
isPlaying = true; |
|
|
|
|
|
const text = await transcriber(buffer).then(({ text }) => text.trim()); |
|
if (["", "[BLANK_AUDIO]"].includes(text)) { |
|
|
|
return; |
|
} |
|
messages.push({ role: "user", content: text }); |
|
|
|
|
|
const splitter = new TextSplitterStream(); |
|
const stream = tts.stream(splitter, { |
|
voice, |
|
}); |
|
(async () => { |
|
for await (const { text, phonemes, audio } of stream) { |
|
self.postMessage({ type: "output", text, result: audio }); |
|
} |
|
})(); |
|
|
|
|
|
const inputs = tokenizer.apply_chat_template(messages, { |
|
add_generation_prompt: true, |
|
return_dict: true, |
|
}); |
|
const streamer = new TextStreamer(tokenizer, { |
|
skip_prompt: true, |
|
skip_special_tokens: true, |
|
callback_function: (text) => { |
|
splitter.push(text); |
|
}, |
|
token_callback_function: () => {}, |
|
}); |
|
|
|
stopping_criteria = new InterruptableStoppingCriteria(); |
|
const { past_key_values, sequences } = await llm.generate({ |
|
...inputs, |
|
past_key_values: past_key_values_cache, |
|
|
|
do_sample: false, |
|
max_new_tokens: 1024, |
|
streamer, |
|
stopping_criteria, |
|
return_dict_in_generate: true, |
|
}); |
|
past_key_values_cache = past_key_values; |
|
|
|
|
|
splitter.close(); |
|
|
|
const decoded = tokenizer.batch_decode( |
|
sequences.slice(null, [inputs.input_ids.dims[1], null]), |
|
{ skip_special_tokens: true }, |
|
); |
|
|
|
messages.push({ role: "assistant", content: decoded[0] }); |
|
}; |
|
|
|
|
|
let postSpeechSamples = 0; |
|
const resetAfterRecording = (offset = 0) => { |
|
self.postMessage({ |
|
type: "status", |
|
status: "recording_end", |
|
message: "Transcribing...", |
|
duration: "until_next", |
|
}); |
|
BUFFER.fill(0, offset); |
|
bufferPointer = offset; |
|
isRecording = false; |
|
postSpeechSamples = 0; |
|
}; |
|
|
|
const dispatchForTranscriptionAndResetAudioBuffer = (overflow) => { |
|
|
|
const now = Date.now(); |
|
const end = |
|
now - ((postSpeechSamples + SPEECH_PAD_SAMPLES) / INPUT_SAMPLE_RATE) * 1000; |
|
const start = end - (bufferPointer / INPUT_SAMPLE_RATE) * 1000; |
|
const duration = end - start; |
|
const overflowLength = overflow?.length ?? 0; |
|
|
|
|
|
const buffer = BUFFER.slice(0, bufferPointer + SPEECH_PAD_SAMPLES); |
|
|
|
const prevLength = prevBuffers.reduce((acc, b) => acc + b.length, 0); |
|
const paddedBuffer = new Float32Array(prevLength + buffer.length); |
|
let offset = 0; |
|
for (const prev of prevBuffers) { |
|
paddedBuffer.set(prev, offset); |
|
offset += prev.length; |
|
} |
|
paddedBuffer.set(buffer, offset); |
|
speechToSpeech(paddedBuffer, { start, end, duration }); |
|
|
|
|
|
if (overflow) { |
|
BUFFER.set(overflow, 0); |
|
} |
|
resetAfterRecording(overflowLength); |
|
}; |
|
|
|
let prevBuffers = []; |
|
self.onmessage = async (event) => { |
|
const { type, buffer } = event.data; |
|
|
|
|
|
if (type === "audio" && isPlaying) return; |
|
|
|
switch (type) { |
|
case "start_call": { |
|
const name = tts.voices[voice ?? "af_heart"]?.name ?? "Heart"; |
|
greet(`Hey there, my name is ${name}! How can I help you today?`); |
|
return; |
|
} |
|
case "end_call": |
|
messages = [SYSTEM_MESSAGE]; |
|
past_key_values_cache = null; |
|
case "interrupt": |
|
stopping_criteria?.interrupt(); |
|
return; |
|
case "set_voice": |
|
voice = event.data.voice; |
|
return; |
|
case "playback_ended": |
|
isPlaying = false; |
|
return; |
|
} |
|
|
|
const wasRecording = isRecording; |
|
const isSpeech = await vad(buffer); |
|
|
|
if (!wasRecording && !isSpeech) { |
|
|
|
|
|
|
|
if (prevBuffers.length >= MAX_NUM_PREV_BUFFERS) { |
|
|
|
prevBuffers.shift(); |
|
} |
|
prevBuffers.push(buffer); |
|
return; |
|
} |
|
|
|
const remaining = BUFFER.length - bufferPointer; |
|
if (buffer.length >= remaining) { |
|
|
|
|
|
BUFFER.set(buffer.subarray(0, remaining), bufferPointer); |
|
bufferPointer += remaining; |
|
|
|
|
|
const overflow = buffer.subarray(remaining); |
|
dispatchForTranscriptionAndResetAudioBuffer(overflow); |
|
return; |
|
} else { |
|
|
|
|
|
BUFFER.set(buffer, bufferPointer); |
|
bufferPointer += buffer.length; |
|
} |
|
|
|
if (isSpeech) { |
|
if (!isRecording) { |
|
|
|
self.postMessage({ |
|
type: "status", |
|
status: "recording_start", |
|
message: "Listening...", |
|
duration: "until_next", |
|
}); |
|
} |
|
|
|
isRecording = true; |
|
postSpeechSamples = 0; |
|
return; |
|
} |
|
|
|
postSpeechSamples += buffer.length; |
|
|
|
|
|
|
|
if (postSpeechSamples < MIN_SILENCE_DURATION_SAMPLES) { |
|
|
|
|
|
return; |
|
} |
|
|
|
if (bufferPointer < MIN_SPEECH_DURATION_SAMPLES) { |
|
|
|
|
|
resetAfterRecording(); |
|
return; |
|
} |
|
|
|
dispatchForTranscriptionAndResetAudioBuffer(); |
|
}; |
|
|
|
function greet(text) { |
|
isPlaying = true; |
|
const splitter = new TextSplitterStream(); |
|
const stream = tts.stream(splitter, { voice }); |
|
(async () => { |
|
for await (const { text: chunkText, audio } of stream) { |
|
self.postMessage({ type: "output", text: chunkText, result: audio }); |
|
} |
|
})(); |
|
splitter.push(text); |
|
splitter.close(); |
|
messages.push({ role: "assistant", content: text }); |
|
} |
|
|