File size: 5,528 Bytes
439f722 74d6322 439f722 7a2dc35 439f722 7a2dc35 439f722 7a2dc35 439f722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import time
import re
import requests
import streamlit as st
# Streamlit UI
st.set_page_config(
page_title="Task Specific Q&A Chat (No agents needed)",
layout="wide",
initial_sidebar_state="expanded" # or "collapsed"
)
st.title("π§ Task Specific Q&A Chat (No agents needed)")
st.markdown("Ask a question and get a response from a RunPod model. Your session history is saved below.")
def run_inference(api_key: str, endpoint_id: str, instruction: str, prompt: str) -> str:
"""
Submits a job to RunPod and polls until the result is ready.
Args:
api_key (str): RunPod API key from environment.
endpoint_id (str): RunPod endpoint ID.
instruction (str): Instruction type (e.g., "math", "SQL", "python").
prompt (str): The prompt or question to submit.
Returns:
str: The full model response including <think> and <response> tags.
"""
structured_prompt = (
f"<instruction>This is a {instruction} problem.</instruction>"
f"<question>{prompt}</question>"
)
url = f'https://api.runpod.ai/v2/{endpoint_id}/run'
payload = {
"input": {
"prompt": structured_prompt,
"sampling_params": {
"temperature": 0.8,
"max_tokens": 1024,
"stop": "</response>"
}
}
}
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
submit_response = requests.post(url, headers=headers, json=payload)
if submit_response.status_code != 200:
return f"Job submission failed: {submit_response.text}"
job_id = submit_response.json().get('id')
if not job_id:
return "Failed to retrieve job ID."
status_url = f'https://api.runpod.ai/v2/{endpoint_id}/status/{job_id}'
with st.spinner("Waiting for the model to respond..."):
while True:
res = requests.get(status_url, headers=headers)
result = res.json()
status = result.get("status", "")
if status == 'COMPLETED':
try:
return result['output'][0]["choices"][0]["tokens"][0] + "</response>"
except Exception as e:
return f"Unexpected output format: {e}"
elif status in ['FAILED', 'CANCELLED']:
return f"Job failed or cancelled: {result}"
time.sleep(2)
def extract_sections(response: str) -> tuple[str, str]:
"""
Extracts <think> and <response> sections from the model output.
Args:
response (str): Full text including <think> and <response> tags.
Returns:
Tuple[str, str]: (reasoning, final_answer)
"""
think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
response_match = re.search(r"<response>(.*?)</response>", response, re.DOTALL)
reasoning = think_match.group(1).strip() if think_match else ""
final_response = response_match.group(1).strip() if response_match else response.strip()
return reasoning, final_response
# Streamed response emulator
def response_generator(response: str):
for word in response.split():
yield word + " "
time.sleep(0.05)
# Load API key and endpoint ID
try:
API_KEY = os.environ["RUNPOD_API_KEY"]
ENDPOINT_ID = os.environ["RUNPOD_ENDPOINT_ID"]
except KeyError as e:
st.error(f"Missing required environment variable: {str(e)}")
st.stop()
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
# π‘ Move this block up so it runs BEFORE chat messages are rendered
# Handle sidebar early to ensure proper app flow
with st.sidebar:
instruction = st.selectbox(
"What's the task you want this LLM to do?",
("math", "SQL", "python", "medical"),
)
# Handle clear history BEFORE rendering messages
if st.button("Clear History"):
st.session_state.messages = []
st.rerun()
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("Enter a question based on a selected task."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
# Generate answer
raw_output = run_inference(API_KEY, ENDPOINT_ID, instruction, prompt)
reasoning, final_response = extract_sections(raw_output)
# Show reasoning
with st.expander("See reasoning:", expanded=True, icon="π"):
st.write_stream(response_generator(reasoning))
# Detect if it's a code block and render appropriately
with st.expander("See response:", expanded=True, icon="β
"):
if "```python" in final_response:
st.code(final_response.split("```python", 1)[-1].strip("`"), language="python")
elif "```sql" in final_response:
st.code(final_response.split("```sql", 1)[-1].strip("`"), language="sql")
else:
st.markdown(final_response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": raw_output}) |