Spaces:
Sleeping
Sleeping
File size: 2,836 Bytes
aab927d ffc6645 aab927d 5c6d006 aab927d 5c6d006 aab927d 5c6d006 aab927d 5c6d006 aab927d 5c6d006 aab927d 5c6d006 aab927d 5c6d006 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
import os
class MakePipeline:
# ๋ชจ๋ธ๋ช
MODEL_ID = "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"
# ๋ณ์์ด๊ธฐํ
# model_id
# tokenizer
# llm
def __init__(self, model_id: str = MODEL_ID):
print("[torch] is available:", torch.cuda.is_available())
print("[device] default:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
self.model_id = model_id
self.tokenizer = None
self.llm = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = { # ์ด๊ธฐ๊ฐ
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.05,
"max_new_tokens": 96
}
# ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
def build(self, type: str):
if(type == 'ui'):
print("[build] UI ํ
์คํธ์ฉ - ๋ชจ๋ธ ๋ก๋ฉ ์๋ต")
return
if(type == 'hf'):
# ํ๊น
ํ์ด์ค secret์ ๋ฑ๋ก๋ ํ ํฐ ๋ก๋
access_token = os.environ.get("HF_TOKEN")
else:
# ๋ก์ปฌ ์คํ์ token.txt์์ ํ ํฐ ๋ก๋
with open("token.txt", "r") as f:
access_token = f.read().strip()
tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=access_token)
model = AutoModelForCausalLM.from_pretrained(self.model_id, token=access_token, trust_remote_code=True)
self.tokenizer = tokenizer
# ํ๊น
ํ์ด์ค ์
๋ก๋ ์ f16 ์ฌ์ฉ ์ ํจ
if(type == 'hf'):
llm = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
else:
model.eval()
llm = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.float16
)
if torch.cuda.is_available():
model.to("cuda")
self.llm = llm
# ํ๋ฆฌ๋ฏธํฐ ์ค์
def update_config(self, new_config: dict):
self.config.update(new_config)
print("[config] updated:", self.config)
# ๋ชจ๋ธ ์ถ๋ ฅ ์์ฑ ํจ์
def character_chat(self, prompt):
print("[debug] generating with:", self.config)
outputs = self.llm(
prompt,
do_sample=True,
max_new_tokens=self.config["max_new_tokens"],
temperature=self.config["temperature"],
top_p=self.config["top_p"],
repetition_penalty=self.config["repetition_penalty"],
eos_token_id=self.tokenizer.eos_token_id,
return_full_text=True
)
return outputs[0]["generated_text"] |