Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.py | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1NU6NHjan4eF9IVHR549tKLRVNUQ_dBD7 | |
""" | |
import os | |
import torch | |
import numpy as np | |
import requests | |
import json | |
import gradio as gr | |
from dotenv import load_dotenv | |
import torch.nn as nn | |
# ---- Load env variables ---- | |
OPENROUTER_KEY = os.getenv("OPENROUTER_KEY") | |
if not OPENROUTER_KEY: | |
raise ValueError("OPENROUTER_KEY not set in environment variables.") | |
# ---- Blackjack Environment ---- | |
import random | |
class BlackjackEnv: | |
def __init__(self): | |
self.dealer = [] | |
self.player = [] | |
self.usable_ace_player = False | |
def draw_card(self): | |
return random.randint(1, 10) | |
def sum_hand(self, hand): | |
total = sum(hand) | |
ace = 1 in hand | |
if ace and total + 10 <= 21: | |
return total + 10, True | |
return total, False | |
def reset(self): | |
self.player = [self.draw_card(), self.draw_card()] | |
self.dealer = [self.draw_card()] | |
total, usable_ace = self.sum_hand(self.player) | |
self.usable_ace_player = usable_ace | |
return (self.dealer[0], total, int(usable_ace)) | |
def step(self, action): | |
if action == 1: | |
self.player.append(self.draw_card()) | |
total, usable_ace = self.sum_hand(self.player) | |
if total > 21: | |
return (self.dealer[0], total, int(usable_ace)), -1, True | |
return (self.dealer[0], total, int(usable_ace)), 0, False | |
else: | |
dealer_hand = self.dealer + [self.draw_card()] | |
dealer_total, _ = self.sum_hand(dealer_hand) | |
player_total, _ = self.sum_hand(self.player) | |
if dealer_total < player_total: | |
return (self.dealer[0], player_total, int(self.usable_ace_player)), 1, True | |
elif dealer_total > player_total: | |
return (self.dealer[0], player_total, int(self.usable_ace_player)), -1, True | |
else: | |
return (self.dealer[0], player_total, int(self.usable_ace_player)), 0, True | |
# ---- QNetwork ---- | |
class QNetwork(nn.Module): | |
def __init__(self, state_size=3, hidden_size=128, action_size=2): | |
super(QNetwork, self).__init__() | |
self.model = nn.Sequential( | |
nn.Linear(state_size, hidden_size), | |
nn.ReLU(), | |
nn.Linear(hidden_size, hidden_size), | |
nn.ReLU(), | |
nn.Linear(hidden_size, action_size) | |
) | |
def forward(self, x): | |
return self.model(x) | |
# ---- Load model ---- | |
model = QNetwork() | |
model_path = "qnetwork_blackjack_weights.pth" | |
model.load_state_dict(torch.load(model_path)) | |
model.eval() | |
env = BlackjackEnv() | |
# ---- LLM Explanation ---- | |
def explain_action(state, action): | |
prompt = f""" | |
You are a blackjack strategy explainer. The player has a total of {state[1]}. | |
The dealer is showing {state[0]}. Usable ace: {bool(state[2])}. | |
The DQN model chose to {'Hit' if action == 1 else 'Stick'}. | |
Explain why this action makes sense in 2-3 sentences. | |
""" | |
headers = { | |
"Authorization": f"Bearer {OPENROUTER_KEY}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "mistralai/mistral-7b-instruct", | |
"messages": [ | |
{"role": "system", "content": "You explain blackjack strategies clearly."}, | |
{"role": "user", "content": prompt} | |
] | |
} | |
try: | |
response = requests.post("https://openrouter.ai/api/v1/chat/completions", | |
headers=headers, data=json.dumps(data)) | |
if response.status_code == 200: | |
return response.json()['choices'][0]['message']['content'] | |
return f"LLM error: {response.status_code} - {response.text}" | |
except Exception as e: | |
return f"LLM call failed: {str(e)}" | |
# ---- Gradio App ---- | |
def play_hand(): | |
state = env.reset() | |
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0) | |
with torch.no_grad(): | |
q_values = model(state_tensor) | |
action = torch.argmax(q_values).item() | |
explanation = explain_action(state, action) | |
action_name = "Hit" if action == 1 else "Stick" | |
dealer_card, player_sum, usable_ace = state | |
return [ | |
str(player_sum), | |
str(dealer_card), | |
str(bool(usable_ace)), | |
action_name, | |
str(q_values.numpy().tolist()), | |
explanation | |
] | |
demo = gr.Interface( | |
fn=play_hand, | |
inputs=[], | |
outputs=[ | |
gr.Textbox(label="Player Sum"), | |
gr.Textbox(label="Dealer Card"), | |
gr.Textbox(label="Usable Ace"), | |
gr.Textbox(label="DQN Action"), | |
gr.Textbox(label="Q-values"), | |
gr.Textbox(label="LLM Explanation") | |
], | |
title="🧠 Blackjack Tutor: DQN + LLM", | |
description="Play a hand of blackjack. See how a Deep Q Network plays, and get a natural language explanation from Mistral-7B via OpenRouter." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |
import os | |
print(os.listdir()) |