File size: 12,882 Bytes
93d3bfa
238d37f
f3b040f
93d3bfa
 
938032f
77b14f6
238d37f
7d9bb79
238d37f
7d9bb79
 
238d37f
938032f
238d37f
93d3bfa
f3b040f
 
77b14f6
f3b040f
238d37f
 
 
 
 
938032f
238d37f
7ceb07f
 
 
 
 
 
938032f
 
238d37f
7ceb07f
 
 
238d37f
85c8b2b
77b14f6
238d37f
 
 
 
 
 
a58eed0
 
 
 
 
93d3bfa
238d37f
8eb6be4
7ceb07f
 
 
 
 
f3b040f
8eb6be4
7ceb07f
238d37f
a58eed0
93d3bfa
77b14f6
238d37f
85c8b2b
7ceb07f
 
 
 
 
ae5323d
816e617
238d37f
 
 
938032f
7ceb07f
938032f
238d37f
938032f
 
 
238d37f
938032f
 
238d37f
da2916f
238d37f
 
938032f
238d37f
da2916f
238d37f
 
938032f
238d37f
938032f
 
238d37f
938032f
238d37f
 
 
ae5323d
938032f
238d37f
 
 
 
ae5323d
238d37f
938032f
 
238d37f
 
938032f
 
238d37f
 
 
 
938032f
ae5323d
238d37f
da2916f
938032f
238d37f
 
938032f
7ceb07f
938032f
 
77b14f6
238d37f
 
 
 
 
 
 
7ceb07f
938032f
da2916f
 
938032f
 
 
7ceb07f
938032f
 
7ceb07f
 
 
 
 
238d37f
 
7ceb07f
cd7eb0b
7ceb07f
938032f
238d37f
da2916f
238d37f
 
cd7eb0b
 
da2916f
f3b040f
 
85c8b2b
f3b040f
 
938032f
77b14f6
7ceb07f
 
cd7eb0b
7ceb07f
da2916f
 
 
938032f
 
7ceb07f
816e617
238d37f
f3b040f
 
77b14f6
da2916f
238d37f
da2916f
 
 
238d37f
 
 
 
da2916f
 
238d37f
 
 
 
938032f
238d37f
da2916f
938032f
 
 
77b14f6
938032f
238d37f
938032f
 
 
 
 
 
238d37f
938032f
 
 
 
238d37f
 
 
816e617
 
93d3bfa
938032f
 
 
238d37f
 
da2916f
 
 
 
 
938032f
 
 
 
 
77b14f6
f3b040f
238d37f
938032f
77b14f6
 
 
 
238d37f
77b14f6
cd7eb0b
 
 
238d37f
 
 
 
 
 
 
cd7eb0b
 
 
 
 
 
 
 
 
238d37f
816e617
 
7ceb07f
 
 
ae5323d
cd7eb0b
 
 
 
 
ae5323d
238d37f
938032f
77b14f6
938032f
77b14f6
938032f
77b14f6
938032f
77b14f6
7ceb07f
da2916f
7ceb07f
 
 
 
816e617
77b14f6
f3b040f
238d37f
da2916f
f3b040f
 
 
816e617
238d37f
 
 
 
 
 
 
f3b040f
cd7eb0b
 
7ceb07f
238d37f
7ceb07f
 
cd7eb0b
 
7ceb07f
 
 
cd7eb0b
 
85c8b2b
238d37f
938032f
238d37f
 
77b14f6
816e617
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# ── SLM_CService.py ───────────────────────────────────────────────────────────
# Customer-support-only chatbot with strict NSFW blocking + robust FSM + proper reset.

import os
import re
from typing import List, Dict

# Keep OpenMP logs quiet
os.environ["OMP_NUM_THREADS"] = "1"
# Ensure we don't accidentally force offline mode
os.environ.pop("HF_HUB_OFFLINE", None)

# ── Import order matters: Unsloth should come before transformers/peft.
import unsloth  # noqa: E402

import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
from peft import PeftModel
from langchain.memory import ConversationBufferMemory

# ==============================
# Config
# ==============================
REPO = "ThomasBasil/bitext-qlora-tinyllama"   # your adapter + tokenizer live at repo root
BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"   # base model

GEN_KW = dict(  # generation params (passed at call time)
    max_new_tokens=160,
    do_sample=True,
    top_p=0.9,
    temperature=0.7,
    repetition_penalty=1.1,
    no_repeat_ngram_size=4,
)

bnb_cfg = BitsAndBytesConfig(  # 4-bit QLoRA-style loading (needs GPU)
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,  # T4/A10G-friendly
)

# Memory key FIX: use the same key for saving & reading history
MEMORY_KEY = "chat_history"

# ==============================
# Load tokenizer & model
# ==============================
tokenizer = AutoTokenizer.from_pretrained(REPO, use_fast=False)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_side = "right"

