Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import type { ModelData, WidgetExampleAttribute } from "@huggingface/tasks"; | |
| import { parseJSON } from "../../../utils/ViewUtils.js"; | |
| import { ComputeType, type ModelLoadInfo, type TableData } from "./types.js"; | |
| import { LoadState } from "./types.js"; | |
| import { isLoggedIn } from "../stores.js"; | |
| import { get } from "svelte/store"; | |
| const KEYS_TEXT: WidgetExampleAttribute[] = ["text", "context", "candidate_labels"]; | |
| const KEYS_TABLE: WidgetExampleAttribute[] = ["table", "structured_data"]; | |
| type QueryParamVal = string | null | boolean | (string | number)[][]; | |
| export function getQueryParamVal(key: WidgetExampleAttribute): QueryParamVal { | |
| const searchParams = new URL(window.location.href).searchParams; | |
| const value = searchParams.get(key); | |
| if (KEYS_TEXT.includes(key)) { | |
| return value; | |
| } else if (KEYS_TABLE.includes(key)) { | |
| const table = convertDataToTable((parseJSON(value) as TableData) ?? {}); | |
| return table; | |
| } else if (key === "multi_class") { | |
| return value === "true"; | |
| } | |
| return value; | |
| } | |
| // Update current url search params, keeping existing keys intact. | |
| export function updateUrl(obj: Partial<Record<WidgetExampleAttribute, string | undefined>>): void { | |
| if (!window) { | |
| return; | |
| } | |
| const sp = new URL(window.location.href).searchParams; | |
| for (const [k, v] of Object.entries(obj)) { | |
| if (v === undefined) { | |
| sp.delete(k); | |
| } else { | |
| sp.set(k, v); | |
| } | |
| } | |
| const path = `${window.location.pathname}?${sp.toString()}`; | |
| window.history.replaceState(null, "", path); | |
| } | |
| // Run through our own proxy to bypass CORS: | |
| function proxify(url: string): string { | |
| return url.startsWith(`http://localhost`) || new URL(url).host === window.location.host | |
| ? url | |
| : `https://widgets.hf.co/proxy?url=${url}`; | |
| } | |
| // Get BLOB from a given URL after proxifying the URL | |
| export async function getBlobFromUrl(url: string): Promise<Blob> { | |
| const proxiedUrl = proxify(url); | |
| const res = await fetch(proxiedUrl); | |
| const blob = await res.blob(); | |
| return blob; | |
| } | |
| interface Success<T> { | |
| computeTime: string; | |
| output: T; | |
| outputJson: string; | |
| response: Response; | |
| status: "success"; | |
| } | |
| interface LoadingModel { | |
| error: string; | |
| estimatedTime: number; | |
| status: "loading-model"; | |
| } | |
| interface Error { | |
| error: string; | |
| status: "error"; | |
| } | |
| interface CacheNotFound { | |
| status: "cache not found"; | |
| } | |
| type Result<T> = Success<T> | LoadingModel | Error | CacheNotFound; | |
| export async function callInferenceApi<T>( | |
| url: string, | |
| repoId: string, | |
| requestBody: Record<string, unknown>, | |
| apiToken = "", | |
| outputParsingFn: (x: unknown) => T, | |
| waitForModel = false, // If true, the server will only respond once the model has been loaded on Inference API (serverless) | |
| includeCredentials = false, | |
| isOnLoadCall = false, // If true, the server will try to answer from cache and not do anything if not | |
| useCache = true | |
| ): Promise<Result<T>> { | |
| const contentType = | |
| "file" in requestBody && requestBody["file"] && requestBody["file"] instanceof Blob && requestBody["file"].type | |
| ? requestBody["file"]["type"] | |
| : "application/json"; | |
| const headers = new Headers(); | |
| headers.set("Content-Type", contentType); | |
| if (apiToken) { | |
| headers.set("Authorization", `Bearer ${apiToken}`); | |
| } | |
| if (waitForModel) { | |
| headers.set("X-Wait-For-Model", "true"); | |
| } | |
| if (useCache === false && get(isLoggedIn)) { | |
| headers.set("X-Use-Cache", "false"); | |
| } | |
| if (isOnLoadCall || !get(isLoggedIn)) { | |
| headers.set("X-Load-Model", "0"); | |
| } | |
| // `File` is a subtype of `Blob`: therefore, checking for instanceof `Blob` also checks for instanceof `File` | |
| const reqBody: Blob | string = | |
| "file" in requestBody && requestBody["file"] instanceof Blob ? requestBody.file : JSON.stringify(requestBody); | |
| const response = await fetch(`${url}/models/${repoId}`, { | |
| method: "POST", | |
| body: reqBody, | |
| headers, | |
| credentials: includeCredentials ? "include" : "same-origin", | |
| }); | |
| if (response.ok) { | |
| // Success | |
| const computeTime = response.headers.has("x-compute-time") | |
| ? `${response.headers.get("x-compute-time")} s` | |
| : `cached`; | |
| const isMediaContent = (response.headers.get("content-type")?.search(/^(?:audio|image)/i) ?? -1) !== -1; | |
| const body = !isMediaContent ? await response.json() : await response.blob(); | |
| try { | |
| const output = outputParsingFn(body); | |
| const outputJson = !isMediaContent ? JSON.stringify(body, null, 2) : ""; | |
| return { computeTime, output, outputJson, response, status: "success" }; | |
| } catch (e) { | |
| if (isOnLoadCall && body.error === "not loaded yet") { | |
| return { status: "cache not found" }; | |
| } | |
| // Invalid output | |
| const error = `API Implementation Error: ${String(e).replace(/^Error: /, "")}`; | |
| return { error, status: "error" }; | |
| } | |
| } else { | |
| // Error | |
| const bodyText = await response.text(); | |
| const body = parseJSON<Record<string, unknown>>(bodyText) ?? {}; | |
| if ( | |
| body["error"] && | |
| response.status === 503 && | |
| body["estimated_time"] !== null && | |
| body["estimated_time"] !== undefined | |
| ) { | |
| // Model needs loading | |
| return { error: String(body["error"]), estimatedTime: +body["estimated_time"], status: "loading-model" }; | |
| } else { | |
| // Other errors | |
| const { status, statusText } = response; | |
| return { | |
| error: String(body["error"]) || String(body["traceback"]) || `${status} ${statusText}`, | |
| status: "error", | |
| }; | |
| } | |
| } | |
| } | |
| export async function getModelLoadInfo( | |
| url: string, | |
| repoId: string, | |
| includeCredentials = false | |
| ): Promise<ModelLoadInfo> { | |
| const response = await fetch(`${url}/status/${repoId}`, { | |
| credentials: includeCredentials ? "include" : "same-origin", | |
| }); | |
| const output: { | |
| state: LoadState; | |
| compute_type: ComputeType | Record<ComputeType, { [key in ComputeType]?: string } & { count: number }>; | |
| loaded: boolean; | |
| error: Error; | |
| } = await response.json(); | |
| if (response.ok && typeof output === "object" && output.loaded !== undefined) { | |
| // eslint-disable-next-line @typescript-eslint/naming-convention | |
| const compute_type = | |
| typeof output.compute_type === "string" | |
| ? output.compute_type | |
| : output.compute_type["gpu"] | |
| ? ComputeType.GPU | |
| : ComputeType.CPU; | |
| return { compute_type, state: output.state }; | |
| } else { | |
| console.warn(response.status, output.error); | |
| return { state: LoadState.Error }; | |
| } | |
| } | |
| // Extend requestBody with user supplied parameters for Inference API (serverless) | |
| export function addInferenceParameters(requestBody: Record<string, unknown>, model: ModelData): void { | |
| const inference = model?.cardData?.inference; | |
| if (typeof inference === "object") { | |
| const inferenceParameters = inference?.parameters; | |
| if (inferenceParameters) { | |
| if (requestBody.parameters) { | |
| requestBody.parameters = { ...requestBody.parameters, ...inferenceParameters }; | |
| } else { | |
| requestBody.parameters = inferenceParameters; | |
| } | |
| } | |
| } | |
| } | |
| /* | |
| * Converts table from [[Header0, Header1, Header2], [Column0Val0, Column1Val0, Column2Val0], ...] | |
| * to {Header0: [ColumnVal0, ...], Header1: [Column1Val0, ...], Header2: [Column2Val0, ...]} | |
| */ | |
| export function convertTableToData(table: (string | number)[][]): TableData { | |
| return Object.fromEntries( | |
| table[0].map((cell, x) => { | |
| return [ | |
| cell, | |
| table | |
| .slice(1) | |
| .flat() | |
| .filter((_, i) => i % table[0].length === x) | |
| .map((v) => String(v)), // some models can only handle strings (no numbers) | |
| ]; | |
| }) | |
| ); | |
| } | |
| /** | |
| * Converts data from {Header0: [ColumnVal0, ...], Header1: [Column1Val0, ...], Header2: [Column2Val0, ...]} | |
| * to [[Header0, Header1, Header2], [Column0Val0, Column1Val0, Column2Val0], ...] | |
| */ | |
| export function convertDataToTable(data: TableData): (string | number)[][] { | |
| const dataArray = Object.entries(data); // [header, cell[]][] | |
| const nbCols = dataArray.length; | |
| const nbRows = (dataArray[0]?.[1]?.length ?? 0) + 1; | |
| return Array(nbRows) | |
| .fill("") | |
| .map((_, y) => | |
| Array(nbCols) | |
| .fill("") | |
| .map((__, x) => (y === 0 ? dataArray[x][0] : dataArray[x][1][y - 1])) | |
| ); | |
| } | |