conversationbot / app.py
karthikeyan-r's picture
Update app.py
be79c33 verified
import streamlit as st
from transformers import (
T5ForConditionalGeneration,
T5Tokenizer,
pipeline,
AutoTokenizer,
AutoModelForCausalLM
)
import torch
# ----- Streamlit page config -----
st.set_page_config(page_title="Chat", layout="wide")
# ----- Sidebar: Model controls -----
st.sidebar.title("Model Controls")
model_options = {
"1": "karthikeyan-r/calculation_model_11k",
"2": "karthikeyan-r/slm-custom-model_6k"
}
model_choice = st.sidebar.selectbox(
"Select Model",
options=list(model_options.values())
)
load_model_button = st.sidebar.button("Load Model")
clear_conversation_button = st.sidebar.button("Clear Conversation")
clear_model_button = st.sidebar.button("Clear Model")
# ----- Session States -----
if "model" not in st.session_state:
st.session_state["model"] = None
if "tokenizer" not in st.session_state:
st.session_state["tokenizer"] = None
if "qa_pipeline" not in st.session_state:
st.session_state["qa_pipeline"] = None
if "conversation" not in st.session_state:
st.session_state["conversation"] = []
# ----- Load Model -----
def load_model():
if st.session_state["model"] is None or st.session_state["tokenizer"] is None:
with st.spinner("Loading model..."):
try:
if model_choice == model_options["1"]:
# Load the calculation model
tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
# Add special tokens if needed
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
if tokenizer.eos_token is None:
tokenizer.add_special_tokens({'eos_token': '[EOS]'})
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
st.session_state["model"] = model
st.session_state["tokenizer"] = tokenizer
st.session_state["qa_pipeline"] = None # Not needed for calculation model
elif model_choice == model_options["2"]:
# Load the T5 model for general QA
device = 0 if torch.cuda.is_available() else -1
model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
qa_pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=device
)
st.session_state["model"] = model
st.session_state["tokenizer"] = tokenizer
st.session_state["qa_pipeline"] = qa_pipe
st.success("Model loaded successfully and ready!")
except Exception as e:
st.error(f"Error loading model: {e}")
if load_model_button:
load_model()
# ----- Clear Model -----
if clear_model_button:
st.session_state["model"] = None
st.session_state["tokenizer"] = None
st.session_state["qa_pipeline"] = None
st.success("Model cleared.")
# ----- Clear Conversation -----
if clear_conversation_button:
st.session_state["conversation"] = []
st.success("Conversation cleared.")
# ----- Title -----
st.title("Chat Conversation UI")
# ----- User Input and Processing -----
user_input = st.chat_input("Enter your query:")
if user_input:
# Save user input
st.session_state["conversation"].append({
"role": "user",
"content": user_input
})
# Generate response
if st.session_state["qa_pipeline"]:
try:
response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=250)
answer = response[0]["generated_text"]
except Exception as e:
answer = f"Error: {str(e)}"
elif st.session_state["model"] and model_choice == model_options["1"]:
try:
tokenizer = st.session_state["tokenizer"]
model = st.session_state["model"]
inputs = tokenizer(f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True)
output = model.generate(inputs.input_ids, max_length=250, pad_token_id=tokenizer.pad_token_id)
answer = tokenizer.decode(output[0], skip_special_tokens=True).split("Output:")[-1].strip()
except Exception as e:
answer = f"Error: {str(e)}"
else:
answer = "No model is loaded. Please select and load a model."
# Save assistant response
st.session_state["conversation"].append({
"role": "assistant",
"content": answer
})
# Display conversation
for message in st.session_state["conversation"]:
with st.chat_message(message["role"]):
st.write(message["content"])