# Unsloth returns (model, tokenizer) -> unpack
model, _ = unsloth.FastLanguageModel.from_pretrained(
    model_name=BASE,
    load_in_4bit=True,
    quantization_config=bnb_cfg,
    device_map="auto",
    trust_remote_code=True,
)
unsloth.FastLanguageModel.for_inference(model)

# Attach your PEFT adapter from repo root
model = PeftModel.from_pretrained(model, REPO)
model.eval()

# Text-generation pipeline (pass GEN_KW at call time, not as generate_kwargs)
chat_pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    trust_remote_code=True,
    return_full_text=False,
)

# ==============================
# Moderation (strict)
# ==============================
from transformers import TextClassificationPipeline

SEXUAL_TERMS = [
    # single words
    "sex","sexual","porn","nsfw","fetish","kink","bdsm","nude","naked","anal",
    "blowjob","handjob","cum","breast","boobs","vagina","penis","semen","ejaculate",
    "doggy","missionary","cowgirl","69","kamasutra","dominatrix","submissive","spank",
    # phrases
    "sex position","have sex","make love","how to flirt","dominant in bed",
]

def _bad_words_ids(tok, terms: List[str]) -> List[List[int]]:
    """Build bad_words_ids for generation; include both 'term' and ' term' variants."""
    ids = set()
    for t in terms:
        for v in (t, " " + t):
            toks = tok(v, add_special_tokens=False).input_ids
            if toks:
                ids.add(tuple(toks))
    return [list(t) for t in ids]

BAD_WORD_IDS = _bad_words_ids(tokenizer, SEXUAL_TERMS)

# Lightweight classifiers (optional but helpful defense-in-depth)
nsfw_cls: TextClassificationPipeline = pipeline(
    "text-classification",
    model="eliasalbouzidi/distilbert-nsfw-text-classifier",
    truncation=True,
)
toxicity_cls: TextClassificationPipeline = pipeline(
    "text-classification",
    model="unitary/toxic-bert",
    truncation=True,
    return_all_scores=True,
)

def is_sexual_or_toxic(text: str) -> bool:
    t = (text or "").lower()
    if any(k in t for k in SEXUAL_TERMS):
        return True
    try:
        res = nsfw_cls(t)[0]
        if (res.get("label","").lower() == "nsfw") and float(res.get("score",0)) > 0.60:
            return True
    except Exception:
        pass
    try:
        scores = toxicity_cls(t)[0]
        if any(s["score"] > 0.60 and s["label"].lower() in
               {"toxic","severe_toxic","obscene","threat","insult","identity_hate"} for s in scores):
            return True
    except Exception:
        pass
    return False

REFUSAL = ("Sorry, I can’t help with that. I’m only for store support "
           "(orders, shipping, ETA, tracking, returns, warranty, account).")

# ==============================
# Memory + Globals
# ==============================
memory = ConversationBufferMemory(
    memory_key=MEMORY_KEY,      # ← FIX: explicit memory key
    return_messages=True,
)

SYSTEM_PROMPT = (
    "You are a customer-support assistant for our store. Only handle account, "
    "orders, shipping, delivery ETA, tracking links, returns/refunds, warranty, and store policy. "
    "If a request is out of scope or sexual/NSFW, refuse briefly and offer support options. "
    "Be concise and professional."
)

ALLOWED_KEYWORDS = (
    "order","track","status","delivery","shipping","ship","eta","arrive",
    "refund","return","exchange","warranty","guarantee","policy","account","billing",
    "address","cancel","help","support","agent","human"
)

# Robust order detection:
# - "#67890" / "# 67890"
# - "order 67890", "order no. 67890", "order number 67890", "order id 67890"
ORDER_RX = re.compile(
    r"(?:#\s*([\d]{3,12})|order(?:\s*(?:no\.?|number|id))?\s*#?\s*([\d]{3,12}))",
    flags=re.I,
)

def extract_order(text: str):
    if not text:
        return None
    m = ORDER_RX.search(text)
    return (m.group(1) or m.group(2)) if m else None

def handle_status(o): return f"Order #{o} is in transit and should arrive in 3–5 business days."
def handle_eta(o):    return f"Delivery for order #{o} typically takes 3–5 days; you can track it at https://track.example.com/{o}"
def handle_track(o):  return f"Track order #{o} here: https://track.example.com/{o}"
def handle_link(o):   return f"Here’s the latest tracking link for order #{o}: https://track.example.com/{o}"
def handle_return_policy(_=None):
    return ("Our return policy allows returns of unused items in original packaging within 30 days of receipt. "
            "Would you like me to connect you with a human agent?")
def handle_warranty_policy(_=None):
    return ("We provide a 1-year limited warranty against manufacturing defects. "
            "Within 30 days you can return or exchange; afterwards, warranty service applies. "
            "Need help starting a claim?")
