Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import torch | |
st.set_page_config(page_title="Data Science Mentor", layout="wide") | |
# Cache models to avoid reloading | |
def load_model(topic): | |
# Select model based on topic | |
if topic == "Python": | |
return pipeline("text-generation", model="tiiuae/falcon-7b-instruct", device=0 if torch.cuda.is_available() else -1) | |
elif topic == "GenAI": | |
return pipeline("text2text-generation", model="google/flan-t5-large", device=0 if torch.cuda.is_available() else -1) | |
elif topic == "Statistics": | |
return pipeline("text-generation", model="databricks/dolly-v2-3b", device=0 if torch.cuda.is_available() else -1) | |
elif topic == "SQL": | |
return pipeline("text2text-generation", model="google/flan-t5-base", device=0 if torch.cuda.is_available() else -1) | |
else: | |
# Fallback for Power BI, ML, DL | |
return pipeline("text-generation", model="tiiuae/falcon-7b-instruct", device=0 if torch.cuda.is_available() else -1) | |
def generate_answer(model, topic, level, question): | |
prompt = f"You are a {level} level mentor in {topic}. Answer the following question in detail:\n{question}" | |
if "text-generation" in model.task: | |
output = model(prompt, max_length=256, do_sample=True, top_k=50) | |
answer = output[0]['generated_text'] | |
else: | |
output = model(prompt, max_length=256) | |
answer = output[0]['generated_text'] | |
# Remove prompt from answer if echoed | |
if answer.lower().startswith(prompt.lower()): | |
answer = answer[len(prompt):].strip() | |
return answer | |
# --- Streamlit UI --- | |
st.title("π€ Data Science Mentor") | |
with st.sidebar: | |
st.header("Configure Your Mentor") | |
topic = st.radio("Select Topic:", ["Python", "GenAI", "Statistics", "Power BI", "SQL", "Machine Learning", "Deep Learning"]) | |
level = st.radio("Select Experience Level:", ["Beginner", "Intermediate", "Advanced"]) | |
# Load model for topic | |
model = load_model(topic) | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
st.subheader(f"Ask your {topic} question:") | |
user_input = st.text_area("Type your question here:", height=100) | |
if st.button("Get Answer"): | |
if user_input.strip() == "": | |
st.warning("Please enter a question.") | |
else: | |
with st.spinner("Mentor is thinking..."): | |
answer = generate_answer(model, topic, level, user_input) | |
st.session_state.chat_history.append(("You", user_input)) | |
st.session_state.chat_history.append(("Mentor", answer)) | |
# Display chat history | |
if st.session_state.chat_history: | |
for i in range(0, len(st.session_state.chat_history), 2): | |
user_msg = st.session_state.chat_history[i][1] | |
mentor_msg = st.session_state.chat_history[i+1][1] if i+1 < len(st.session_state.chat_history) else "" | |
st.markdown(f"**You:** {user_msg}") | |
st.markdown(f"**Mentor:** {mentor_msg}") | |
st.markdown("---") | |
if st.button("Clear Chat"): | |
st.session_state.chat_history = [] | |