eurekacrew / agents /hypothesis.py
gaur3009's picture
Create hypothesis.py
9b91884 verified
raw
history blame contribute delete
678 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cpu")
model_id = "TheBloke/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
def generate(summaries):
context = " ".join([p['summary'] for p in summaries])
prompt = f"Based on these paper summaries:\n{context}\n\nSuggest a new AI research hypothesis:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text[len(prompt):].strip()