qp-rnn-demo / app.py
mcfloundinho's picture
Initial QP-RNN interactive demo for Hugging Face Spaces
fada987
#!/usr/bin/env python3
"""
Gradio app for QP-RNN interactive demo.
Suitable for deployment on Hugging Face Spaces.
"""
import gradio as gr
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
from io import BytesIO
import base64
class MinimalQPRNN(torch.nn.Module):
"""Minimal QP-RNN for demonstration."""
def __init__(self, position_gain=3.0, velocity_gain=1.5, control_cost=10.0):
super().__init__()
self.P = torch.tensor([[control_cost]], dtype=torch.float32)
self.K = torch.tensor([position_gain, velocity_gain], dtype=torch.float32)
def forward(self, state, reference=None):
if reference is None:
reference = torch.zeros_like(state)
error = state - reference
q = torch.sum(self.K * error, dim=-1, keepdim=True)
u_unconstrained = -q / self.P
u = torch.clamp(u_unconstrained, -1.0, 1.0)
return u
def simulate_system(position_gain, velocity_gain, control_cost,
initial_position, initial_velocity,
target_position, simulation_time):
"""Run simulation with given parameters."""
# Create controller
controller = MinimalQPRNN(position_gain, velocity_gain, control_cost)
# Setup
dt = 0.05
T = int(simulation_time / dt)
x0 = torch.tensor([initial_position, initial_velocity])
x_ref = torch.tensor([target_position, 0.0])
# Simulate
states = [x0.numpy()]
controls = []
x = x0.clone()
for t in range(T):
u = controller(x, x_ref)
x_next = torch.zeros_like(x)
x_next[0] = x[0] + x[1] * dt
x_next[1] = x[1] + u.item() * dt
states.append(x_next.numpy())
controls.append(u.item())
x = x_next
return np.array(states), np.array(controls), dt
def create_plots(states, controls, dt):
"""Create visualization plots."""
time = np.arange(len(states)) * dt
time_control = time[:-1]
# Create figure with subplots
fig = plt.figure(figsize=(12, 10))
# Position subplot
ax1 = plt.subplot(3, 2, 1)
ax1.plot(time, states[:, 0], 'b-', linewidth=2)
ax1.axhline(y=states[-1, 0], color='r', linestyle='--', alpha=0.5)
ax1.set_ylabel('Position')
ax1.set_title('Position vs Time')
ax1.grid(True, alpha=0.3)
# Velocity subplot
ax2 = plt.subplot(3, 2, 2)
ax2.plot(time, states[:, 1], 'g-', linewidth=2)
ax2.axhline(y=0, color='r', linestyle='--', alpha=0.5)
ax2.set_ylabel('Velocity')
ax2.set_title('Velocity vs Time')
ax2.grid(True, alpha=0.3)
# Control subplot
ax3 = plt.subplot(3, 2, 3)
ax3.plot(time_control, controls, 'r-', linewidth=2)
ax3.axhline(y=1, color='k', linestyle=':', alpha=0.5)
ax3.axhline(y=-1, color='k', linestyle=':', alpha=0.5)
ax3.set_ylabel('Control Input')
ax3.set_xlabel('Time (s)')
ax3.set_title('Control Input vs Time')
ax3.grid(True, alpha=0.3)
ax3.set_ylim(-1.2, 1.2)
# Phase portrait
ax4 = plt.subplot(3, 2, 4)
ax4.plot(states[:, 0], states[:, 1], 'b-', linewidth=2)
ax4.scatter([states[0, 0]], [states[0, 1]], color='green', s=100, marker='o', label='Start')
ax4.scatter([states[-1, 0]], [states[-1, 1]], color='red', s=100, marker='x', label='End')
ax4.set_xlabel('Position')
ax4.set_ylabel('Velocity')
ax4.set_title('Phase Portrait')
ax4.legend()
ax4.grid(True, alpha=0.3)
# QP visualization
ax5 = plt.subplot(3, 2, 5)
# Show how control saturates
time_saturated = np.sum(np.abs(controls) >= 0.99) / len(controls) * 100
labels = ['Saturated', 'Unsaturated']
sizes = [time_saturated, 100 - time_saturated]
colors = ['red', 'blue']
ax5.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%')
ax5.set_title('Control Saturation')
# Metrics text
ax6 = plt.subplot(3, 2, 6)
ax6.axis('off')
metrics_text = f"""Performance Metrics:
Final Position Error: {abs(states[-1, 0]):.4f}
Final Velocity: {states[-1, 1]:.4f}
Control Effort (L1): {np.sum(np.abs(controls)):.2f}
Control Effort (L2): {np.sqrt(np.sum(controls**2)):.2f}
Settling Time: ~{len(states) * dt:.1f}s
Max Overshoot: {np.max(np.abs(states[:, 0])):.2f}
"""
ax6.text(0.1, 0.5, metrics_text, fontsize=12, verticalalignment='center',
fontfamily='monospace', bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray"))
plt.suptitle('QP-RNN Control Simulation Results', fontsize=16)
plt.tight_layout()
return fig
def run_qp_rnn_demo(position_gain, velocity_gain, control_cost,
initial_position, initial_velocity,
target_position, simulation_time):
"""Main function for Gradio interface."""
# Run simulation
states, controls, dt = simulate_system(
position_gain, velocity_gain, control_cost,
initial_position, initial_velocity,
target_position, simulation_time
)
# Create plots
fig = create_plots(states, controls, dt)
# Create description
description = f"""
### QP-RNN Control Results
The QP-RNN controller solves the following optimization problem at each time step:
```
min 0.5 * u² * {control_cost} + u * (K @ error)
s.t. -1 ≤ u ≤ 1
```
Where K = [{position_gain}, {velocity_gain}] are the feedback gains.
**Final State:** Position = {states[-1, 0]:.3f}, Velocity = {states[-1, 1]:.3f}
**Key Features:**
- Guaranteed constraint satisfaction (control always in [-1, 1])
- Interpretable structure (quadratic cost + linear feedback)
- Can be trained via RL for complex tasks
"""
return fig, description
# Create Gradio interface
iface = gr.Interface(
fn=run_qp_rnn_demo,
inputs=[
gr.Slider(0.1, 10.0, value=3.0, label="Position Gain (Kp)",
info="Higher values = faster position correction"),
gr.Slider(0.1, 5.0, value=1.5, label="Velocity Gain (Kv)",
info="Higher values = more damping"),
gr.Slider(0.1, 50.0, value=10.0, label="Control Cost",
info="Higher values = less aggressive control"),
gr.Slider(-5.0, 5.0, value=2.0, label="Initial Position"),
gr.Slider(-2.0, 2.0, value=0.0, label="Initial Velocity"),
gr.Slider(-3.0, 3.0, value=0.0, label="Target Position"),
gr.Slider(1.0, 10.0, value=5.0, label="Simulation Time (s)")
],
outputs=[
gr.Plot(label="Simulation Results"),
gr.Markdown(label="Analysis")
],
title="QP-RNN: Quadratic Programming Recurrent Neural Network Demo",
description="""
This interactive demo shows how QP-RNN controllers work for a simple double integrator system.
**What is QP-RNN?**
- Combines Model Predictive Control structure with Deep Reinforcement Learning
- Learns to solve a parameterized Quadratic Program (QP) to generate control actions
- Provides theoretical guarantees (constraint satisfaction, stability verification)
**Try adjusting the parameters** to see how they affect control performance!
Paper: [MPC-Inspired Reinforcement Learning for Verifiable Model-Free Control](https://arxiv.org/abs/2312.05332)
""",
examples=[
[3.0, 1.5, 10.0, 2.0, 0.0, 0.0, 5.0], # Default
[5.0, 2.0, 5.0, 2.0, 0.0, 0.0, 5.0], # Aggressive
[1.0, 0.5, 20.0, 2.0, 0.0, 0.0, 5.0], # Conservative
[3.0, 0.1, 10.0, 2.0, 0.0, 0.0, 5.0], # Underdamped
[3.0, 3.0, 10.0, 2.0, 0.0, 0.0, 5.0], # Overdamped
],
cache_examples=True
)
if __name__ == "__main__":
iface.launch()