YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
Predictive State Networks (PSN) -- Robotics Implementation
Production implementation of the PSN architecture for model-based reinforcement learning, targeting state-of-the-art performance on DMControl benchmarks.
Architecture
PSN replaces transformer/RSSM-based world models with a structured recurrent architecture built around three core innovations:
- Tri-partite state -- internal state is explicitly partitioned into memory (episodic retention), control (executive modulation), and predictive (forward model) components
- Joint predictive-control computation -- a single controller trunk generates both state-update control signals (gates, modulators) and one-step-ahead predictions
- Runtime prediction-error injection -- discrepancies between predictions and realized observations are transformed and fed back as first-class update signals, enabling online self-correction during inference
Model sizes
| Config | State size | Params (world model) | Comparable to |
|---|---|---|---|
| Small | ~1,280 | ~5M | DreamerV3-XS |
| Medium | ~2,048 | ~40-50M | DreamerV3-S/M |
| Large | ~3,584 | ~100-150M | DreamerV3-L |
| XLarge | ~6,144 | ~200-500M | DreamerV3-XL |
Custom CUDA kernel
The PGSU (Predictive Gated State Update) operation is fused into a single Triton kernel that combines sigmoid gating, softplus modulation, gated interpolation, and error injection -- eliminating 5-7 intermediate kernel launches per partition update. Both forward and backward passes are implemented for full training support.
Installation
# clone and install
cd PSN
pip install -e .
# install DMControl
pip install dm_control shimmy[dm-control]
# (optional) for logging
pip install wandb
Requires Python 3.10+, PyTorch 2.2+, and Triton 3.0+ (for fused kernel; falls back to eager mode on CPU/non-CUDA).
Training
# single task, single seed
python train.py --env walker_walk --size medium --seed 0
# with wandb logging
python train.py --env walker_walk --size medium --seed 0 --wandb
# harder task, larger model
python train.py --env humanoid_walk --size large --steps 10000000
# multi-seed benchmark
python train.py --env walker_walk --size medium --seeds 0,1,2,3,4
Benchmarking
# easy suite (4 tasks)
python benchmark.py --suite easy --size medium --seeds 0,1,2
# full suite (9 tasks, including dog)
python benchmark.py --suite hard --size large --seeds 0,1,2,3,4
# custom task list
python benchmark.py --tasks walker_walk,humanoid_run --size xlarge --seeds 0,1,2
Evaluation
python evaluate.py --checkpoint logs/walker_walk_medium_s0/final.pt --env walker_walk --episodes 100
Supported DMControl tasks
| Task | Obs dim | Action dim | Difficulty |
|---|---|---|---|
| walker_walk | 24 | 6 | Easy |
| walker_run | 24 | 6 | Easy |
| cheetah_run | 17 | 6 | Easy |
| quadruped_walk | 78 | 12 | Medium |
| quadruped_run | 78 | 12 | Medium |
| humanoid_walk | 67 | 21 | Hard |
| humanoid_run | 67 | 21 | Hard |
| dog_walk | 223 | 38 | Very hard |
| dog_run | 223 | 38 | Very hard |
Project structure
psn/
βββ core/
β βββ config.py # PSNConfig, RoboticsConfig, TrainingConfig
β βββ state.py # PSNState dataclass
β βββ networks.py # ResidualMLP, MLP, CNN encoder/decoder, symlog, twohot
β βββ controller.py # Controller module (shared trunk + multi-head)
β βββ error_processor.py # Prediction-error processing
β βββ pgsu.py # Predictive Gated State Update
β βββ psn_cell.py # PSNCell (single step) + sequence processor
βββ kernels/
β βββ fused_pgsu.py # Custom Triton kernel for fused PGSU ops
βββ robotics/
β βββ world_model.py # PSN world model (encoder + PSN + decoder + heads)
β βββ actor_critic.py # Actor, Critic, SlowCritic, return computation
β βββ planner.py # CEM and MPPI planners
β βββ agent.py # Full PSNAgent
βββ training/
βββ replay_buffer.py # Episodic replay buffer with chunk sampling
βββ losses.py # World model + actor-critic loss functions
βββ trainer.py # Training orchestration