Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,6 @@ import numpy as np
|
|
| 6 |
import torch
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
-
from peft import PeftModel
|
| 10 |
from reportlab.lib.pagesizes import A4
|
| 11 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
| 12 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
@@ -56,21 +55,19 @@ def retrieve_milestone(user_input):
|
|
| 56 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 57 |
|
| 58 |
# Initialize IBM Granite Model
|
| 59 |
-
BASE_NAME = "ibm-granite/granite-3.0-
|
| 60 |
-
LORA_NAME = "ibm-granite/granite-rag-3.0-8b-lora"
|
| 61 |
|
| 62 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 63 |
|
| 64 |
tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
|
| 65 |
model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
|
| 66 |
-
model_rag = PeftModel.from_pretrained(model_base, LORA_NAME)
|
| 67 |
|
| 68 |
def generate_response(user_input, child_age):
|
| 69 |
relevant_milestone = retrieve_milestone(user_input)
|
| 70 |
question_chat = [
|
| 71 |
{
|
| 72 |
"role": "system",
|
| 73 |
-
"content": "
|
| 74 |
},
|
| 75 |
{
|
| 76 |
"role": "user",
|
|
@@ -79,7 +76,7 @@ def generate_response(user_input, child_age):
|
|
| 79 |
]
|
| 80 |
input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
|
| 81 |
inputs = tokenizer(input_text, return_tensors="pt")
|
| 82 |
-
output =
|
| 83 |
output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 84 |
return output_text
|
| 85 |
|
|
|
|
| 6 |
import torch
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
| 9 |
from reportlab.lib.pagesizes import A4
|
| 10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
| 11 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
|
|
| 55 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
| 56 |
|
| 57 |
# Initialize IBM Granite Model
|
| 58 |
+
BASE_NAME = "ibm-granite/granite-3.0-2b-base"
|
|
|
|
| 59 |
|
| 60 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 61 |
|
| 62 |
tokenizer = AutoTokenizer.from_pretrained(BASE_NAME, padding_side='left', trust_remote_code=True)
|
| 63 |
model_base = AutoModelForCausalLM.from_pretrained(BASE_NAME, device_map="auto")
|
|
|
|
| 64 |
|
| 65 |
def generate_response(user_input, child_age):
|
| 66 |
relevant_milestone = retrieve_milestone(user_input)
|
| 67 |
question_chat = [
|
| 68 |
{
|
| 69 |
"role": "system",
|
| 70 |
+
"content": f"The child is {child_age} months old. Based on the given traits: {user_input}, determine whether the child is meeting expected milestones. Relevant milestone: {relevant_milestone}. If there are any concerns, suggest steps the parents can take."
|
| 71 |
},
|
| 72 |
{
|
| 73 |
"role": "user",
|
|
|
|
| 76 |
]
|
| 77 |
input_text = tokenizer.apply_chat_template(question_chat, tokenize=False, add_generation_prompt=True)
|
| 78 |
inputs = tokenizer(input_text, return_tensors="pt")
|
| 79 |
+
output = model_base.generate(inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), max_new_tokens=500)
|
| 80 |
output_text = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 81 |
return output_text
|
| 82 |
|