DataScience / app.py
67Ayush87's picture
Update app.py
ecf4384 verified
raw
history blame
3.05 kB
import streamlit as st
from transformers import pipeline
import torch
st.set_page_config(page_title="Data Science Mentor", layout="wide")
# Cache models to avoid reloading
@st.cache_resource
def load_model(topic):
# Select model based on topic
if topic == "Python":
return pipeline("text-generation", model="tiiuae/falcon-7b-instruct", device=0 if torch.cuda.is_available() else -1)
elif topic == "GenAI":
return pipeline("text2text-generation", model="google/flan-t5-large", device=0 if torch.cuda.is_available() else -1)
elif topic == "Statistics":
return pipeline("text-generation", model="databricks/dolly-v2-3b", device=0 if torch.cuda.is_available() else -1)
elif topic == "SQL":
return pipeline("text2text-generation", model="google/flan-t5-base", device=0 if torch.cuda.is_available() else -1)
else:
# Fallback for Power BI, ML, DL
return pipeline("text-generation", model="tiiuae/falcon-7b-instruct", device=0 if torch.cuda.is_available() else -1)
def generate_answer(model, topic, level, question):
prompt = f"You are a {level} level mentor in {topic}. Answer the following question in detail:\n{question}"
if "text-generation" in model.task:
output = model(prompt, max_length=256, do_sample=True, top_k=50)
answer = output[0]['generated_text']
else:
output = model(prompt, max_length=256)
answer = output[0]['generated_text']
# Remove prompt from answer if echoed
if answer.lower().startswith(prompt.lower()):
answer = answer[len(prompt):].strip()
return answer
# --- Streamlit UI ---
st.title("πŸ€– Data Science Mentor")
with st.sidebar:
st.header("Configure Your Mentor")
topic = st.radio("Select Topic:", ["Python", "GenAI", "Statistics", "Power BI", "SQL", "Machine Learning", "Deep Learning"])
level = st.radio("Select Experience Level:", ["Beginner", "Intermediate", "Advanced"])
# Load model for topic
model = load_model(topic)
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
st.subheader(f"Ask your {topic} question:")
user_input = st.text_area("Type your question here:", height=100)
if st.button("Get Answer"):
if user_input.strip() == "":
st.warning("Please enter a question.")
else:
with st.spinner("Mentor is thinking..."):
answer = generate_answer(model, topic, level, user_input)
st.session_state.chat_history.append(("You", user_input))
st.session_state.chat_history.append(("Mentor", answer))
# Display chat history
if st.session_state.chat_history:
for i in range(0, len(st.session_state.chat_history), 2):
user_msg = st.session_state.chat_history[i][1]
mentor_msg = st.session_state.chat_history[i+1][1] if i+1 < len(st.session_state.chat_history) else ""
st.markdown(f"**You:** {user_msg}")
st.markdown(f"**Mentor:** {mentor_msg}")
st.markdown("---")
if st.button("Clear Chat"):
st.session_state.chat_history = []