Spaces:
Sleeping
Sleeping
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
class MakePipeline: | |
# ๋ชจ๋ธ๋ช | |
MODEL_ID = "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B" | |
# ๋ณ์์ด๊ธฐํ | |
# model_id | |
# tokenizer | |
# llm | |
def __init__(self, model_id: str = MODEL_ID): | |
print("[torch] is available:", torch.cuda.is_available()) | |
print("[device] default:", torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
self.model_id = model_id | |
self.tokenizer = None | |
self.llm = None | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.config = { # ์ด๊ธฐ๊ฐ | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"repetition_penalty": 1.05, | |
"max_new_tokens": 96 | |
} | |
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ | |
def build(self, type: str): | |
if(type == 'ui'): | |
print("[build] UI ํ ์คํธ์ฉ - ๋ชจ๋ธ ๋ก๋ฉ ์๋ต") | |
return | |
if(type == 'hf'): | |
# ํ๊น ํ์ด์ค secret์ ๋ฑ๋ก๋ ํ ํฐ ๋ก๋ | |
access_token = os.environ.get("HF_TOKEN") | |
else: | |
# ๋ก์ปฌ ์คํ์ token.txt์์ ํ ํฐ ๋ก๋ | |
with open("token.txt", "r") as f: | |
access_token = f.read().strip() | |
tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=access_token) | |
model = AutoModelForCausalLM.from_pretrained(self.model_id, token=access_token, trust_remote_code=True) | |
self.tokenizer = tokenizer | |
# ํ๊น ํ์ด์ค ์ ๋ก๋ ์ f16 ์ฌ์ฉ ์ ํจ | |
if(type == 'hf'): | |
llm = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
) | |
else: | |
model.eval() | |
llm = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.float16 | |
) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
self.llm = llm | |
# ํ๋ฆฌ๋ฏธํฐ ์ค์ | |
def update_config(self, new_config: dict): | |
self.config.update(new_config) | |
print("[config] updated:", self.config) | |
# ๋ชจ๋ธ ์ถ๋ ฅ ์์ฑ ํจ์ | |
def character_chat(self, prompt): | |
print("[debug] generating with:", self.config) | |
outputs = self.llm( | |
prompt, | |
do_sample=True, | |
max_new_tokens=self.config["max_new_tokens"], | |
temperature=self.config["temperature"], | |
top_p=self.config["top_p"], | |
repetition_penalty=self.config["repetition_penalty"], | |
eos_token_id=self.tokenizer.eos_token_id, | |
return_full_text=True | |
) | |
return outputs[0]["generated_text"] |