| |
| """PSN Robotics -- systematic benchmarking across DMControl tasks. |
| |
| Runs all specified tasks × sizes × seeds and produces a summary table |
| for comparison against DreamerV3 / TD-MPC2 / other baselines. |
| |
| Usage: |
| python benchmark.py --tasks walker_walk,cheetah_run,humanoid_walk --size medium --seeds 0,1,2 |
| python benchmark.py --suite easy --size medium --seeds 0,1,2,3,4 |
| python benchmark.py --suite hard --size large --seeds 0,1,2 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
| TASK_SUITES = { |
| "easy": [ |
| "walker_walk", |
| "walker_run", |
| "cheetah_run", |
| "quadruped_walk", |
| ], |
| "medium": [ |
| "walker_walk", |
| "walker_run", |
| "cheetah_run", |
| "quadruped_walk", |
| "quadruped_run", |
| "humanoid_walk", |
| ], |
| "hard": [ |
| "walker_walk", |
| "walker_run", |
| "cheetah_run", |
| "quadruped_walk", |
| "quadruped_run", |
| "humanoid_walk", |
| "humanoid_run", |
| "dog_walk", |
| "dog_run", |
| ], |
| } |
|
|
| |
| DREAMER_V3_REFS = { |
| "walker_walk": 955, |
| "walker_run": 750, |
| "cheetah_run": 850, |
| "quadruped_walk": 900, |
| "quadruped_run": 700, |
| "humanoid_walk": 650, |
| "humanoid_run": 400, |
| "dog_walk": 500, |
| "dog_run": 350, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="PSN Robotics Benchmark") |
| parser.add_argument("--tasks", type=str, default=None, |
| help="Comma-separated task list") |
| parser.add_argument("--suite", type=str, default=None, |
| choices=list(TASK_SUITES.keys()), |
| help="Predefined task suite") |
| parser.add_argument("--size", type=str, default="medium", |
| choices=["small", "medium", "large", "xlarge", "serious_v1"]) |
| parser.add_argument("--seeds", type=str, default="0,1,2", |
| help="Comma-separated seeds") |
| parser.add_argument("--steps", type=int, default=1_000_000) |
| parser.add_argument("--wandb", action="store_true") |
| parser.add_argument("--output", type=str, default="benchmark_results.json") |
| args = parser.parse_args() |
|
|
| if args.tasks: |
| tasks = args.tasks.split(",") |
| elif args.suite: |
| tasks = TASK_SUITES[args.suite] |
| else: |
| tasks = TASK_SUITES["easy"] |
|
|
| seeds = [int(s) for s in args.seeds.split(",")] |
|
|
| results = {} |
| t0 = time.time() |
|
|
| for task in tasks: |
| print(f"\n{'#'*70}") |
| print(f" TASK: {task} | SIZE: {args.size} | SEEDS: {seeds}") |
| print(f"{'#'*70}") |
|
|
| task_returns = [] |
| for seed in seeds: |
| cmd_args = argparse.Namespace( |
| env=task, |
| size=args.size, |
| seed=seed, |
| seeds=None, |
| steps=args.steps, |
| wandb=args.wandb, |
| device=None, |
| ) |
| from train import run |
| result = run(cmd_args) |
| task_returns.append(result["eval/return_mean"]) |
|
|
| mean = float(np.mean(task_returns)) |
| std = float(np.std(task_returns)) |
| ref = DREAMER_V3_REFS.get(task, None) |
|
|
| results[task] = { |
| "mean": mean, |
| "std": std, |
| "seeds": task_returns, |
| "dreamerv3_ref": ref, |
| "vs_dreamerv3": f"{mean/ref*100:.1f}%" if ref else "N/A", |
| } |
|
|
| elapsed = time.time() - t0 |
|
|
| |
| print(f"\n{'='*80}") |
| print(f" PSN ROBOTICS BENCHMARK RESULTS ({args.size}, {len(seeds)} seeds, {args.steps:,} steps)") |
| print(f" Total time: {elapsed/3600:.1f} hours") |
| print(f"{'='*80}") |
| print(f"{'Task':<20} {'PSN':>12} {'DreamerV3':>12} {'vs DV3':>10}") |
| print(f"{'-'*54}") |
| for task, r in results.items(): |
| ref_str = f"{r['dreamerv3_ref']}" if r["dreamerv3_ref"] else "N/A" |
| print(f"{task:<20} {r['mean']:>8.1f}±{r['std']:>3.1f} {ref_str:>12} {r['vs_dreamerv3']:>10}") |
|
|
| |
| psn_means = [r["mean"] for r in results.values()] |
| ref_means = [r["dreamerv3_ref"] for r in results.values() if r["dreamerv3_ref"]] |
| if ref_means: |
| ratio = np.mean(psn_means[:len(ref_means)]) / np.mean(ref_means) |
| print(f"\n Overall PSN / DreamerV3 ratio: {ratio*100:.1f}%") |
|
|
| |
| output_path = Path(args.output) |
| with open(output_path, "w") as f: |
| json.dump({ |
| "config": {"size": args.size, "seeds": seeds, "steps": args.steps}, |
| "results": results, |
| "elapsed_hours": elapsed / 3600, |
| }, f, indent=2) |
| print(f"\n Results saved to {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|