YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

Predictive State Networks (PSN) -- Robotics Implementation

Production implementation of the PSN architecture for model-based reinforcement learning, targeting state-of-the-art performance on DMControl benchmarks.

Architecture

PSN replaces transformer/RSSM-based world models with a structured recurrent architecture built around three core innovations:

  1. Tri-partite state -- internal state is explicitly partitioned into memory (episodic retention), control (executive modulation), and predictive (forward model) components
  2. Joint predictive-control computation -- a single controller trunk generates both state-update control signals (gates, modulators) and one-step-ahead predictions
  3. Runtime prediction-error injection -- discrepancies between predictions and realized observations are transformed and fed back as first-class update signals, enabling online self-correction during inference

Model sizes

Config State size Params (world model) Comparable to
Small ~1,280 ~5M DreamerV3-XS
Medium ~2,048 ~40-50M DreamerV3-S/M
Large ~3,584 ~100-150M DreamerV3-L
XLarge ~6,144 ~200-500M DreamerV3-XL

Custom CUDA kernel

The PGSU (Predictive Gated State Update) operation is fused into a single Triton kernel that combines sigmoid gating, softplus modulation, gated interpolation, and error injection -- eliminating 5-7 intermediate kernel launches per partition update. Both forward and backward passes are implemented for full training support.

Installation

# clone and install
cd PSN
pip install -e .

# install DMControl
pip install dm_control shimmy[dm-control]

# (optional) for logging
pip install wandb

Requires Python 3.10+, PyTorch 2.2+, and Triton 3.0+ (for fused kernel; falls back to eager mode on CPU/non-CUDA).

Training

# single task, single seed
python train.py --env walker_walk --size medium --seed 0

# with wandb logging
python train.py --env walker_walk --size medium --seed 0 --wandb

# harder task, larger model
python train.py --env humanoid_walk --size large --steps 10000000

# multi-seed benchmark
python train.py --env walker_walk --size medium --seeds 0,1,2,3,4

Benchmarking

# easy suite (4 tasks)
python benchmark.py --suite easy --size medium --seeds 0,1,2

# full suite (9 tasks, including dog)
python benchmark.py --suite hard --size large --seeds 0,1,2,3,4

# custom task list
python benchmark.py --tasks walker_walk,humanoid_run --size xlarge --seeds 0,1,2

Evaluation

python evaluate.py --checkpoint logs/walker_walk_medium_s0/final.pt --env walker_walk --episodes 100

Supported DMControl tasks

Task Obs dim Action dim Difficulty
walker_walk 24 6 Easy
walker_run 24 6 Easy
cheetah_run 17 6 Easy
quadruped_walk 78 12 Medium
quadruped_run 78 12 Medium
humanoid_walk 67 21 Hard
humanoid_run 67 21 Hard
dog_walk 223 38 Very hard
dog_run 223 38 Very hard

Project structure

psn/
β”œβ”€β”€ core/
β”‚   β”œβ”€β”€ config.py           # PSNConfig, RoboticsConfig, TrainingConfig
β”‚   β”œβ”€β”€ state.py            # PSNState dataclass
β”‚   β”œβ”€β”€ networks.py         # ResidualMLP, MLP, CNN encoder/decoder, symlog, twohot
β”‚   β”œβ”€β”€ controller.py       # Controller module (shared trunk + multi-head)
β”‚   β”œβ”€β”€ error_processor.py  # Prediction-error processing
β”‚   β”œβ”€β”€ pgsu.py             # Predictive Gated State Update
β”‚   └── psn_cell.py         # PSNCell (single step) + sequence processor
β”œβ”€β”€ kernels/
β”‚   └── fused_pgsu.py       # Custom Triton kernel for fused PGSU ops
β”œβ”€β”€ robotics/
β”‚   β”œβ”€β”€ world_model.py      # PSN world model (encoder + PSN + decoder + heads)
β”‚   β”œβ”€β”€ actor_critic.py     # Actor, Critic, SlowCritic, return computation
β”‚   β”œβ”€β”€ planner.py          # CEM and MPPI planners
β”‚   └── agent.py            # Full PSNAgent
└── training/
    β”œβ”€β”€ replay_buffer.py    # Episodic replay buffer with chunk sampling
    β”œβ”€β”€ losses.py           # World model + actor-critic loss functions
    └── trainer.py          # Training orchestration
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support