import React, { useState } from "react" import { PiFinnTheHumanLight, PiFinnTheHumanBold, PiFinnTheHumanFill, PiFinnTheHumanDuotone, } from "react-icons/pi" import { Play } from "lucide-react" import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, } from "@/components/ui/select" import { Button } from "@/components/ui/button" import { Tabs, TabsList, TabsTrigger, TabsContent } from "@/components/ui/tabs" import { Footer } from "@/components/footer" import { Header } from "@/components/header" // ---------------------------------------------------- // Category + Prompt Mapping // ---------------------------------------------------- const categories = [ { value: "lgbtq+", label: "LGBT+", prompt: "The gay people are" }, { value: "nationality", label: "Nationality", prompt: "The French people are" }, { value: "race", label: "Race", prompt: "The Black people are" }, ] // ---------------------------------------------------- // Minimal card // ---------------------------------------------------- function ModelCard({ Icon, title, text, streaming, }: { Icon: React.ElementType title: string text: string streaming?: boolean }) { return (

{title}

{/* If streaming == true, we append "●" at the end to mimic a typing indicator */}

{text} {streaming && "●"} {/* // ⏺ or ⬤ or ● */}

) } // ---------------------------------------------------- // Tab Panel that holds 4 model cards + "Play" button // ---------------------------------------------------- function TabPanel({ datasetKey, modelKey, categoryKey, prompt, }: { datasetKey: string modelKey: string categoryKey: string prompt: string }) { // These are the four generation “modes” in sequence const modelSequence = [ { type: "original", title: "Original Model", key: "origin", icon: PiFinnTheHumanLight }, { type: "origin+steer", title: "Bias Amplified Steering", key: "origin+steer", icon: PiFinnTheHumanBold }, { type: "trained", title: "Bias Trained Model", key: "trained", icon: PiFinnTheHumanFill }, { type: "trained-steer", title: "Bias Mitigated Steering", key: "trained-steer", icon: PiFinnTheHumanDuotone }, ] // Holds the partial or final text for each of the 4 slots const [outputs, setOutputs] = useState(["", "", "", ""]) // Which slot is currently streaming? -1 if none const [activeIndex, setActiveIndex] = useState(-1) // Helper to fetch in streaming chunks async function fetchInChunks(genType: string, index: number) { const payload = { model: modelKey, dataset: datasetKey, category: categoryKey, type: genType, } const apiBaseUrl = import.meta.env.VITE_API_BASE_URL || "" const response = await fetch(`${apiBaseUrl}/api/generate`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify(payload), }) // Stream the response const reader = response.body?.getReader() if (!reader) return const decoder = new TextDecoder("utf-8") let partial = "" while (true) { const { done, value } = await reader.read() if (done) break // Decode chunk and update partial text partial += decoder.decode(value, { stream: true }) // Update outputs[i] in real-time setOutputs((prev) => { const copy = [...prev] copy[index] = partial return copy }) } } // Called on "Play" async function handlePlay() { // Reset everything setOutputs(["", "", "", ""]) setActiveIndex(-1) // Stream each model's text in sequence for (let i = 0; i < modelSequence.length; i++) { setActiveIndex(i) await fetchInChunks(modelSequence[i].type, i) setActiveIndex(-1) // or keep streaming indicator until next loop } } return (
{modelSequence.map((seq, i) => { const Icon = seq.icon return ( ) })}
) } // ---------------------------------------------------- // Main App // ---------------------------------------------------- export default function App() { const [dataset, setDataset] = useState("Bias (EMGSD)") const [model, setModel] = useState("GPT-2") const [category, setCategory] = useState(categories[0].value) // Convert front-end selection to server keys const datasetKey = dataset === "Bias (EMGSD)" ? "emgsd" : "emgsd" const modelKey = model === "GPT-2" ? "gpt2" : "gpt2" return (
{/* Main content */}

CorrSteer

Text Classification dataset can be used to Steer LLMs,
Correlating with SAE features

{/* Dropdowns */}
{/* Tabs: 3 categories -> each has its own content */} {categories.map((cat) => ( {cat.label} ))} {categories.map((cat) => ( ))}
) }