|
|
|
|
|
|
|
import os |
|
import re |
|
from typing import List, Dict |
|
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
os.environ.pop("HF_HUB_OFFLINE", None) |
|
|
|
|
|
import unsloth |
|
|
|
import torch |
|
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline |
|
from peft import PeftModel |
|
from langchain.memory import ConversationBufferMemory |
|
|
|
|
|
|
|
|
|
REPO = "ThomasBasil/bitext-qlora-tinyllama" |
|
BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
|
GEN_KW = dict( |
|
max_new_tokens=160, |
|
do_sample=True, |
|
top_p=0.9, |
|
temperature=0.7, |
|
repetition_penalty=1.1, |
|
no_repeat_ngram_size=4, |
|
) |
|
|
|
bnb_cfg = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
|
|
|
|
MEMORY_KEY = "chat_history" |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(REPO, use_fast=False) |
|
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.padding_side = "left" |
|
tokenizer.truncation_side = "right" |
|
|
|
|
|
model, _ = unsloth.FastLanguageModel.from_pretrained( |
|
model_name=BASE, |
|
load_in_4bit=True, |
|
quantization_config=bnb_cfg, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
unsloth.FastLanguageModel.for_inference(model) |
|
|
|
|
|
model = PeftModel.from_pretrained(model, REPO) |
|
model.eval() |
|
|
|
|
|
chat_pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
trust_remote_code=True, |
|
return_full_text=False, |
|
) |
|
|
|
|
|
|
|
|
|
from transformers import TextClassificationPipeline |
|
|
|
SEXUAL_TERMS = [ |
|
|
|
"sex","sexual","porn","nsfw","fetish","kink","bdsm","nude","naked","anal", |
|
"blowjob","handjob","cum","breast","boobs","vagina","penis","semen","ejaculate", |
|
"doggy","missionary","cowgirl","69","kamasutra","dominatrix","submissive","spank", |
|
|
|
"sex position","have sex","make love","how to flirt","dominant in bed", |
|
] |
|
|
|
def _bad_words_ids(tok, terms: List[str]) -> List[List[int]]: |
|
"""Build bad_words_ids for generation; include both 'term' and ' term' variants.""" |
|
ids = set() |
|
for t in terms: |
|
for v in (t, " " + t): |
|
toks = tok(v, add_special_tokens=False).input_ids |
|
if toks: |
|
ids.add(tuple(toks)) |
|
return [list(t) for t in ids] |
|
|
|
BAD_WORD_IDS = _bad_words_ids(tokenizer, SEXUAL_TERMS) |
|
|
|
|
|
nsfw_cls: TextClassificationPipeline = pipeline( |
|
"text-classification", |
|
model="eliasalbouzidi/distilbert-nsfw-text-classifier", |
|
truncation=True, |
|
) |
|
toxicity_cls: TextClassificationPipeline = pipeline( |
|
"text-classification", |
|
model="unitary/toxic-bert", |
|
truncation=True, |
|
return_all_scores=True, |
|
) |
|
|
|
def is_sexual_or_toxic(text: str) -> bool: |
|
t = (text or "").lower() |
|
if any(k in t for k in SEXUAL_TERMS): |
|
return True |
|
try: |
|
res = nsfw_cls(t)[0] |
|
if (res.get("label","").lower() == "nsfw") and float(res.get("score",0)) > 0.60: |
|
return True |
|
except Exception: |
|
pass |
|
try: |
|
scores = toxicity_cls(t)[0] |
|
if any(s["score"] > 0.60 and s["label"].lower() in |
|
{"toxic","severe_toxic","obscene","threat","insult","identity_hate"} for s in scores): |
|
return True |
|
except Exception: |
|
pass |
|
return False |
|
|
|
REFUSAL = ("Sorry, I canβt help with that. Iβm only for store support " |
|
"(orders, shipping, ETA, tracking, returns, warranty, account).") |
|
|
|
|
|
|
|
|
|
memory = ConversationBufferMemory( |
|
memory_key=MEMORY_KEY, |
|
return_messages=True, |
|
) |
|
|
|
SYSTEM_PROMPT = ( |
|
"You are a customer-support assistant for our store. Only handle account, " |
|
"orders, shipping, delivery ETA, tracking links, returns/refunds, warranty, and store policy. " |
|
"If a request is out of scope or sexual/NSFW, refuse briefly and offer support options. " |
|
"Be concise and professional." |
|
) |
|
|
|
ALLOWED_KEYWORDS = ( |
|
"order","track","status","delivery","shipping","ship","eta","arrive", |
|
"refund","return","exchange","warranty","guarantee","policy","account","billing", |
|
"address","cancel","help","support","agent","human" |
|
) |
|
|
|
|
|
|
|
|
|
ORDER_RX = re.compile( |
|
r"(?:#\s*([\d]{3,12})|order(?:\s*(?:no\.?|number|id))?\s*#?\s*([\d]{3,12}))", |
|
flags=re.I, |
|
) |
|
|
|
def extract_order(text: str): |
|
if not text: |
|
return None |
|
m = ORDER_RX.search(text) |
|
return (m.group(1) or m.group(2)) if m else None |
|
|
|
def handle_status(o): return f"Order #{o} is in transit and should arrive in 3β5 business days." |
|
def handle_eta(o): return f"Delivery for order #{o} typically takes 3β5 days; you can track it at https://track.example.com/{o}" |
|
def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}" |
|
def handle_link(o): return f"Hereβs the latest tracking link for order #{o}: https://track.example.com/{o}" |
|
def handle_return_policy(_=None): |
|
return ("Our return policy allows returns of unused items in original packaging within 30 days of receipt. " |
|
"Would you like me to connect you with a human agent?") |
|
def handle_warranty_policy(_=None): |
|
return ("We provide a 1-year limited warranty against manufacturing defects. " |
|
"Within 30 days you can return or exchange; afterwards, warranty service applies. " |
|
"Need help starting a claim?") |
|
def handle_cancel(o=None): |
|
return (f"Iβve submitted a cancellation request for order #{o}. If it has already shipped, " |
|
"weβll process a return/refund once itβs back. Youβll receive a confirmation email shortly.") |
|
def handle_gratitude(_=None): return "Youβre welcome! Anything else I can help with?" |
|
def handle_escalation(_=None): return "I can connect you with a human agent. Would you like me to do that?" |
|
def handle_ask_action(o): return (f"Iβve saved order #{o}. What would you like to do β status, ETA, tracking link, or cancel?") |
|
|
|
|
|
stored_order = None |
|
pending_intent = None |
|
|
|
def reset_state(): |
|
"""Called by app.py Reset button to clear memory + globals.""" |
|
global stored_order, pending_intent |
|
stored_order = None |
|
pending_intent = None |
|
try: |
|
memory.clear() |
|
except Exception: |
|
pass |
|
return True |
|
|
|
|
|
|
|
|
|
def _lc_to_messages() -> List[Dict[str, str]]: |
|
msgs = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
hist = memory.load_memory_variables({}).get(MEMORY_KEY, []) or [] |
|
for m in hist: |
|
role = "user" if getattr(m, "type", "") == "human" else "assistant" |
|
msgs.append({"role": role, "content": getattr(m, "content", "")}) |
|
return msgs |
|
|
|
def _generate_reply(user_input: str) -> str: |
|
|
|
messages = _lc_to_messages() + [{"role": "user", "content": user_input}] |
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
out = chat_pipe( |
|
prompt, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
bad_words_ids=BAD_WORD_IDS, |
|
**GEN_KW, |
|
)[0]["generated_text"] |
|
return out.strip() |
|
|
|
|
|
|
|
|
|
def chat_with_memory(user_input: str) -> str: |
|
global stored_order, pending_intent |
|
ui = (user_input or "").strip() |
|
if not ui: |
|
return "How can I help with your order today?" |
|
|
|
|
|
hist = memory.load_memory_variables({}).get(MEMORY_KEY, []) or [] |
|
if len(hist) == 0: |
|
stored_order = None |
|
pending_intent = None |
|
|
|
|
|
if is_sexual_or_toxic(ui): |
|
reply = REFUSAL |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
low = ui.lower() |
|
|
|
|
|
if any(tok in low for tok in ["thank you","thanks","thx"]): |
|
reply = handle_gratitude() |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
new_o = extract_order(ui) |
|
if pending_intent: |
|
if new_o: |
|
stored_order = new_o |
|
fn = { |
|
"status": handle_status, |
|
"eta": handle_eta, |
|
"track": handle_track, |
|
"link": handle_link, |
|
"cancel": handle_cancel, |
|
}[pending_intent] |
|
reply = fn(stored_order) |
|
pending_intent = None |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
reply = "Got itβplease share your order number (e.g., #12345)." |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
if new_o: |
|
stored_order = new_o |
|
reply = handle_ask_action(stored_order) |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
if not any(k in low for k in ALLOWED_KEYWORDS) and not any(k in low for k in ("hi","hello","hey")): |
|
reply = "Iβm for store support only (orders, shipping, returns, warranty, account). How can I help with those?" |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
if any(k in low for k in ["status","where is my order","check status"]): |
|
intent = "status" |
|
elif any(k in low for k in ["how long","eta","delivery time"]): |
|
intent = "eta" |
|
elif any(k in low for k in ["how can i track","track my order","where is my package","tracking"]): |
|
intent = "track" |
|
elif "tracking link" in low or "resend" in low or "link" in low: |
|
intent = "link" |
|
elif any(k in low for k in ["cancel","cancellation","abort order"]): |
|
intent = "cancel" |
|
elif any(k in low for k in ["warranty","guarantee","policy"]): |
|
intent = "warranty_policy" |
|
elif "return" in low: |
|
intent = "return_policy" |
|
else: |
|
intent = "fallback" |
|
|
|
|
|
if intent in ("status","eta","track","link","cancel"): |
|
if not stored_order: |
|
pending_intent = intent |
|
reply = "Sureβwhatβs your order number (e.g., #12345)?" |
|
else: |
|
fn = { |
|
"status": handle_status, |
|
"eta": handle_eta, |
|
"track": handle_track, |
|
"link": handle_link, |
|
"cancel": handle_cancel, |
|
}[intent] |
|
reply = fn(stored_order) |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
if intent == "warranty_policy": |
|
reply = handle_warranty_policy() |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
if intent == "return_policy": |
|
reply = handle_return_policy() |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
reply = _generate_reply(ui) |
|
if is_sexual_or_toxic(reply): |
|
reply = REFUSAL |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|