File size: 3,816 Bytes
70b5e18
5825e6d
93643d5
040c521
5ff29bf
f63295c
ae5d8eb
fd79eb2
 
1dfce92
fd79eb2
3b3af39
fd79eb2
 
 
 
59c51e8
fd79eb2
 
 
 
b0d2a02
5825e6d
 
1602927
30d670a
83433fb
1dfce92
 
 
5a665c4
1dfce92
 
2d12eef
 
1dfce92
2d12eef
1dfce92
2bf9da4
50b814c
a822923
1602927
f3bcef9
47a0109
b3cb286
3b3af39
f63295c
 
b3cb286
5ff29bf
f63295c
6d2fab3
 
 
9afecae
6d2fab3
 
 
d64cd1f
9afecae
5ff29bf
50b814c
f63295c
 
 
b3cb286
f63295c
 
5071704
8ec85f2
d2e06fa
8ec85f2
93643d5
50b814c
daac94f
8e36dd1
3b3af39
 
93643d5
19a483c
2b2a5e4
8ec27b2
2b2a5e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import spaces
import torch
import gradio
import json
import time
from datetime import datetime
from transformers import AutoTokenizer, pipeline
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from optimum.onnxruntime import ORTModelForSequenceClassification

# CORS Config - This isn't actually working; instead, I am taking a gross approach to origin whitelisting within the service.
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win","https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win","https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win","https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# "xenova/mobilebert-uncased-mnli" "typeform/mobilebert-uncased-mnli" Fast but small--same as bundled in Statosphere

model_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
tokenizer_name = "MoritzLaurer/deberta-v3-base-zeroshot-v2.0"
file_name = "onnx/model.onnx"

model_name_cpu = "MoritzLaurer/ModernBERT-large-zeroshot-v2.0"

# model_cpu = ORTModelForSequenceClassification.from_pretrained(model_id=model_name_cpu, file_name=file_name)
# tokenizer_cpu = AutoTokenizer.from_pretrained(model_name_cpu)

classifier_cpu = pipeline(task="zero-shot-classification", model=model_name_cpu, tokenizer=model_name_cpu)
classifier_gpu = pipeline(task="zero-shot-classification", model=model_name, tokenizer=tokenizer_name, device="cuda:0")

def classify(data_string, request: gradio.Request):
    if request:
        if request.headers["origin"] not in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://crunchatize-77a78ffcc6a6.c5v4v4jx6pq5.win", "https://crunchatize-2-2b4f5b1479a6.c5v4v4jx6pq5.win", "https://tamabotchi-2dba63df3bf1.c5v4v4jx6pq5.win", "https://ravenok-statosphere-backend.hf.space", "https://lord-raven.github.io"]:
            return "{}"
    data = json.loads(data_string)

    # Try to prevent batch suggestion warning in log.
    classifier_cpu.call_count = 0
    classifier_gpu.call_count = 0

    start_time = time.time()
    result = {}
    try:
        if 'cpu' not in data:
            result = zero_shot_classification_gpu(data)
            print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - GPU Classification took {time.time() - start_time}.")
    except Exception as e:
        print(f"GPU classification failed: {e}\nFall back to CPU.")
    if not result:
        result = zero_shot_classification_cpu(data)
        print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - CPU Classification took {time.time() - start_time}.")
    return json.dumps(result)

def zero_shot_classification_cpu(data):
    return classifier_cpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])

@spaces.GPU(duration=3)
def zero_shot_classification_gpu(data):
    return classifier_gpu(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])

def create_sequences(data):
    return [data['sequence'] + '\n' + data['hypothesis_template'].format(label) for label in data['candidate_labels']]

gradio_interface = gradio.Interface(
    fn = classify,
    inputs = gradio.Textbox(label="JSON Input"),
    outputs = gradio.Textbox(label="JSON Output"),
    title = "Statosphere Backend",
    description = "This Space is a classification service for a set of chub.ai stages and not really intended for use through this UI."
)

app.mount("/gradio", gradio_interface)

gradio_interface.launch()