|
import path from 'node:path'; |
|
import fs from 'node:fs'; |
|
import process from 'node:process'; |
|
import { Buffer } from 'node:buffer'; |
|
|
|
import { pipeline, env, RawImage } from 'sillytavern-transformers'; |
|
import { getConfigValue } from './util.js'; |
|
import { serverDirectory } from './server-directory.js'; |
|
|
|
configureTransformers(); |
|
|
|
function configureTransformers() { |
|
|
|
env.backends.onnx.wasm.numThreads = 1; |
|
|
|
env.backends.onnx.wasm.wasmPaths = path.join(serverDirectory, 'node_modules', 'sillytavern-transformers', 'dist') + path.sep; |
|
} |
|
|
|
const tasks = { |
|
'text-classification': { |
|
defaultModel: 'Cohee/distilbert-base-uncased-go-emotions-onnx', |
|
pipeline: null, |
|
configField: 'extensions.models.classification', |
|
quantized: true, |
|
}, |
|
'image-to-text': { |
|
defaultModel: 'Xenova/vit-gpt2-image-captioning', |
|
pipeline: null, |
|
configField: 'extensions.models.captioning', |
|
quantized: true, |
|
}, |
|
'feature-extraction': { |
|
defaultModel: 'Xenova/all-mpnet-base-v2', |
|
pipeline: null, |
|
configField: 'extensions.models.embedding', |
|
quantized: true, |
|
}, |
|
'automatic-speech-recognition': { |
|
defaultModel: 'Xenova/whisper-small', |
|
pipeline: null, |
|
configField: 'extensions.models.speechToText', |
|
quantized: true, |
|
}, |
|
'text-to-speech': { |
|
defaultModel: 'Xenova/speecht5_tts', |
|
pipeline: null, |
|
configField: 'extensions.models.textToSpeech', |
|
quantized: false, |
|
}, |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
export async function getRawImage(image) { |
|
try { |
|
const buffer = Buffer.from(image, 'base64'); |
|
const byteArray = new Uint8Array(buffer); |
|
const blob = new Blob([byteArray]); |
|
|
|
const rawImage = await RawImage.fromBlob(blob); |
|
return rawImage; |
|
} catch { |
|
return null; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
function getModelForTask(task) { |
|
const defaultModel = tasks[task].defaultModel; |
|
|
|
try { |
|
const model = getConfigValue(tasks[task].configField, null); |
|
return model || defaultModel; |
|
} catch (error) { |
|
console.warn('Failed to read config.yaml, using default classification model.'); |
|
return defaultModel; |
|
} |
|
} |
|
|
|
async function migrateCacheToDataDir() { |
|
const oldCacheDir = path.join(process.cwd(), 'cache'); |
|
const newCacheDir = path.join(globalThis.DATA_ROOT, '_cache'); |
|
|
|
if (!fs.existsSync(newCacheDir)) { |
|
fs.mkdirSync(newCacheDir, { recursive: true }); |
|
} |
|
|
|
if (fs.existsSync(oldCacheDir) && fs.statSync(oldCacheDir).isDirectory()) { |
|
const files = fs.readdirSync(oldCacheDir); |
|
|
|
if (files.length === 0) { |
|
return; |
|
} |
|
|
|
console.log('Migrating model cache files to data directory. Please wait...'); |
|
|
|
for (const file of files) { |
|
try { |
|
const oldPath = path.join(oldCacheDir, file); |
|
const newPath = path.join(newCacheDir, file); |
|
fs.cpSync(oldPath, newPath, { recursive: true, force: true }); |
|
fs.rmSync(oldPath, { recursive: true, force: true }); |
|
} catch (error) { |
|
console.warn('Failed to migrate cache file. The model will be re-downloaded.', error); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export async function getPipeline(task, forceModel = '') { |
|
await migrateCacheToDataDir(); |
|
|
|
if (tasks[task].pipeline) { |
|
if (forceModel === '' || tasks[task].currentModel === forceModel) { |
|
return tasks[task].pipeline; |
|
} |
|
console.log('Disposing transformers.js pipeline for for task', task, 'with model', tasks[task].currentModel); |
|
await tasks[task].pipeline.dispose(); |
|
} |
|
|
|
const cacheDir = path.join(globalThis.DATA_ROOT, '_cache'); |
|
const model = forceModel || getModelForTask(task); |
|
const localOnly = !getConfigValue('extensions.models.autoDownload', true, 'boolean'); |
|
console.log('Initializing transformers.js pipeline for task', task, 'with model', model); |
|
const instance = await pipeline(task, model, { cache_dir: cacheDir, quantized: tasks[task].quantized ?? true, local_files_only: localOnly }); |
|
tasks[task].pipeline = instance; |
|
tasks[task].currentModel = model; |
|
|
|
return instance; |
|
} |
|
|
|
export default { |
|
getRawImage, |
|
getPipeline, |
|
}; |
|
|