Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Gradio app for QP-RNN interactive demo. | |
Suitable for deployment on Hugging Face Spaces. | |
""" | |
import gradio as gr | |
import torch | |
import numpy as np | |
import matplotlib | |
matplotlib.use('Agg') # Use non-interactive backend | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
import base64 | |
class MinimalQPRNN(torch.nn.Module): | |
"""Minimal QP-RNN for demonstration.""" | |
def __init__(self, position_gain=3.0, velocity_gain=1.5, control_cost=10.0): | |
super().__init__() | |
self.P = torch.tensor([[control_cost]], dtype=torch.float32) | |
self.K = torch.tensor([position_gain, velocity_gain], dtype=torch.float32) | |
def forward(self, state, reference=None): | |
if reference is None: | |
reference = torch.zeros_like(state) | |
error = state - reference | |
q = torch.sum(self.K * error, dim=-1, keepdim=True) | |
u_unconstrained = -q / self.P | |
u = torch.clamp(u_unconstrained, -1.0, 1.0) | |
return u | |
def simulate_system(position_gain, velocity_gain, control_cost, | |
initial_position, initial_velocity, | |
target_position, simulation_time): | |
"""Run simulation with given parameters.""" | |
# Create controller | |
controller = MinimalQPRNN(position_gain, velocity_gain, control_cost) | |
# Setup | |
dt = 0.05 | |
T = int(simulation_time / dt) | |
x0 = torch.tensor([initial_position, initial_velocity]) | |
x_ref = torch.tensor([target_position, 0.0]) | |
# Simulate | |
states = [x0.numpy()] | |
controls = [] | |
x = x0.clone() | |
for t in range(T): | |
u = controller(x, x_ref) | |
x_next = torch.zeros_like(x) | |
x_next[0] = x[0] + x[1] * dt | |
x_next[1] = x[1] + u.item() * dt | |
states.append(x_next.numpy()) | |
controls.append(u.item()) | |
x = x_next | |
return np.array(states), np.array(controls), dt | |
def create_plots(states, controls, dt): | |
"""Create visualization plots.""" | |
time = np.arange(len(states)) * dt | |
time_control = time[:-1] | |
# Create figure with subplots | |
fig = plt.figure(figsize=(12, 10)) | |
# Position subplot | |
ax1 = plt.subplot(3, 2, 1) | |
ax1.plot(time, states[:, 0], 'b-', linewidth=2) | |
ax1.axhline(y=states[-1, 0], color='r', linestyle='--', alpha=0.5) | |
ax1.set_ylabel('Position') | |
ax1.set_title('Position vs Time') | |
ax1.grid(True, alpha=0.3) | |
# Velocity subplot | |
ax2 = plt.subplot(3, 2, 2) | |
ax2.plot(time, states[:, 1], 'g-', linewidth=2) | |
ax2.axhline(y=0, color='r', linestyle='--', alpha=0.5) | |
ax2.set_ylabel('Velocity') | |
ax2.set_title('Velocity vs Time') | |
ax2.grid(True, alpha=0.3) | |
# Control subplot | |
ax3 = plt.subplot(3, 2, 3) | |
ax3.plot(time_control, controls, 'r-', linewidth=2) | |
ax3.axhline(y=1, color='k', linestyle=':', alpha=0.5) | |
ax3.axhline(y=-1, color='k', linestyle=':', alpha=0.5) | |
ax3.set_ylabel('Control Input') | |
ax3.set_xlabel('Time (s)') | |
ax3.set_title('Control Input vs Time') | |
ax3.grid(True, alpha=0.3) | |
ax3.set_ylim(-1.2, 1.2) | |
# Phase portrait | |
ax4 = plt.subplot(3, 2, 4) | |
ax4.plot(states[:, 0], states[:, 1], 'b-', linewidth=2) | |
ax4.scatter([states[0, 0]], [states[0, 1]], color='green', s=100, marker='o', label='Start') | |
ax4.scatter([states[-1, 0]], [states[-1, 1]], color='red', s=100, marker='x', label='End') | |
ax4.set_xlabel('Position') | |
ax4.set_ylabel('Velocity') | |
ax4.set_title('Phase Portrait') | |
ax4.legend() | |
ax4.grid(True, alpha=0.3) | |
# QP visualization | |
ax5 = plt.subplot(3, 2, 5) | |
# Show how control saturates | |
time_saturated = np.sum(np.abs(controls) >= 0.99) / len(controls) * 100 | |
labels = ['Saturated', 'Unsaturated'] | |
sizes = [time_saturated, 100 - time_saturated] | |
colors = ['red', 'blue'] | |
ax5.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%') | |
ax5.set_title('Control Saturation') | |
# Metrics text | |
ax6 = plt.subplot(3, 2, 6) | |
ax6.axis('off') | |
metrics_text = f"""Performance Metrics: | |
Final Position Error: {abs(states[-1, 0]):.4f} | |
Final Velocity: {states[-1, 1]:.4f} | |
Control Effort (L1): {np.sum(np.abs(controls)):.2f} | |
Control Effort (L2): {np.sqrt(np.sum(controls**2)):.2f} | |
Settling Time: ~{len(states) * dt:.1f}s | |
Max Overshoot: {np.max(np.abs(states[:, 0])):.2f} | |
""" | |
ax6.text(0.1, 0.5, metrics_text, fontsize=12, verticalalignment='center', | |
fontfamily='monospace', bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray")) | |
plt.suptitle('QP-RNN Control Simulation Results', fontsize=16) | |
plt.tight_layout() | |
return fig | |
def run_qp_rnn_demo(position_gain, velocity_gain, control_cost, | |
initial_position, initial_velocity, | |
target_position, simulation_time): | |
"""Main function for Gradio interface.""" | |
# Run simulation | |
states, controls, dt = simulate_system( | |
position_gain, velocity_gain, control_cost, | |
initial_position, initial_velocity, | |
target_position, simulation_time | |
) | |
# Create plots | |
fig = create_plots(states, controls, dt) | |
# Create description | |
description = f""" | |
### QP-RNN Control Results | |
The QP-RNN controller solves the following optimization problem at each time step: | |
``` | |
min 0.5 * u² * {control_cost} + u * (K @ error) | |
s.t. -1 ≤ u ≤ 1 | |
``` | |
Where K = [{position_gain}, {velocity_gain}] are the feedback gains. | |
**Final State:** Position = {states[-1, 0]:.3f}, Velocity = {states[-1, 1]:.3f} | |
**Key Features:** | |
- Guaranteed constraint satisfaction (control always in [-1, 1]) | |
- Interpretable structure (quadratic cost + linear feedback) | |
- Can be trained via RL for complex tasks | |
""" | |
return fig, description | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=run_qp_rnn_demo, | |
inputs=[ | |
gr.Slider(0.1, 10.0, value=3.0, label="Position Gain (Kp)", | |
info="Higher values = faster position correction"), | |
gr.Slider(0.1, 5.0, value=1.5, label="Velocity Gain (Kv)", | |
info="Higher values = more damping"), | |
gr.Slider(0.1, 50.0, value=10.0, label="Control Cost", | |
info="Higher values = less aggressive control"), | |
gr.Slider(-5.0, 5.0, value=2.0, label="Initial Position"), | |
gr.Slider(-2.0, 2.0, value=0.0, label="Initial Velocity"), | |
gr.Slider(-3.0, 3.0, value=0.0, label="Target Position"), | |
gr.Slider(1.0, 10.0, value=5.0, label="Simulation Time (s)") | |
], | |
outputs=[ | |
gr.Plot(label="Simulation Results"), | |
gr.Markdown(label="Analysis") | |
], | |
title="QP-RNN: Quadratic Programming Recurrent Neural Network Demo", | |
description=""" | |
This interactive demo shows how QP-RNN controllers work for a simple double integrator system. | |
**What is QP-RNN?** | |
- Combines Model Predictive Control structure with Deep Reinforcement Learning | |
- Learns to solve a parameterized Quadratic Program (QP) to generate control actions | |
- Provides theoretical guarantees (constraint satisfaction, stability verification) | |
**Try adjusting the parameters** to see how they affect control performance! | |
Paper: [MPC-Inspired Reinforcement Learning for Verifiable Model-Free Control](https://arxiv.org/abs/2312.05332) | |
""", | |
examples=[ | |
[3.0, 1.5, 10.0, 2.0, 0.0, 0.0, 5.0], # Default | |
[5.0, 2.0, 5.0, 2.0, 0.0, 0.0, 5.0], # Aggressive | |
[1.0, 0.5, 20.0, 2.0, 0.0, 0.0, 5.0], # Conservative | |
[3.0, 0.1, 10.0, 2.0, 0.0, 0.0, 5.0], # Underdamped | |
[3.0, 3.0, 10.0, 2.0, 0.0, 0.0, 5.0], # Overdamped | |
], | |
cache_examples=True | |
) | |
if __name__ == "__main__": | |
iface.launch() |