Spaces:
Running
Running
Aditya Shankar
feat: added dataset recording; hf uploader, s3 uploader; runpod trainer (#6)
4384839
unverified
import { describe, it, expect, beforeEach } from "vitest"; | |
import { vi } from "vitest"; | |
import { LeRobotDatasetRecorder } from "../src/record"; | |
import { WebTeleoperator } from "../../web/src/teleoperators/base-teleoperator"; | |
// Mock the WebTeleoperator class | |
vi.mock("../../web/src/teleoperators/base-teleoperator", () => { | |
return { | |
WebTeleoperator: vi.fn().mockImplementation(() => { | |
return { | |
startRecording: vi.fn(), | |
stopRecording: vi.fn().mockResolvedValue([]), | |
clearRecording: vi.fn() | |
}; | |
}) | |
}; | |
}); | |
describe("LeRobotDatasetRecorder", () => { | |
let recorder: LeRobotDatasetRecorder; | |
beforeEach(() => { | |
// Create a new recorder instance before each test | |
// @ts-ignore | |
const mockTeleoperator = new WebTeleoperator() as unknown as WebTeleoperator; | |
const mockVideoStreams = {}; | |
recorder = new LeRobotDatasetRecorder([mockTeleoperator], mockVideoStreams, 30); | |
}); | |
describe("_interpolateAndCompleteLerobotData", () => { | |
it("should interpolate data to match the specified fps", async () => { | |
// Create test data with non-regular timestamps | |
const roughData = [ | |
{ | |
timestamp: 0, | |
action: { | |
shoulder_pan: 0, | |
shoulder_lift: 0, | |
elbow_flex: 0, | |
wrist_flex: 0, | |
wrist_roll: 0, | |
gripper: 0 | |
}, | |
"observation.state": { | |
shoulder_pan: 0, | |
shoulder_lift: 0, | |
elbow_flex: 0, | |
wrist_flex: 0, | |
wrist_roll: 0, | |
gripper: 0 | |
}, | |
episode_index: 0, | |
task_index: 0 | |
}, | |
{ | |
timestamp: 0.5, | |
action: { | |
shoulder_pan: 10, | |
shoulder_lift: 20, | |
elbow_flex: 30, | |
wrist_flex: 40, | |
wrist_roll: 50, | |
gripper: 60 | |
}, | |
"observation.state": { | |
shoulder_pan: 15, | |
shoulder_lift: 25, | |
elbow_flex: 35, | |
wrist_flex: 45, | |
wrist_roll: 55, | |
gripper: 65 | |
}, | |
episode_index: 0, | |
task_index: 0 | |
}, | |
{ | |
timestamp: 1.0, | |
action: { | |
shoulder_pan: 20, | |
shoulder_lift: 40, | |
elbow_flex: 60, | |
wrist_flex: 80, | |
wrist_roll: 100, | |
gripper: 120 | |
}, | |
"observation.state": { | |
shoulder_pan: 30, | |
shoulder_lift: 50, | |
elbow_flex: 70, | |
wrist_flex: 90, | |
wrist_roll: 110, | |
gripper: 130 | |
}, | |
episode_index: 1, // New episode | |
task_index: 0 | |
}, | |
{ | |
timestamp: 1.5, | |
action: { | |
shoulder_pan: 25, | |
shoulder_lift: 50, | |
elbow_flex: 75, | |
wrist_flex: 100, | |
wrist_roll: 125, | |
gripper: 150 | |
}, | |
"observation.state": { | |
shoulder_pan: 35, | |
shoulder_lift: 55, | |
elbow_flex: 75, | |
wrist_flex: 95, | |
wrist_roll: 115, | |
gripper: 135 | |
}, | |
episode_index: 1, // New episode | |
task_index: 0 | |
} | |
]; | |
// Set the FPS to 10 for this test | |
const fps = 10; | |
// Call the method under test | |
const result = await recorder._interpolateAndCompleteLerobotData(fps, roughData); | |
// log all the results, row by row | |
for (let i = 0; i < result.length; i++) { | |
console.log(result[i]); | |
} | |
// Verify the results | |
expect(result).toBeInstanceOf(Array); | |
expect(result.length).toBe(15); // 1.5 seconds at 10 fps = 15 frames | |
// Check the first frame | |
expect(result[0].timestamp).toBeCloseTo(0, 5); | |
expect(result[0].action).toEqual([0, 0, 0, 0, 0, 0]); | |
expect(result[0]["observation.state"]).toEqual([0, 0, 0, 0, 0, 0]); | |
expect(result[0].episode_index).toBe(0); | |
expect(result[0].task_index).toBe(0); | |
expect(result[0].frame_index).toBe(0); | |
expect(result[0].index).toBe(0); | |
// Check a middle frame (0.3 seconds) | |
const middleFrame = result[3]; | |
expect(middleFrame.timestamp).toBeCloseTo(0.3, 5); | |
// At 0.3 seconds, we're 60% between 0 and 0.5 seconds | |
// So action.shoulder_pan should be 60% of 10 = 6 | |
expect(middleFrame.action[0]).toBeCloseTo(6, 5); | |
expect(middleFrame.episode_index).toBe(0); | |
expect(middleFrame.frame_index).toBe(3); | |
expect(middleFrame.index).toBe(3); | |
// Check the frame right after the episode change | |
const episodeChangeFrame = result[5]; // 0.5 seconds | |
expect(episodeChangeFrame.timestamp).toBeCloseTo(0.5, 5); | |
expect(episodeChangeFrame.action[0]).toBeCloseTo(10, 5); | |
expect(episodeChangeFrame.episode_index).toBe(0); | |
expect(episodeChangeFrame.frame_index).toBe(5); | |
// Check the last frame before 1 second | |
const lastFrame = result[9]; // 0.9 seconds | |
expect(lastFrame.timestamp).toBeCloseTo(0.9, 5); | |
expect(lastFrame.episode_index).toBe(0); | |
expect(lastFrame.frame_index).toBe(9); // Frame index continues incrementing | |
expect(lastFrame.index).toBe(9); | |
}); | |
it("should handle episode index changes correctly", async () => { | |
// Create test data with episode changes | |
const roughData = [ | |
{ timestamp: 0.0, action: { shoulder_pan: 0 }, "observation.state": { shoulder_pan: 0 }, episode_index: 0, task_index: 0 }, | |
{ timestamp: 0.3, action: { shoulder_pan: 30 }, "observation.state": { shoulder_pan: 30 }, episode_index: 0, task_index: 0 }, | |
{ timestamp: 0.5, action: { shoulder_pan: 50 }, "observation.state": { shoulder_pan: 50 }, episode_index: 1, task_index: 0 }, // Episode change | |
{ timestamp: 0.8, action: { shoulder_pan: 80 }, "observation.state": { shoulder_pan: 80 }, episode_index: 1, task_index: 0 }, | |
{ timestamp: 1.0, action: { shoulder_pan: 100 }, "observation.state": { shoulder_pan: 100 }, episode_index: 2, task_index: 0 } // Another episode change | |
]; | |
const fps = 10; | |
const result = await recorder._interpolateAndCompleteLerobotData(fps, roughData); | |
// Check frame indices reset after episode changes | |
expect(result[0].episode_index).toBe(0); | |
expect(result[0].frame_index).toBe(0); | |
expect(result[5].episode_index).toBe(1); // After episode change | |
expect(result[5].frame_index).toBe(0); // Should reset to 0 | |
expect(result[9].episode_index).toBe(1); // After second episode change | |
expect(result[9].frame_index).toBe(4); // Frame index continues incrementing | |
}); | |
it("should handle empty or minimal data", async () => { | |
// Test with minimal data (just two points) | |
const minimalData = [ | |
{ timestamp: 0.0, action: { shoulder_pan: 0 }, "observation.state": { shoulder_pan: 0 }, episode_index: 0, task_index: 0 }, | |
{ timestamp: 0.1, action: { shoulder_pan: 10 }, "observation.state": { shoulder_pan: 10 }, episode_index: 0, task_index: 0 } | |
]; | |
const fps = 10; | |
const result = await recorder._interpolateAndCompleteLerobotData(fps, minimalData); | |
expect(result.length).toBe(1); // 0.1 seconds at 10fps = 1 frame | |
expect(result[0].timestamp).toBeCloseTo(0, 5); | |
expect(result[0].action[0]).toBeCloseTo(0, 5); | |
}); | |
}); | |
}); | |