import sys | |
import os | |
import torch | |
sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
from all_models import models | |
def query_(query, doc): | |
input_text = f""" | |
You are an AI assistant designed to extract relevant information from a document and generate a clear, concise answer. | |
Question: {query} | |
Provide a *single-paragraph response of 250 words* that summarizes key details, explains the answer logically, and avoids repetition. Ignore irrelevant details like page numbers, author names, and metadata. | |
Context: | |
"{doc}" | |
Answer: | |
""" | |
# Move inputs to the same device as the model | |
device = next(models.flan_model.parameters()).device | |
inputs = models.flan_tokenizer(input_text, return_tensors="pt").to(device) | |
input_length = inputs["input_ids"].shape[1] | |
max_tokens = input_length + 180 | |
with torch.no_grad(): | |
outputs = models.flan_model.generate( | |
**inputs, | |
do_sample=True, | |
max_length=max_tokens, | |
min_length=100, | |
early_stopping=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
num_beams=3 | |
) | |
answer = models.flan_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# print(answer) | |
# answer = extract_answer(answer) | |
return answer | |