Spaces:
Running
Running
Aditya Shankar
feat: added dataset recording; hf uploader, s3 uploader; runpod trainer (#6)
4384839
unverified
import React from "react"; | |
import { | |
LineChart, | |
Line, | |
XAxis, | |
YAxis, | |
CartesianGrid, | |
Tooltip, | |
Legend, | |
ResponsiveContainer | |
} from "recharts"; | |
import { NonIndexedLeRobotDatasetRow } from "@lerobot/web"; | |
interface TeleoperatorJointGraphProps { | |
frames: NonIndexedLeRobotDatasetRow[]; | |
} | |
export function TeleoperatorJointGraph({ frames }: TeleoperatorJointGraphProps) { | |
// Skip rendering if no frames | |
if (!frames || frames.length === 0) { | |
return null; | |
} | |
// Use hardcoded joint names that match the LeRobot dataset format | |
const jointNames = [ | |
"shoulder_pan", | |
"shoulder_lift", | |
"elbow_flex", | |
"wrist_flex", | |
"wrist_roll", | |
"gripper" | |
]; | |
// Generate a color palette for the joints | |
const colors = [ | |
"#8884d8", "#82ca9d", "#ffc658", "#ff8042", "#0088fe", "#00C49F", | |
"#FFBB28", "#FF8042", "#a4de6c", "#d0ed57" | |
]; | |
// Prepare data for the chart - handling arrays | |
const chartData = frames.map((frame, index) => { | |
// Create base data point with index | |
const dataPoint: any = { | |
name: index, | |
timestamp: frame.timestamp | |
}; | |
// Add action values (assuming action is an array) | |
if (Array.isArray(frame.action)) { | |
// Map each array index to the corresponding joint name | |
jointNames.forEach((jointName, i) => { | |
if (i < frame.action.length) { | |
dataPoint[`action_${jointName}`] = frame.action[i]; | |
} | |
}); | |
} | |
// Add observation state values (assuming observation.state is an array) | |
if (Array.isArray(frame["observation.state"])) { | |
// Map each array index to the corresponding joint name | |
jointNames.forEach((jointName, i) => { | |
if (i < frame["observation.state"].length) { | |
dataPoint[`state_${jointName}`] = frame["observation.state"][i]; | |
} | |
}); | |
} | |
return dataPoint; | |
}); | |
// Create lines for each joint | |
const linesToRender = jointNames.flatMap(jointName => [ | |
{ | |
key: `action_${jointName}`, | |
dataKey: `action_${jointName}`, | |
name: `Action: ${jointName}`, | |
isDotted: true | |
}, | |
{ | |
key: `state_${jointName}`, | |
dataKey: `state_${jointName}`, | |
name: `State: ${jointName}`, | |
isDotted: false | |
} | |
]); | |
return ( | |
<div className="w-full bg-gray-800/50 rounded-md p-4 mb-4"> | |
<h3 className="text-sm font-medium text-gray-300 mb-2">Joint Positions Over Time</h3> | |
<ResponsiveContainer width="100%" height={300}> | |
<LineChart | |
data={chartData} | |
margin={{ | |
top: 5, | |
right: 30, | |
left: 20, | |
bottom: 5 | |
}} | |
> | |
<CartesianGrid strokeDasharray="3 3" stroke="#444" /> | |
<XAxis | |
dataKey="name" | |
label={{ value: 'Frame Index', position: 'insideBottomRight', offset: -10 }} | |
stroke="#aaa" | |
/> | |
<YAxis stroke="#aaa" /> | |
<Tooltip | |
contentStyle={{ backgroundColor: '#333', borderColor: '#555' }} | |
labelStyle={{ color: '#eee' }} | |
itemStyle={{ color: '#eee' }} | |
/> | |
<Legend /> | |
{/* Render all lines */} | |
{linesToRender.map((lineConfig, index) => { | |
const jointName = lineConfig.dataKey.replace(/^(action|state)_/, ''); | |
const jointIndex = jointNames.indexOf(jointName); | |
const colorIndex = jointIndex >= 0 ? jointIndex : index % colors.length; | |
return ( | |
<Line | |
key={lineConfig.key} | |
type="monotone" | |
dataKey={lineConfig.dataKey} | |
name={lineConfig.name} | |
stroke={colors[colorIndex]} | |
strokeDasharray={lineConfig.isDotted ? "5 5" : undefined} | |
dot={false} | |
activeDot={{ r: 4 }} | |
/> | |
); | |
})} | |
</LineChart> | |
</ResponsiveContainer> | |
</div> | |
); | |
} | |