GenAIDevTOProd's picture
Update app.py
eb5e0ca verified
# -*- coding: utf-8 -*-
"""app.py
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1NU6NHjan4eF9IVHR549tKLRVNUQ_dBD7
"""
import os
import torch
import numpy as np
import requests
import json
import gradio as gr
from dotenv import load_dotenv
import torch.nn as nn
# ---- Load env variables ----
OPENROUTER_KEY = os.getenv("OPENROUTER_KEY")
if not OPENROUTER_KEY:
raise ValueError("OPENROUTER_KEY not set in environment variables.")
# ---- Blackjack Environment ----
import random
class BlackjackEnv:
def __init__(self):
self.dealer = []
self.player = []
self.usable_ace_player = False
def draw_card(self):
return random.randint(1, 10)
def sum_hand(self, hand):
total = sum(hand)
ace = 1 in hand
if ace and total + 10 <= 21:
return total + 10, True
return total, False
def reset(self):
self.player = [self.draw_card(), self.draw_card()]
self.dealer = [self.draw_card()]
total, usable_ace = self.sum_hand(self.player)
self.usable_ace_player = usable_ace
return (self.dealer[0], total, int(usable_ace))
def step(self, action):
if action == 1:
self.player.append(self.draw_card())
total, usable_ace = self.sum_hand(self.player)
if total > 21:
return (self.dealer[0], total, int(usable_ace)), -1, True
return (self.dealer[0], total, int(usable_ace)), 0, False
else:
dealer_hand = self.dealer + [self.draw_card()]
dealer_total, _ = self.sum_hand(dealer_hand)
player_total, _ = self.sum_hand(self.player)
if dealer_total < player_total:
return (self.dealer[0], player_total, int(self.usable_ace_player)), 1, True
elif dealer_total > player_total:
return (self.dealer[0], player_total, int(self.usable_ace_player)), -1, True
else:
return (self.dealer[0], player_total, int(self.usable_ace_player)), 0, True
# ---- QNetwork ----
class QNetwork(nn.Module):
def __init__(self, state_size=3, hidden_size=128, action_size=2):
super(QNetwork, self).__init__()
self.model = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, action_size)
)
def forward(self, x):
return self.model(x)
# ---- Load model ----
model = QNetwork()
model_path = "qnetwork_blackjack_weights.pth"
model.load_state_dict(torch.load(model_path))
model.eval()
env = BlackjackEnv()
# ---- LLM Explanation ----
def explain_action(state, action):
prompt = f"""
You are a blackjack strategy explainer. The player has a total of {state[1]}.
The dealer is showing {state[0]}. Usable ace: {bool(state[2])}.
The DQN model chose to {'Hit' if action == 1 else 'Stick'}.
Explain why this action makes sense in 2-3 sentences.
"""
headers = {
"Authorization": f"Bearer {OPENROUTER_KEY}",
"Content-Type": "application/json"
}
data = {
"model": "mistralai/mistral-7b-instruct",
"messages": [
{"role": "system", "content": "You explain blackjack strategies clearly."},
{"role": "user", "content": prompt}
]
}
try:
response = requests.post("https://openrouter.ai/api/v1/chat/completions",
headers=headers, data=json.dumps(data))
if response.status_code == 200:
return response.json()['choices'][0]['message']['content']
return f"LLM error: {response.status_code} - {response.text}"
except Exception as e:
return f"LLM call failed: {str(e)}"
# ---- Gradio App ----
def play_hand():
state = env.reset()
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
q_values = model(state_tensor)
action = torch.argmax(q_values).item()
explanation = explain_action(state, action)
action_name = "Hit" if action == 1 else "Stick"
dealer_card, player_sum, usable_ace = state
return [
str(player_sum),
str(dealer_card),
str(bool(usable_ace)),
action_name,
str(q_values.numpy().tolist()),
explanation
]
demo = gr.Interface(
fn=play_hand,
inputs=[],
outputs=[
gr.Textbox(label="Player Sum"),
gr.Textbox(label="Dealer Card"),
gr.Textbox(label="Usable Ace"),
gr.Textbox(label="DQN Action"),
gr.Textbox(label="Q-values"),
gr.Textbox(label="LLM Explanation")
],
title="🧠 Blackjack Tutor: DQN + LLM",
description="Play a hand of blackjack. See how a Deep Q Network plays, and get a natural language explanation from Mistral-7B via OpenRouter."
)
if __name__ == "__main__":
demo.launch()
import os
print(os.listdir())