Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,816 Bytes
70b5e18 5825e6d 93643d5 040c521 5ff29bf f63295c ae5d8eb fd79eb2 1dfce92 fd79eb2 3b3af39 fd79eb2 59c51e8 fd79eb2 b0d2a02 5825e6d 1602927 30d670a 83433fb 1dfce92 5a665c4 1dfce92 2d12eef 1dfce92 2d12eef 1dfce92 2bf9da4 50b814c a822923 1602927 f3bcef9 47a0109 b3cb286 3b3af39 f63295c b3cb286 5ff29bf f63295c 6d2fab3 9afecae 6d2fab3 d64cd1f 9afecae 5ff29bf 50b814c f63295c b3cb286 f63295c 5071704 8ec85f2 d2e06fa 8ec85f2 93643d5 50b814c daac94f 8e36dd1 3b3af39 93643d5 19a483c 2b2a5e4 8ec27b2 2b2a5e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import spaces
import torch
import gradio
import json
import time
from datetime import datetime
from transformers import AutoTokenizer, pipeline
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from optimum.onnxruntime import ORTModelForSequenceClassification
# CORS Config - This isn't actually working; instead, I am taking a gross approach to origin whitelisting within the service.
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere
model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
file_name = "onnx/model.onnx"
model_name_cpu = "MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
# model_cpu = ORTModelForSequenceClassification.from_pretrained(model_id=model_name_cpu, file_name=file_name)
# tokenizer_cpu = AutoTokenizer.from_pretrained(model_name_cpu)
classifier_cpu = pipeline(task="zero-shot-classification", model=model_name_cpu, tokenizer=model_name_cpu)
classifier_gpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name, device="cuda:0")
def classify(data_string, request: gradio.Request):
if request:
if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win", "https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
return "{}"
data = json.loads(data_string)
# Try to prevent batch suggestion warning in log.
classifier_cpu.call_count = 0
classifier_gpu.call_count = 0
start_time = time.time()
result = {}
try:
if 'cpu' not in data:
result = zero_shot_classification_gpu(data)
print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - GPU Classification took {time.time() - start_time}.")
except Exception as e:
print(f"GPU classification failed: {e}\nFall back to CPU.")
if not result:
result = zero_shot_classification_cpu(data)
print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - CPU Classification took {time.time() - start_time}.")
return json.dumps(result)
def zero_shot_classification_cpu(data):
return classifier_cpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
@spaces.GPU(duration=3)
def zero_shot_classification_gpu(data):
return classifier_gpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
def create_sequences(data):
return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]
gradio_interface = gradio.Interface(
fn = classify,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox(label="JSON Output"),
title = "Statosphere Backend",
description = "This Space is a classification service for a set of chub.ai stages and not really intended for use through this UI."
)
app.mount("/gradio", gradio_interface)
gradio_interface.launch()
|