|
import torch |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger.addHandler(logging.StreamHandler()) |
|
|
|
|
|
class LocalModel: |
|
def __init__(self, model_name: str, max_tokens: int, temperature: float): |
|
self.max_tokens = max_tokens |
|
self.temperature = temperature |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, torch_dtype=torch.float16 |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.pipeline = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
) |
|
|
|
def __call__(self, prompt: str, **kwargs) -> str: |
|
|
|
result = self.pipeline( |
|
prompt, |
|
max_new_tokens=self.max_tokens, |
|
temperature=self.temperature, |
|
**kwargs, |
|
) |
|
|
|
output = result[0]["generated_text"] |
|
logger.info(f"Model output: {output}") |
|
|
|
return result[0]["generated_text"] |
|
|
|
|
|
if __name__ == "__main__": |
|
local_model = LocalModel("Qwen/Qwen2.5-1.5B", max_tokens=100, temperature=0.5) |
|
output = local_model("A big foot") |
|
|
|
print(output) |
|
|