|
import os |
|
import time |
|
import re |
|
import requests |
|
import streamlit as st |
|
|
|
|
|
st.set_page_config( |
|
page_title="Task Specific Q&A Chat (No agents needed)", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
st.title("π§ Task Specific Q&A Chat (No agents needed)") |
|
st.markdown("Ask a question and get a response from a RunPod model. Your session history is saved below.") |
|
|
|
def run_inference(api_key: str, endpoint_id: str, instruction: str, prompt: str) -> str: |
|
""" |
|
Submits a job to RunPod and polls until the result is ready. |
|
|
|
Args: |
|
api_key (str): RunPod API key from environment. |
|
endpoint_id (str): RunPod endpoint ID. |
|
instruction (str): Instruction type (e.g., "math", "SQL", "python"). |
|
prompt (str): The prompt or question to submit. |
|
|
|
Returns: |
|
str: The full model response including <think> and <response> tags. |
|
""" |
|
structured_prompt = ( |
|
f"<instruction>This is a {instruction} problem.</instruction>" |
|
f"<question>{prompt}</question>" |
|
) |
|
url = f'https://api.runpod.ai/v2/{endpoint_id}/run' |
|
payload = { |
|
"input": { |
|
"prompt": structured_prompt, |
|
"sampling_params": { |
|
"temperature": 0.8, |
|
"max_tokens": 1024, |
|
"stop": "</response>" |
|
} |
|
} |
|
} |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Authorization': f'Bearer {api_key}' |
|
} |
|
|
|
submit_response = requests.post(url, headers=headers, json=payload) |
|
if submit_response.status_code != 200: |
|
return f"Job submission failed: {submit_response.text}" |
|
|
|
job_id = submit_response.json().get('id') |
|
if not job_id: |
|
return "Failed to retrieve job ID." |
|
|
|
status_url = f'https://api.runpod.ai/v2/{endpoint_id}/status/{job_id}' |
|
with st.spinner("Waiting for the model to respond..."): |
|
while True: |
|
res = requests.get(status_url, headers=headers) |
|
result = res.json() |
|
status = result.get("status", "") |
|
if status == 'COMPLETED': |
|
try: |
|
return result['output'][0]["choices"][0]["tokens"][0] + "</response>" |
|
except Exception as e: |
|
return f"Unexpected output format: {e}" |
|
elif status in ['FAILED', 'CANCELLED']: |
|
return f"Job failed or cancelled: {result}" |
|
time.sleep(2) |
|
|
|
def extract_sections(response: str) -> tuple[str, str]: |
|
""" |
|
Extracts <think> and <response> sections from the model output. |
|
|
|
Args: |
|
response (str): Full text including <think> and <response> tags. |
|
|
|
Returns: |
|
Tuple[str, str]: (reasoning, final_answer) |
|
""" |
|
think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL) |
|
response_match = re.search(r"<response>(.*?)</response>", response, re.DOTALL) |
|
|
|
reasoning = think_match.group(1).strip() if think_match else "" |
|
final_response = response_match.group(1).strip() if response_match else response.strip() |
|
|
|
return reasoning, final_response |
|
|
|
|
|
def response_generator(response: str): |
|
for word in response.split(): |
|
yield word + " " |
|
time.sleep(0.05) |
|
|
|
|
|
try: |
|
API_KEY = os.environ["RUNPOD_API_KEY"] |
|
ENDPOINT_ID = os.environ["RUNPOD_ENDPOINT_ID"] |
|
except KeyError as e: |
|
st.error(f"Missing required environment variable: {str(e)}") |
|
st.stop() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
with st.sidebar: |
|
instruction = st.selectbox( |
|
"What's the task you want this LLM to do?", |
|
("math", "SQL", "python", "medical"), |
|
) |
|
|
|
|
|
if st.button("Clear History"): |
|
st.session_state.messages = [] |
|
st.rerun() |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Enter a question based on a selected task."): |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
|
raw_output = run_inference(API_KEY, ENDPOINT_ID, instruction, prompt) |
|
reasoning, final_response = extract_sections(raw_output) |
|
|
|
|
|
with st.expander("See reasoning:", expanded=True, icon="π"): |
|
st.write_stream(response_generator(reasoning)) |
|
|
|
|
|
with st.expander("See response:", expanded=True, icon="β
"): |
|
if "```python" in final_response: |
|
st.code(final_response.split("```python", 1)[-1].strip("`"), language="python") |
|
elif "```sql" in final_response: |
|
st.code(final_response.split("```sql", 1)[-1].strip("`"), language="sql") |
|
else: |
|
st.markdown(final_response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": raw_output}) |