|
import { useEffect, useState } from 'react'; |
|
import { Database, ExternalLink, Shuffle, Lock } from 'lucide-react'; |
|
import DatasetValidator from '../components/validators/DatasetValidator'; |
|
import { DATASETS } from '../constants/datasets'; |
|
import type { JobConfig } from '../types'; |
|
|
|
type Props = { |
|
onNext: () => void; |
|
}; |
|
|
|
type ExtrasDraft = { |
|
datasetLimit?: number; |
|
}; |
|
|
|
export default function DatasetConfigPage({ onNext }: Props) { |
|
const [cfg, setCfg] = useState<JobConfig>({ |
|
dataset: '', |
|
languageModel: '', |
|
scorerModel: '', |
|
k: 5, |
|
numCounterfactuals: 3, |
|
metrictarget: 0.5, |
|
tau: 0.1, |
|
iterations: 1000, |
|
seed: 42, |
|
enableFineTuning: false, |
|
counterfactual: false, |
|
}); |
|
|
|
const [customDataset, setCustomDataset] = useState(''); |
|
const [showCustomDatasetInput, setShowCustomDatasetInput] = useState(false); |
|
|
|
const [fieldStats, setFieldStats] = useState<Record<string, Record<string, number>>>({}); |
|
const [numCounterfactuals, setNumCounterfactuals] = useState<number>(3); |
|
const [selectedCfFields, setSelectedCfFields] = useState<string[]>([]); |
|
const [isLoadingFields, setIsLoadingFields] = useState(false); |
|
const [fieldsError, setFieldsError] = useState<string | null>(null); |
|
|
|
const [metaConfigs, setMetaConfigs] = useState<string[]>([]); |
|
const [metaSplits, setMetaSplits] = useState<string[]>([]); |
|
const [selectedConfig, setSelectedConfig] = useState<string | null>(null); |
|
const [selectedSplit, setSelectedSplit] = useState<string>('train'); |
|
|
|
const setField = <K extends keyof JobConfig>(k: K, v: JobConfig[K]) => |
|
setCfg((prev) => ({ ...prev, [k]: v })); |
|
|
|
const card = |
|
'group relative rounded-2xl p-8 border border-white/30 bg-white/60 backdrop-blur-xl ' + |
|
'shadow-[0_15px_40px_-20px_rgba(30,41,59,0.35)] transition-all duration-300 ' + |
|
'hover:shadow-[0_20px_50px_-20px_rgba(79,70,229,0.45)] hover:-translate-y-0.5'; |
|
|
|
const sectionTitle = 'text-xl font-bold tracking-tight text-slate-900'; |
|
const subtext = 'text-sm text-slate-600'; |
|
const fieldInput = |
|
'w-full rounded-xl border-2 border-slate-200/70 bg-white/70 px-4 py-3 ' + |
|
'focus:outline-none focus:border-indigo-500 focus:ring-4 focus:ring-indigo-500/20 transition-all'; |
|
const selectInput = |
|
'w-full rounded-xl border-2 border-slate-200/70 bg-white/70 px-3 py-2.5 ' + |
|
'focus:outline-none focus:border-indigo-500 focus:ring-4 focus:ring-indigo-500/20 transition-all'; |
|
const choiceRow = |
|
'flex items-start gap-4 cursor-pointer p-4 rounded-xl border transition-colors ' + |
|
'bg-white/60 hover:bg-white/80 border-slate-200/60 hover:border-indigo-300'; |
|
|
|
const API_BASE = '/api'; |
|
|
|
|
|
function resolveDatasetId(id: string | null | undefined) { |
|
if (!id) return id; |
|
return id === 'example' ? 'AmazonScience/bold' : id; |
|
} |
|
|
|
function buildFieldsURL(datasetId: string, config: string | null, split: string): string { |
|
const realId = resolveDatasetId(datasetId)!; |
|
const params = new URLSearchParams(); |
|
params.set('id', realId); |
|
if (config && config.trim() !== '') params.set('config', config); |
|
if (split && split.trim() !== '') params.set('split', split); |
|
return `/dataset/fields?${params.toString()}`; |
|
} |
|
|
|
async function fetchJSON<T>(url: string, signal?: AbortSignal): Promise<T> { |
|
const fullURL = url.startsWith('http') ? url : `${API_BASE}${url}`; |
|
const res = await fetch(fullURL, { signal }); |
|
if (!res.ok) throw new Error(`${res.status} ${res.statusText}`); |
|
return (await res.json()) as T; |
|
} |
|
|
|
|
|
useEffect(() => { |
|
try { |
|
const draft = localStorage.getItem('cfgDraft'); |
|
if (draft) { |
|
const parsed = JSON.parse(draft); |
|
setCfg((prev) => ({ ...prev, ...parsed })); |
|
if (parsed.numCounterfactuals) setNumCounterfactuals(parsed.numCounterfactuals); |
|
if (parsed.selectedCfFields) setSelectedCfFields(parsed.selectedCfFields); |
|
} |
|
} catch {} |
|
}, []); |
|
|
|
|
|
useEffect(() => { |
|
setSelectedCfFields([]); |
|
setFieldsError(null); |
|
|
|
if (!cfg.dataset || cfg.dataset === 'custom') return; |
|
|
|
const ac = new AbortController(); |
|
const realId = resolveDatasetId(cfg.dataset)!; |
|
|
|
const run = async () => { |
|
try { |
|
const metaURL = `/dataset/meta?id=${encodeURIComponent(realId)}`; |
|
const meta = await fetchJSON<{ |
|
datasetId: string; |
|
configs: string[]; |
|
splits: string[]; |
|
}>(metaURL, ac.signal); |
|
|
|
setMetaConfigs(meta.configs || []); |
|
setMetaSplits(meta.splits || []); |
|
|
|
const defaultConfig = meta.configs?.length ? meta.configs[0] : null; |
|
const defaultSplit = meta.splits?.length |
|
? meta.splits.includes('train') ? 'train' : meta.splits[0] |
|
: 'train'; |
|
|
|
setSelectedConfig(defaultConfig); |
|
setSelectedSplit(defaultSplit); |
|
|
|
setIsLoadingFields(true); |
|
const fieldsURL = buildFieldsURL(cfg.dataset, defaultConfig, defaultSplit); |
|
await fetchJSON<{ fields: string[] }>(fieldsURL, ac.signal); |
|
|
|
setFieldsError(null); |
|
} catch (err: any) { |
|
setMetaConfigs([]); |
|
setMetaSplits([]); |
|
setSelectedConfig(null); |
|
setSelectedSplit('train'); |
|
|
|
const fieldsURL = buildFieldsURL(cfg.dataset, null, 'train'); |
|
setFieldsError(`(${fieldsURL}) → ${err?.message || '欄位讀取失敗'}`); |
|
} finally { |
|
setIsLoadingFields(false); |
|
} |
|
}; |
|
|
|
run(); |
|
return () => ac.abort(); |
|
}, [cfg.dataset]); |
|
|
|
|
|
useEffect(() => { |
|
if (!cfg.dataset || cfg.dataset === 'custom') return; |
|
|
|
const ac = new AbortController(); |
|
const realId = resolveDatasetId(cfg.dataset)!; |
|
|
|
const run = async () => { |
|
try { |
|
setIsLoadingFields(true); |
|
|
|
|
|
const fieldsURL = buildFieldsURL(cfg.dataset, selectedConfig, selectedSplit); |
|
await fetchJSON<{ fields: string[] }>(fieldsURL, ac.signal); |
|
|
|
|
|
const statsURL = `/dataset/field-stats?id=${encodeURIComponent(realId)}&field=domain&subfield=category`; |
|
const statsData = await fetchJSON<{ counts: Record<string, Record<string, number>> }>(statsURL, ac.signal); |
|
|
|
setFieldStats(statsData.counts || {}); |
|
setFieldsError(null); |
|
|
|
|
|
if (cfg.dataset === 'example') { |
|
const keys: string[] = []; |
|
const cats = statsData?.counts?.domain || {}; |
|
if ('American_actors' in cats) keys.push('domain/American_actors'); |
|
if ('American_actresses' in cats) keys.push('domain/American_actresses'); |
|
setSelectedCfFields(keys); |
|
} else { |
|
setSelectedCfFields([]); |
|
} |
|
} catch (err: any) { |
|
const fieldsURL = buildFieldsURL(cfg.dataset, selectedConfig, selectedSplit); |
|
setFieldStats({}); |
|
setFieldsError(`(${fieldsURL}) → ${err?.message || 'Field Read Failed'}`); |
|
} finally { |
|
setIsLoadingFields(false); |
|
} |
|
}; |
|
|
|
run(); |
|
return () => ac.abort(); |
|
}, [cfg.dataset, selectedConfig, selectedSplit]); |
|
|
|
|
|
useEffect(() => { |
|
if (cfg.dataset === 'example') { |
|
setNumCounterfactuals(20); |
|
} |
|
}, [cfg.dataset]); |
|
|
|
|
|
useEffect(() => { |
|
if (cfg.dataset !== 'example') return; |
|
const targets = new Set(['American_actors', 'American_actresses']); |
|
const keys: string[] = []; |
|
Object.entries(fieldStats).forEach(([domain, categories]) => { |
|
Object.keys(categories).forEach((cat) => { |
|
if (targets.has(cat)) keys.push(`${domain}/${cat}`); |
|
}); |
|
}); |
|
if (keys.length > 0) setSelectedCfFields(keys); |
|
}, [cfg.dataset, fieldStats]); |
|
|
|
const canNext = !!cfg.dataset; |
|
const isExample = cfg.dataset === 'example'; |
|
|
|
return ( |
|
<div className="space-y-10"> |
|
<div className="grid grid-cols-1 lg:grid-cols-6 gap-8"> |
|
{/* Dataset selection */} |
|
<div className={`${card} lg:col-span-3`}> |
|
<div className="flex items-center gap-3 mb-8"> |
|
<div className="p-3 rounded-xl bg-gradient-to-br from-indigo-600 to-fuchsia-600 shadow-md shadow-indigo-600/30"> |
|
<Database className="w-6 h-6 text-white" /> |
|
</div> |
|
<h3 className={sectionTitle}>Dataset Selection</h3> |
|
</div> |
|
|
|
<div className="space-y-4"> |
|
{DATASETS.map((dataset) => ( |
|
<label key={dataset.id} className={choiceRow}> |
|
<input |
|
type="radio" |
|
name="dataset" |
|
value={dataset.id} |
|
checked={cfg.dataset === dataset.id} |
|
onChange={(e) => { |
|
setField('dataset', e.target.value); |
|
setShowCustomDatasetInput(false); |
|
setCustomDataset(''); |
|
setSelectedCfFields([]); |
|
}} |
|
className="mt-1 accent-indigo-600" |
|
/> |
|
<div className="flex-1"> |
|
<div className="font-semibold text-slate-900">{dataset.name}</div> |
|
<div className="flex items-center gap-4 text-xs text-slate-500 mt-2"> |
|
{'entities' in dataset && ( |
|
<span>📊 {(dataset as any).entities?.toLocaleString?.() || '-' } entities</span> |
|
)} |
|
{'groups' in dataset && <span>👥 {(dataset as any).groups || '-' } groups</span>} |
|
</div> |
|
<a |
|
href={`https://huggingface.co/datasets/${dataset.id}`} |
|
target="_blank" |
|
rel="noopener noreferrer" |
|
className="inline-flex items-center gap-1 text-indigo-600 hover:text-indigo-700 text-xs font-medium mt-2" |
|
onClick={(e) => e.stopPropagation()} |
|
> |
|
<ExternalLink className="w-3.5 h-3.5" /> |
|
View on Hugging Face |
|
</a> |
|
</div> |
|
</label> |
|
))} |
|
|
|
{/* Example dataset (preconfigured) */} |
|
<label className={choiceRow}> |
|
<input |
|
type="radio" |
|
name="dataset" |
|
value="example" |
|
checked={cfg.dataset === 'example'} |
|
onChange={(e) => { |
|
setField('dataset', e.target.value); |
|
setShowCustomDatasetInput(false); |
|
setCustomDataset(''); |
|
setSelectedCfFields([]); // 將由 effect 根據 fieldStats 自動填 |
|
setNumCounterfactuals(20); |
|
}} |
|
className="mt-1 accent-violet-600" |
|
/> |
|
<div className="flex-1"> |
|
<div className="font-semibold text-slate-900 flex items-center gap-2"> |
|
🧪 Example |
|
{isExample && ( |
|
<span className="inline-flex items-center gap-1 text-[10px] font-semibold px-2 py-0.5 rounded-full bg-slate-900 text-white"> |
|
<Lock className="w-3 h-3" /> CF fields locked |
|
</span> |
|
)} |
|
</div> |
|
</div> |
|
</label> |
|
|
|
{/* Custom dataset */} |
|
<label className={choiceRow}> |
|
<input |
|
type="radio" |
|
name="dataset" |
|
value="custom" |
|
checked={cfg.dataset === 'custom'} |
|
onChange={(e) => { |
|
setField('dataset', e.target.value); |
|
setShowCustomDatasetInput(true); |
|
setSelectedCfFields([]); |
|
}} |
|
className="mt-1 accent-fuchsia-600" |
|
/> |
|
<div className="flex-1"> |
|
<div className="font-semibold text-slate-900">🔧 Custom Dataset Upload from Hugging Face</div> |
|
</div> |
|
</label> |
|
|
|
{showCustomDatasetInput && ( |
|
<div className="pl-6 space-y-3 animate-in slide-in-from-top duration-300"> |
|
<input |
|
type="text" |
|
placeholder="Input Hugging Face Dataset ID (e.g. AmazonScience/bold)" |
|
value={customDataset} |
|
onChange={(e) => { |
|
setCustomDataset(e.target.value); |
|
setField('dataset', e.target.value); |
|
}} |
|
className={fieldInput} |
|
/> |
|
{customDataset && customDataset.includes('/') && ( |
|
<DatasetValidator datasetId={customDataset} /> |
|
)} |
|
</div> |
|
)} |
|
|
|
{cfg.dataset === 'AmazonScience/bold' && !showCustomDatasetInput && ( |
|
<DatasetValidator datasetId="AmazonScience/bold" /> |
|
)} |
|
</div> |
|
</div> |
|
|
|
{/* Counterfactual */} |
|
<div className={`${card} lg:col-span-3`}> |
|
<div className="flex items-center gap-3 mb-8"> |
|
<div className="p-3 rounded-xl bg-gradient-to-br from-pink-600 to-rose-600 shadow-md shadow-pink-600/30"> |
|
<Shuffle className="w-6 h-6 text-white" /> |
|
</div> |
|
<h3 className={sectionTitle}>Counterfactual Setting</h3> |
|
</div> |
|
|
|
<div className="space-y-6"> |
|
<div className="pt-2"> |
|
<label className="block text-sm font-semibold text-slate-800 mb-1"> |
|
Number of Counterfactual |
|
</label> |
|
<input |
|
type="number" |
|
min={1} |
|
max={20} |
|
step={1} |
|
value={isExample ? 20 : numCounterfactuals} |
|
onChange={(e) => { |
|
if (isExample) return; // Example 時忽略修改 |
|
const v = parseInt(e.target.value || '3', 10); |
|
setNumCounterfactuals(Number.isFinite(v) ? Math.max(1, Math.min(20, v)) : 3); |
|
}} |
|
disabled={isExample} |
|
className={fieldInput + (isExample ? ' cursor-not-allowed opacity-80' : '')} |
|
/> |
|
{isExample && ( |
|
<div className="text-[11px] mt-1 text-slate-500 flex items-center gap-1"> |
|
<Lock className="w-3 h-3" /> Locked to 20 for the Example preset. |
|
</div> |
|
)} |
|
</div> |
|
|
|
{(metaConfigs.length > 0 || metaSplits.length > 0) && ( |
|
<div className="grid grid-cols-1 sm:grid-cols-2 gap-4"> |
|
{metaConfigs.length > 0 && ( |
|
<div> |
|
<label className="block text-sm font-semibold text-slate-800 mb-1">Dataset Config</label> |
|
<select |
|
value={selectedConfig || ''} |
|
onChange={(e) => setSelectedConfig(e.target.value || null)} |
|
className={selectInput} |
|
> |
|
{metaConfigs.map((c) => ( |
|
<option key={c} value={c}>{c}</option> |
|
))} |
|
</select> |
|
</div> |
|
)} |
|
|
|
{metaSplits.length > 0 && ( |
|
<div> |
|
<label className="block text-sm font-semibold text-slate-800 mb-1">Split</label> |
|
<select |
|
value={selectedSplit} |
|
onChange={(e) => setSelectedSplit(e.target.value)} |
|
className={selectInput} |
|
> |
|
{metaSplits.map((s) => ( |
|
<option key={s} value={s}>{s}</option> |
|
))} |
|
</select> |
|
</div> |
|
)} |
|
</div> |
|
)} |
|
|
|
<div className="text-xs text-slate-500 flex items-center gap-2"> |
|
<span>Selected Dataset</span> |
|
<span className="inline-flex items-center rounded-full bg-slate-800/90 text-white px-2.5 py-1"> |
|
{cfg.dataset || 'Not Selected Yet'} |
|
</span> |
|
{selectedConfig && <span className="ml-1">/ {selectedConfig}</span>} |
|
{selectedSplit && <span className="ml-1">/ {selectedSplit}</span>} |
|
</div> |
|
|
|
{/* Optional fields (domain/category) */} |
|
<div> |
|
<div className="flex items-center justify-between mb-2"> |
|
<div className="text-sm font-semibold text-slate-800">Optional fields</div> |
|
{isExample && ( |
|
<span className="inline-flex items-center gap-1 text-[10px] font-semibold px-2 py-0.5 rounded-full bg-slate-900 text-white"> |
|
<Lock className="w-3 h-3" /> Locked by Example |
|
</span> |
|
)} |
|
{isLoadingFields && <span className="text-xs text-slate-500">Loading</span>} |
|
</div> |
|
|
|
{!!fieldsError && ( |
|
<div className="text-xs text-rose-600 mb-2">{fieldsError}</div> |
|
)} |
|
|
|
<div className="space-y-4 max-h-64 overflow-auto pr-1"> |
|
{Object.entries(fieldStats).map(([domain, categories]) => ( |
|
<div key={domain} className="bg-white/50 border border-slate-200 rounded-xl p-3 shadow-sm"> |
|
<div className="font-semibold text-slate-700 text-sm mb-2">{domain}</div> |
|
<div className="grid grid-cols-1 sm:grid-cols-2 gap-x-4 gap-y-2 pl-1"> |
|
{Object.entries(categories).map(([category, count]) => { |
|
const fieldKey = `${domain}/${category}`; |
|
const checked = selectedCfFields.includes(fieldKey); |
|
const locked = isExample && (category === 'American_actors' || category === 'American_actresses'); |
|
return ( |
|
<label |
|
key={fieldKey} |
|
className={`flex items-center gap-2 text-sm text-slate-800 px-2 py-1 rounded-md transition-colors ${ |
|
isExample ? 'opacity-80 cursor-not-allowed' : 'hover:bg-white/60' |
|
}`} |
|
> |
|
<input |
|
type="checkbox" |
|
checked={checked} |
|
disabled={isExample} |
|
onChange={() => { |
|
if (isExample) return; // Example 時不可手動更改 |
|
setSelectedCfFields((prev) => |
|
checked ? prev.filter((x) => x !== fieldKey) : [...prev, fieldKey] |
|
); |
|
}} |
|
className="accent-fuchsia-600" |
|
/> |
|
<span>{category}</span> |
|
<span className="text-xs text-slate-500">({count})</span> |
|
{locked && ( |
|
<span className="ml-1 text-[10px] px-1.5 py-0.5 rounded bg-slate-900 text-white">locked</span> |
|
)} |
|
</label> |
|
); |
|
})} |
|
</div> |
|
</div> |
|
))} |
|
</div> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
</div> |
|
|
|
{/* Next */} |
|
<div className="flex"> |
|
<button |
|
onClick={() => { |
|
// 將 Dataset/Counterfactual 的草稿寫入 localStorage,供 Model 頁讀取 |
|
const draft = { |
|
...cfg, |
|
selectedCfFields, |
|
numCounterfactuals: isExample ? 20 : numCounterfactuals, // 例項下強制 20 |
|
// 保留使用者選的 meta config / split |
|
datasetConfig: selectedConfig, |
|
datasetSplit: selectedSplit, |
|
} as any; |
|
localStorage.setItem('cfgDraft', JSON.stringify(draft)); |
|
|
|
// 保留 extras(目前 Dataset 頁沒有調整 extras) |
|
const extrasDraft: ExtrasDraft = JSON.parse(localStorage.getItem('extrasDraft') || '{}'); |
|
localStorage.setItem('extrasDraft', JSON.stringify(extrasDraft)); |
|
|
|
onNext(); |
|
}} |
|
disabled={!canNext} |
|
className="relative w-full group overflow-hidden rounded-2xl px-6 py-4 text-white font-semibold bg-gradient-to-r from-indigo-600 via-violet-600 to-fuchsia-600 shadow-lg shadow-indigo-600/20 enabled:hover:shadow-indigo-600/40 transition-all enabled:hover:translate-y-[-1px] enabled:active:translate-y-0 disabled:opacity-60 disabled:cursor-not-allowed" |
|
> |
|
<span className="relative z-10">Next: Model Config</span> |
|
<span className="absolute inset-0 opacity-0 group-hover:opacity-100 transition-opacity bg-[radial-gradient(1200px_200px_at_50%_-40%,rgba(255,255,255,0.35),transparent_60%)]" /> |
|
</button> |
|
</div> |
|
</div> |
|
); |
|
} |
|
|