Spaces:
Sleeping
Sleeping
import streamlit as st | |
from openai import OpenAI | |
# Initialize session state | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Function to generate system prompt based on user inputs | |
def create_system_prompt(classification_type, num_to_generate, domain, min_words, max_words, labels): | |
system_prompt = f"You are a professional {classification_type.lower()} expert. Your role is to generate exactly {num_to_generate} data examples for {domain}. " | |
system_prompt += f"Each example should consist of between {min_words} and {max_words} words. " | |
system_prompt += "Use the following labels: " + ", ".join(labels) + ". Please do not add any extra commentary or explanation. " | |
system_prompt += "Format each example like this: \nExample: <text>, Label: <label>\n" | |
return system_prompt | |
# OpenAI client setup (replace with your OpenAI API credentials) | |
client = OpenAI(api_key='YOUR_API_KEY') | |
# App title | |
st.title("Data Generation for Classification") | |
# Choice between Data Generation or Data Labeling | |
mode = st.radio("Choose Task:", ["Data Generation", "Data Labeling"]) | |
if mode == "Data Generation": | |
# Step 1: Choose Classification Type | |
classification_type = st.radio( | |
"Select Classification Type:", | |
["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] | |
) | |
# Step 2: Choose labels based on classification type | |
if classification_type == "Sentiment Analysis": | |
labels = ["Positive", "Negative", "Neutral"] | |
elif classification_type == "Binary Classification": | |
class1 = st.text_input("Enter First Class for Binary Classification") | |
class2 = st.text_input("Enter Second Class for Binary Classification") | |
labels = [class1, class2] | |
elif classification_type == "Multi-Class Classification": | |
num_classes = st.slider("Number of Classes (Max 10):", 2, 10, 3) | |
labels = [st.text_input(f"Enter Class {i+1}") for i in range(num_classes)] | |
# Step 3: Choose the domain | |
domain = st.radio( | |
"Select Domain:", | |
["Restaurant reviews", "E-commerce reviews", "Custom"] | |
) | |
if domain == "Custom": | |
domain = st.text_input("Enter Custom Domain") | |
# Step 4: Specify example length (min and max words) | |
min_words = st.slider("Minimum Words per Example", 10, 90, 20) | |
max_words = st.slider("Maximum Words per Example", 10, 90, 40) | |
# Step 5: Ask if user wants few-shot examples | |
use_few_shot = st.checkbox("Use Few-Shot Examples?") | |
few_shot_examples = [] | |
if use_few_shot: | |
num_few_shots = st.slider("Number of Few-Shot Examples (Max 5):", 1, 5, 2) | |
for i in range(num_few_shots): | |
example_text = st.text_area(f"Enter Example {i+1} Text") | |
example_label = st.selectbox(f"Select Label for Example {i+1}", labels) | |
few_shot_examples.append(f"Example: {example_text}, Label: {example_label}") | |
# Step 6: Specify the number of examples to generate | |
num_to_generate = st.number_input("Number of Examples to Generate", min_value=1, max_value=50, value=10) | |
# Step 7: Generate system prompt based on the inputs | |
system_prompt = create_system_prompt(classification_type, num_to_generate, domain, min_words, max_words, labels) | |
if st.button("Generate Examples"): | |
all_generated_examples = [] | |
remaining_examples = num_to_generate | |
with st.spinner("Generating..."): | |
while remaining_examples > 0: | |
chunk_size = min(remaining_examples, 5) | |
try: | |
# Add system and user messages to session state | |
st.session_state.messages.append({"role": "system", "content": system_prompt}) | |
# Add few-shot examples to the system prompt | |
if few_shot_examples: | |
for example in few_shot_examples: | |
st.session_state.messages.append({"role": "user", "content": example}) | |
# Stream API request to generate examples | |
stream = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": m["role"], "content": m["content"]} | |
for m in st.session_state.messages | |
], | |
temperature=0.7, | |
stream=True, | |
max_tokens=3000, | |
) | |
# Capture streamed response | |
response = "" | |
for chunk in stream: | |
if 'content' in chunk['choices'][0]['delta']: | |
response += chunk['choices'][0]['delta']['content'] | |
# Split response into individual examples by "Example: " | |
generated_examples = response.split("Example: ")[1:chunk_size+1] # Extract up to the chunk size | |
# Clean up the extracted examples | |
cleaned_examples = [f"Example {i+1}: {ex.strip()}" for i, ex in enumerate(generated_examples)] | |
# Store the new examples | |
all_generated_examples.extend(cleaned_examples) | |
remaining_examples -= chunk_size | |
except Exception as e: | |
st.error("Error during generation.") | |
st.write(e) | |
break | |
# Display all generated examples properly formatted | |
for idx, example in enumerate(all_generated_examples): | |
st.write(f"Example {idx+1}: {example.strip()}") | |
# Clear session state to avoid repetition of old prompts | |
st.session_state.messages = [] # Reset after each generation | |