Spaces:
Sleeping
Sleeping
Anonymous-COFFEE
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain_community.llms import OpenAI
|
3 |
+
import argparse
|
4 |
+
from datasets import load_dataset
|
5 |
+
import yaml
|
6 |
+
from tqdm import tqdm
|
7 |
+
import re
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def load_data(split="test"):
|
12 |
+
data = load_dataset("bigcode/humanevalpack")
|
13 |
+
print("=========== dataset statistics ===========")
|
14 |
+
print(len(data[split]))
|
15 |
+
print("==========================================")
|
16 |
+
return data[split]
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def split_function_header_and_docstring(s):
|
21 |
+
# pattern = re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL)
|
22 |
+
pattern = re.compile(r"(\"\"\"(.*?)\"\"\"|\'\'\'(.*?)\'\'\')", re.DOTALL)
|
23 |
+
match = pattern.findall(s)
|
24 |
+
if match:
|
25 |
+
# docstring = match.group(-1)
|
26 |
+
docstring = match[-1][0]
|
27 |
+
code_without_docstring = s.replace(docstring, "").replace('"' * 6, "").strip()
|
28 |
+
docstring = docstring.replace('"', "")
|
29 |
+
else:
|
30 |
+
raise ValueError
|
31 |
+
return code_without_docstring, docstring
|
32 |
+
|
33 |
+
|
34 |
+
def prepare_model_input(code_data):
|
35 |
+
prompt = """Provide feedback on the errors in the given code and suggest the
|
36 |
+
correct code to address the described problem.
|
37 |
+
|
38 |
+
Problem Description:
|
39 |
+
{description}
|
40 |
+
|
41 |
+
Incorrect Code:
|
42 |
+
{wrong_code}"""
|
43 |
+
|
44 |
+
description = code_data["prompt"]
|
45 |
+
function_header, docstring = split_function_header_and_docstring(description)
|
46 |
+
problem = docstring.split(">>>")[0]
|
47 |
+
|
48 |
+
wrong_code = function_header + code_data["buggy_solution"]
|
49 |
+
template_dict = {"function_header": function_header, "description": problem, "wrong_code": wrong_code}
|
50 |
+
model_input = prompt.format(**template_dict)
|
51 |
+
return model_input, problem, function_header
|
52 |
+
|
53 |
+
|
54 |
+
def load_and_prepare_data():
|
55 |
+
dataset = load_data()
|
56 |
+
all_model_inputs = {}
|
57 |
+
print("### load and prepare data")
|
58 |
+
for data in tqdm(dataset):
|
59 |
+
problem_id = data['task_id']
|
60 |
+
buggy_solution = data['buggy_solution']
|
61 |
+
model_input, problem, function_header = prepare_model_input(data)
|
62 |
+
new_model_input =f"Provide feedback on the errors in the given code and suggest the correct code to address the described problem.\nProblem Description:{problem}\nIncorrect Code:\n{buggy_solution}\nFeedback:"
|
63 |
+
# data["header"] = function_header
|
64 |
+
all_model_inputs[problem_id] = {
|
65 |
+
"model_input": new_model_input,
|
66 |
+
"header": function_header,
|
67 |
+
"problem_description": problem,
|
68 |
+
"data": data
|
69 |
+
}
|
70 |
+
return all_model_inputs
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
dataset = load_dataset("bigcode/humanevalpack", split='test', trust_remote_code=True) # Ensuring consistent split usage
|
76 |
+
|
77 |
+
problem_ids = [problem['task_id'] for problem in dataset]
|
78 |
+
all_model_inputs = load_and_prepare_data()
|
79 |
+
|
80 |
+
|
81 |
+
# Initialize with dummy ports for demonstration purposes here
|
82 |
+
parser = argparse.ArgumentParser()
|
83 |
+
parser.add_argument("--editor_port", type=str, default="6000")
|
84 |
+
parser.add_argument("--critic_port", type=str, default="6001")
|
85 |
+
|
86 |
+
# Assuming args are passed via command line interface
|
87 |
+
args = parser.parse_args()
|
88 |
+
|
89 |
+
# Initialize Langchain LLMs for our models (please replace 'your_api_key' with actual API keys)
|
90 |
+
editor_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-editor", api_key="EMPTY", openai_api_base=f"https://editor.jp.ngrok.io/v1")
|
91 |
+
# critic_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-critic", api_key="EMPTY", openai_api_base=f"http://localhost:{args.critic_port}/v1")
|
92 |
+
|
93 |
+
critic_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-critic", api_key="EMPTY", openai_api_base=f"https://critic.jp.ngrok.io/v1")
|
94 |
+
|
95 |
+
st.title("Demo for COFFEEPOTS")
|
96 |
+
|
97 |
+
selected_task_id = st.selectbox("Select a problem ID:", problem_ids)
|
98 |
+
|
99 |
+
# Retrieve selected problem details
|
100 |
+
problem_details = dataset[problem_ids.index(selected_task_id)]
|
101 |
+
|
102 |
+
st.write(f"**Selected Problem ID:** {problem_details['task_id']}")
|
103 |
+
st.write(f"**Problem Description:**\n{all_model_inputs[selected_task_id]['problem_description']}")
|
104 |
+
# Display buggy code with syntax highlighting
|
105 |
+
st.code(problem_details['buggy_solution'], language='python')
|
106 |
+
|
107 |
+
status_text = st.empty()
|
108 |
+
code_output = st.code("", language="python")
|
109 |
+
|
110 |
+
def generate_feedback():
|
111 |
+
|
112 |
+
return critic_model.stream(input=f"{all_model_inputs[selected_task_id]['model_input']}", logit_bias=None)
|
113 |
+
# feedback = output.generations[0][0].text
|
114 |
+
# return feedback
|
115 |
+
# def generate_corrected_code():
|
116 |
+
# return "```python"+editor_model.stream(input=f"Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}", logit_bias=None)
|
117 |
+
def generate_corrected_code():
|
118 |
+
# Stream output from the editor model
|
119 |
+
yield "```python"
|
120 |
+
for text_chunk in editor_model.stream(input=f"[INST]Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}[/INST]", logit_bias=None):
|
121 |
+
yield text_chunk # Assuming each chunk is part of the final code
|
122 |
+
|
123 |
+
yield "```"
|
124 |
+
# time.sleep(0.02) # Simulate processing delay; Adjust timing as necessary
|
125 |
+
|
126 |
+
|
127 |
+
if st.button("Generate Feedback and Corrected Code"):
|
128 |
+
# Example of generating feedback and corrected code (replace these with actual model calls)
|
129 |
+
with st.spinner("Generating feedback..."):
|
130 |
+
# Simulate API call to critic_model
|
131 |
+
print(f"model input for critic:")
|
132 |
+
print(all_model_inputs[selected_task_id]['model_input'])
|
133 |
+
# output = critic_model.generate(prompts=[f"{all_model_inputs[selected_task_id]['model_input']}"], logit_bias=None)
|
134 |
+
# feedback = output.generations[0][0].text
|
135 |
+
# print(feedback)
|
136 |
+
# feedback = "dummy feedback"
|
137 |
+
|
138 |
+
# status_text.markdown(f"{feedback}")
|
139 |
+
feedback = status_text.write_stream(generate_feedback())
|
140 |
+
# status_text.code(f"{feedback}", language='python')
|
141 |
+
|
142 |
+
with st.spinner("Generating corrected code..."):
|
143 |
+
# Simulate API call to editor_model
|
144 |
+
# output = editor_model.generate(prompts=[f"Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}"], logit_bias=None)
|
145 |
+
# corrected_code = output.generations[0][0].text
|
146 |
+
# print(corrected_code)
|
147 |
+
# corrected_code = "dummy code"
|
148 |
+
# st.write("**Corrected Code:**")
|
149 |
+
corrected_code = code_output.write_stream(generate_corrected_code())
|
150 |
+
# code_output.code(corrected_code, language='python')
|