File size: 5,012 Bytes
fa91c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb5e0ca
fa91c7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# -*- coding: utf-8 -*-
"""app.py

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1NU6NHjan4eF9IVHR549tKLRVNUQ_dBD7
"""

import os
import torch
import numpy as np
import requests
import json
import gradio as gr
from dotenv import load_dotenv
import torch.nn as nn

# ---- Load env variables ----
OPENROUTER_KEY = os.getenv("OPENROUTER_KEY")
if not OPENROUTER_KEY:
    raise ValueError("OPENROUTER_KEY not set in environment variables.")

# ---- Blackjack Environment ----
import random

class BlackjackEnv:
    def __init__(self):
        self.dealer = []
        self.player = []
        self.usable_ace_player = False

    def draw_card(self):
        return random.randint(1, 10)

    def sum_hand(self, hand):
        total = sum(hand)
        ace = 1 in hand
        if ace and total + 10 <= 21:
            return total + 10, True
        return total, False

    def reset(self):
        self.player = [self.draw_card(), self.draw_card()]
        self.dealer = [self.draw_card()]
        total, usable_ace = self.sum_hand(self.player)
        self.usable_ace_player = usable_ace
        return (self.dealer[0], total, int(usable_ace))

    def step(self, action):
        if action == 1:
            self.player.append(self.draw_card())
            total, usable_ace = self.sum_hand(self.player)
            if total > 21:
                return (self.dealer[0], total, int(usable_ace)), -1, True
            return (self.dealer[0], total, int(usable_ace)), 0, False
        else:
            dealer_hand = self.dealer + [self.draw_card()]
            dealer_total, _ = self.sum_hand(dealer_hand)
            player_total, _ = self.sum_hand(self.player)
            if dealer_total < player_total:
                return (self.dealer[0], player_total, int(self.usable_ace_player)), 1, True
            elif dealer_total > player_total:
                return (self.dealer[0], player_total, int(self.usable_ace_player)), -1, True
            else:
                return (self.dealer[0], player_total, int(self.usable_ace_player)), 0, True

# ---- QNetwork ----
class QNetwork(nn.Module):
    def __init__(self, state_size=3, hidden_size=128, action_size=2):
        super(QNetwork, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size)
        )

    def forward(self, x):
        return self.model(x)

# ---- Load model ----
model = QNetwork()
model_path = "qnetwork_blackjack_weights.pth"
model.load_state_dict(torch.load(model_path))
model.eval()

env = BlackjackEnv()

# ---- LLM Explanation ----
def explain_action(state, action):
    prompt = f"""
You are a blackjack strategy explainer. The player has a total of {state[1]}.
The dealer is showing {state[0]}. Usable ace: {bool(state[2])}.
The DQN model chose to {'Hit' if action == 1 else 'Stick'}.
Explain why this action makes sense in 2-3 sentences.
"""

    headers = {
        "Authorization": f"Bearer {OPENROUTER_KEY}",
        "Content-Type": "application/json"
    }

    data = {
        "model": "mistralai/mistral-7b-instruct",
        "messages": [
            {"role": "system", "content": "You explain blackjack strategies clearly."},
            {"role": "user", "content": prompt}
        ]
    }

    try:
        response = requests.post("https://openrouter.ai/api/v1/chat/completions",
                                 headers=headers, data=json.dumps(data))
        if response.status_code == 200:
            return response.json()['choices'][0]['message']['content']
        return f"LLM error: {response.status_code} - {response.text}"
    except Exception as e:
        return f"LLM call failed: {str(e)}"

# ---- Gradio App ----
def play_hand():
    state = env.reset()
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    with torch.no_grad():
        q_values = model(state_tensor)
    action = torch.argmax(q_values).item()

    explanation = explain_action(state, action)
    action_name = "Hit" if action == 1 else "Stick"
    dealer_card, player_sum, usable_ace = state

    return [
        str(player_sum),
        str(dealer_card),
        str(bool(usable_ace)),
        action_name,
        str(q_values.numpy().tolist()),
        explanation
    ]

demo = gr.Interface(
    fn=play_hand,
    inputs=[],
    outputs=[
        gr.Textbox(label="Player Sum"),
        gr.Textbox(label="Dealer Card"),
        gr.Textbox(label="Usable Ace"),
        gr.Textbox(label="DQN Action"),
        gr.Textbox(label="Q-values"),
        gr.Textbox(label="LLM Explanation")
    ],
    title="🧠 Blackjack Tutor: DQN + LLM",
    description="Play a hand of blackjack. See how a Deep Q Network plays, and get a natural language explanation from Mistral-7B via OpenRouter."
)

if __name__ == "__main__":
    demo.launch()

import os
print(os.listdir())