def handle_cancel(o=None):
    return (f"I’ve submitted a cancellation request for order #{o}. If it has already shipped, "
            "we’ll process a return/refund once it’s back. You’ll receive a confirmation email shortly.")
def handle_gratitude(_=None): return "You’re welcome! Anything else I can help with?"
def handle_escalation(_=None): return "I can connect you with a human agent. Would you like me to do that?"
def handle_ask_action(o): return (f"I’ve saved order #{o}. What would you like to do β€” status, ETA, tracking link, or cancel?")

# >>> state that must reset <<<
stored_order   = None
pending_intent = None

def reset_state():
    """Called by app.py Reset button to clear memory + globals."""
    global stored_order, pending_intent
    stored_order = None
    pending_intent = None
    try:
        memory.clear()  # wipe the buffer
    except Exception:
        pass
    return True

# ==============================
# Chat templating helpers
# ==============================
def _lc_to_messages() -> List[Dict[str, str]]:
    msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
    hist = memory.load_memory_variables({}).get(MEMORY_KEY, []) or []  # ← use same key
    for m in hist:
        role = "user" if getattr(m, "type", "") == "human" else "assistant"
        msgs.append({"role": role, "content": getattr(m, "content", "")})
    return msgs

def _generate_reply(user_input: str) -> str:
    # Format with HF chat template so the model respects roles/system
    messages = _lc_to_messages() + [{"role": "user", "content": user_input}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    out = chat_pipe(
        prompt,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        bad_words_ids=BAD_WORD_IDS,  # block sexual tokens at generation time
        **GEN_KW,
    )[0]["generated_text"]
    return out.strip()

# ==============================
# Main entry
# ==============================
def chat_with_memory(user_input: str) -> str:
    global stored_order, pending_intent
    ui = (user_input or "").strip()
    if not ui:
        return "How can I help with your order today?"

    # Fresh session guard: if memory empty, also clear globals
    hist = memory.load_memory_variables({}).get(MEMORY_KEY, []) or []
    if len(hist) == 0:
        stored_order = None
        pending_intent = None

    # 1) Safety
    if is_sexual_or_toxic(ui):
        reply = REFUSAL
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    low = ui.lower()

    # 2) Quick intents (gratitude / returns)
    if any(tok in low for tok in ["thank you","thanks","thx"]):
        reply = handle_gratitude()
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # 3) PENDING-INTENT SHORT-CIRCUIT (fixes "It's #26790" case)
    new_o = extract_order(ui)
    if pending_intent:
        if new_o:
            stored_order = new_o
            fn = {
                "status": handle_status,
                "eta":    handle_eta,
                "track":  handle_track,
                "link":   handle_link,
                "cancel": handle_cancel,
            }[pending_intent]
            reply = fn(stored_order)
            pending_intent = None
            memory.save_context({"input": ui}, {"output": reply})
            return reply
        # still waiting for an order number
        reply = "Got itβ€”please share your order number (e.g., #12345)."
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # 4) If message provides an order number (no pending intent yet), save & ask action
    if new_o:
        stored_order = new_o
        reply = handle_ask_action(stored_order)
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # 5) Support-only guard (message must be support-ish)
    if not any(k in low for k in ALLOWED_KEYWORDS) and not any(k in low for k in ("hi","hello","hey")):
        reply = "I’m for store support only (orders, shipping, returns, warranty, account). How can I help with those?"
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # 6) Intent classification (deterministic handlers first)
    if any(k in low for k in ["status","where is my order","check status"]):
        intent = "status"
    elif any(k in low for k in ["how long","eta","delivery time"]):
        intent = "eta"
    elif any(k in low for k in ["how can i track","track my order","where is my package","tracking"]):
        intent = "track"
    elif "tracking link" in low or "resend" in low or "link" in low:
        intent = "link"
    elif any(k in low for k in ["cancel","cancellation","abort order"]):
        intent = "cancel"
    elif any(k in low for k in ["warranty","guarantee","policy"]):
        intent = "warranty_policy"
    elif "return" in low:
        intent = "return_policy"
    else:
        intent = "fallback"

    # 7) Handle intents that need an order number
    if intent in ("status","eta","track","link","cancel"):
        if not stored_order:
            pending_intent = intent
            reply = "Sureβ€”what’s your order number (e.g., #12345)?"
        else:
            fn = {
                "status": handle_status,
                "eta":    handle_eta,
                "track":  handle_track,
                "link":   handle_link,
                "cancel": handle_cancel,
            }[intent]
            reply = fn(stored_order)
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # 8) Policy intents (no order needed)
    if intent == "warranty_policy":
        reply = handle_warranty_policy()
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    if intent == "return_policy":
        reply = handle_return_policy()
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # 9) LLM fallback (still on-topic) + post-check
    reply = _generate_reply(ui)
    if is_sexual_or_toxic(reply):
        reply = REFUSAL
    memory.save_context({"input": ui}, {"output": reply})
    return reply