eagle0504's picture
Update src/streamlit_app.py
7a2dc35 verified
import os
import time
import re
import requests
import streamlit as st
# Streamlit UI
st.set_page_config(
page_title="Task Specific Q&A Chat (No agents needed)",
layout="wide",
initial_sidebar_state="expanded" # or "collapsed"
)
st.title("🧠 Task Specific Q&A Chat (No agents needed)")
st.markdown("Ask a question and get a response from a RunPod model. Your session history is saved below.")
def run_inference(api_key: str, endpoint_id: str, instruction: str, prompt: str) -> str:
"""
Submits a job to RunPod and polls until the result is ready.
Args:
api_key (str): RunPod API key from environment.
endpoint_id (str): RunPod endpoint ID.
instruction (str): Instruction type (e.g., "math", "SQL", "python").
prompt (str): The prompt or question to submit.
Returns:
str: The full model response including <think> and <response> tags.
"""
structured_prompt = (
f"<instruction>This is a {instruction} problem.</instruction>"
f"<question>{prompt}</question>"
)
url = f'https://api.runpod.ai/v2/{endpoint_id}/run'
payload = {
"input": {
"prompt": structured_prompt,
"sampling_params": {
"temperature": 0.8,
"max_tokens": 1024,
"stop": "</response>"
}
}
}
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
submit_response = requests.post(url, headers=headers, json=payload)
if submit_response.status_code != 200:
return f"Job submission failed: {submit_response.text}"
job_id = submit_response.json().get('id')
if not job_id:
return "Failed to retrieve job ID."
status_url = f'https://api.runpod.ai/v2/{endpoint_id}/status/{job_id}'
with st.spinner("Waiting for the model to respond..."):
while True:
res = requests.get(status_url, headers=headers)
result = res.json()
status = result.get("status", "")
if status == 'COMPLETED':
try:
return result['output'][0]["choices"][0]["tokens"][0] + "</response>"
except Exception as e:
return f"Unexpected output format: {e}"
elif status in ['FAILED', 'CANCELLED']:
return f"Job failed or cancelled: {result}"
time.sleep(2)
def extract_sections(response: str) -> tuple[str, str]:
"""
Extracts <think> and <response> sections from the model output.
Args:
response (str): Full text including <think> and <response> tags.
Returns:
Tuple[str, str]: (reasoning, final_answer)
"""
think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
response_match = re.search(r"<response>(.*?)</response>", response, re.DOTALL)
reasoning = think_match.group(1).strip() if think_match else ""
final_response = response_match.group(1).strip() if response_match else response.strip()
return reasoning, final_response
# Streamed response emulator
def response_generator(response: str):
for word in response.split():
yield word + " "
time.sleep(0.05)
# Load API key and endpoint ID
try:
API_KEY = os.environ["RUNPOD_API_KEY"]
ENDPOINT_ID = os.environ["RUNPOD_ENDPOINT_ID"]
except KeyError as e:
st.error(f"Missing required environment variable: {str(e)}")
st.stop()
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
# πŸ’‘ Move this block up so it runs BEFORE chat messages are rendered
# Handle sidebar early to ensure proper app flow
with st.sidebar:
instruction = st.selectbox(
"What's the task you want this LLM to do?",
("math", "SQL", "python", "medical"),
)
# Handle clear history BEFORE rendering messages
if st.button("Clear History"):
st.session_state.messages = []
st.rerun()
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("Enter a question based on a selected task."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
# Generate answer
raw_output = run_inference(API_KEY, ENDPOINT_ID, instruction, prompt)
reasoning, final_response = extract_sections(raw_output)
# Show reasoning
with st.expander("See reasoning:", expanded=True, icon="πŸ’­"):
st.write_stream(response_generator(reasoning))
# Detect if it's a code block and render appropriately
with st.expander("See response:", expanded=True, icon="βœ…"):
if "```python" in final_response:
st.code(final_response.split("```python", 1)[-1].strip("`"), language="python")
elif "```sql" in final_response:
st.code(final_response.split("```sql", 1)[-1].strip("`"), language="sql")
else:
st.markdown(final_response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": raw_output})