File size: 5,528 Bytes
439f722
 
 
 
74d6322
 
439f722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a2dc35
 
439f722
 
 
 
 
 
7a2dc35
439f722
 
 
 
7a2dc35
 
 
 
439f722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import time
import re
import requests
import streamlit as st

# Streamlit UI
st.set_page_config(
    page_title="Task Specific Q&A Chat (No agents needed)",
    layout="wide",
    initial_sidebar_state="expanded"  # or "collapsed"
)
st.title("🧠 Task Specific Q&A Chat (No agents needed)")
st.markdown("Ask a question and get a response from a RunPod model. Your session history is saved below.")

def run_inference(api_key: str, endpoint_id: str, instruction: str, prompt: str) -> str:
    """
    Submits a job to RunPod and polls until the result is ready.

    Args:
        api_key (str): RunPod API key from environment.
        endpoint_id (str): RunPod endpoint ID.
        instruction (str): Instruction type (e.g., "math", "SQL", "python").
        prompt (str): The prompt or question to submit.

    Returns:
        str: The full model response including <think> and <response> tags.
    """
    structured_prompt = (
        f"<instruction>This is a {instruction} problem.</instruction>"
        f"<question>{prompt}</question>"
    )
    url = f'https://api.runpod.ai/v2/{endpoint_id}/run'
    payload = {
        "input": {
            "prompt": structured_prompt,
            "sampling_params": {
                "temperature": 0.8,
                "max_tokens": 1024,
                "stop": "</response>"
            }
        }
    }
    headers = {
        'Content-Type': 'application/json',
        'Authorization': f'Bearer {api_key}'
    }

    submit_response = requests.post(url, headers=headers, json=payload)
    if submit_response.status_code != 200:
        return f"Job submission failed: {submit_response.text}"

    job_id = submit_response.json().get('id')
    if not job_id:
        return "Failed to retrieve job ID."

    status_url = f'https://api.runpod.ai/v2/{endpoint_id}/status/{job_id}'
    with st.spinner("Waiting for the model to respond..."):
        while True:
            res = requests.get(status_url, headers=headers)
            result = res.json()
            status = result.get("status", "")
            if status == 'COMPLETED':
                try:
                    return result['output'][0]["choices"][0]["tokens"][0] + "</response>"
                except Exception as e:
                    return f"Unexpected output format: {e}"
            elif status in ['FAILED', 'CANCELLED']:
                return f"Job failed or cancelled: {result}"
            time.sleep(2)

def extract_sections(response: str) -> tuple[str, str]:
    """
    Extracts <think> and <response> sections from the model output.

    Args:
        response (str): Full text including <think> and <response> tags.

    Returns:
        Tuple[str, str]: (reasoning, final_answer)
    """
    think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
    response_match = re.search(r"<response>(.*?)</response>", response, re.DOTALL)

    reasoning = think_match.group(1).strip() if think_match else ""
    final_response = response_match.group(1).strip() if response_match else response.strip()

    return reasoning, final_response

# Streamed response emulator
def response_generator(response: str):
    for word in response.split():
        yield word + " "
        time.sleep(0.05)

# Load API key and endpoint ID
try:
    API_KEY = os.environ["RUNPOD_API_KEY"]
    ENDPOINT_ID = os.environ["RUNPOD_ENDPOINT_ID"]
except KeyError as e:
    st.error(f"Missing required environment variable: {str(e)}")
    st.stop()

# Initialize session state
if "messages" not in st.session_state:
    st.session_state.messages = []

# πŸ’‘ Move this block up so it runs BEFORE chat messages are rendered
# Handle sidebar early to ensure proper app flow
with st.sidebar:
    instruction = st.selectbox(
        "What's the task you want this LLM to do?",
        ("math", "SQL", "python", "medical"),
    )

    # Handle clear history BEFORE rendering messages
    if st.button("Clear History"):
        st.session_state.messages = []
        st.rerun()

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Accept user input
if prompt := st.chat_input("Enter a question based on a selected task."):
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Display user message in chat message container
    with st.chat_message("user"):
        st.markdown(prompt)

    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        # Generate answer
        raw_output = run_inference(API_KEY, ENDPOINT_ID, instruction, prompt)
        reasoning, final_response = extract_sections(raw_output)

        # Show reasoning
        with st.expander("See reasoning:", expanded=True, icon="πŸ’­"):
            st.write_stream(response_generator(reasoning))

        # Detect if it's a code block and render appropriately
        with st.expander("See response:", expanded=True, icon="βœ…"):
            if "```python" in final_response:
                st.code(final_response.split("```python", 1)[-1].strip("`"), language="python")
            elif "```sql" in final_response:
                st.code(final_response.split("```sql", 1)[-1].strip("`"), language="sql")
            else:
                st.markdown(final_response)

    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": raw_output})