File size: 13,040 Bytes
3a8259f
 
 
 
 
09c6768
3a8259f
 
 
 
 
 
0cdbb5c
2e0bc05
3a8259f
 
 
 
b88833e
3a8259f
 
 
 
b7d1634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d98dad2
 
3a8259f
b88833e
 
3a8259f
 
2e0bc05
594495c
2e0bc05
 
 
 
 
 
 
3a8259f
 
8f2ea38
 
 
 
 
 
 
 
 
 
3a8259f
 
b88833e
3a8259f
 
 
 
 
 
 
 
 
 
 
 
b88833e
 
 
 
 
 
3a8259f
b88833e
 
3a8259f
d98dad2
3a8259f
b88833e
3a8259f
 
 
 
 
 
 
 
 
 
ef803a6
3a8259f
 
 
0448449
b88833e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a8259f
09c6768
4e5a960
3a8259f
 
 
 
 
 
 
b7d1634
 
b88833e
3a8259f
7c58507
0448449
7c58507
 
 
 
 
 
 
b7d1634
 
 
7c58507
 
 
 
b88833e
2e0bc05
3a8259f
 
7c58507
b88833e
 
 
 
 
 
 
7c58507
b88833e
 
7c58507
 
b88833e
7c58507
 
 
 
 
 
 
b88833e
2e0bc05
 
 
 
 
 
 
7c58507
 
2e0bc05
 
 
 
 
 
 
 
 
 
 
 
 
3a8259f
 
 
 
 
 
 
b7d1634
b88833e
b7d1634
b88833e
 
b7d1634
 
b88833e
 
3a8259f
 
 
 
 
 
 
 
b88833e
 
 
 
 
3a8259f
7c58507
3a8259f
b7d1634
 
 
 
 
 
7c58507
 
3a8259f
 
 
b88833e
 
 
 
3a8259f
 
 
 
 
7c58507
2e0bc05
 
 
 
 
7c58507
 
 
3a8259f
 
b7d1634
3a8259f
 
 
7c58507
 
 
c1f2553
7c58507
6e7d264
8f2ea38
d98dad2
 
3a8259f
b88833e
 
 
 
 
3a8259f
b7d1634
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import os
import logging
import threading
import http.server
import socketserver
from functools import lru_cache
from typing import Optional

import gradio as gr
from transformers.pipelines import pipeline
from transformers import AutoTokenizer
import torch
import importlib
import time


# ---------------- Configuration ----------------
MODEL_ID = os.getenv("MODEL_ID", "tasal9/ZamAI-mT5-Pashto")
CACHE_DIR = os.getenv("HF_HOME", None)  # optional cache dir for transformers
HEALTH_PORT = int(os.getenv("HEALTH_PORT", "8080"))
GRADIO_HOST = os.getenv("GRADIO_HOST", "0.0.0.0")
GRADIO_PORT = int(os.getenv("GRADIO_PORT", "7860"))
DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", "128"))
ECHO_MODE = os.getenv("ECHO_MODE", "off").lower()  # default env; UI can override at runtime
OFFLINE_FLAG = os.getenv("OFFLINE", "0").lower() in {"1", "true", "yes"}
if OFFLINE_FLAG:
    os.environ["HF_HUB_OFFLINE"] = "1"

def _log_cache_env():
    try:
        import huggingface_hub as _hub
        hub_cache = getattr(_hub.constants, 'HF_HUB_CACHE', None)
    except Exception:
        hub_cache = None
    logging.info(
        "Cache config: HF_HOME=%s TRANSFORMERS_CACHE=%s HF_HUB_OFFLINE=%s hub_cache=%s",
        os.getenv("HF_HOME"), os.getenv("TRANSFORMERS_CACHE"), os.getenv("HF_HUB_OFFLINE"), hub_cache
    )

_log_cache_env()


# ---------------- Logging ----------------
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("zamai-app")

# Metrics storage for last real generation
LAST_METRICS: dict[str, float | int | str | None] = {
    "latency_sec": None,
    "input_tokens": None,
    "output_tokens": None,
    "num_sequences": None,
    "mode": None,
}


# ---------------- Utilities ----------------
SAMPLE_INSTRUCTIONS = [
    "په پښتو کې د خپل نوم او د عمر معلومات ولیکئ.",
    "د هوا د حالت په اړه لنډ راپور ورکړئ.",
    "په پښتو کې یوه لنډه کیسه ولیکئ چې د ښوونځي د ژوند په اړه وي.",
    "د خپلو ملګرو لپاره د یوې کوچنۍ پیغام ولیکئ.",
    "په پښتو کې د خپل خوښې خواړه تشریح کړئ او ووایاست ولې یې خوښوی.",
    "د خپلې سیمې د تاریخي ځایونو په اړه لنډ معلومات ورکړئ.",
    "یو ورځني کارنامه ولیکئ چې په کور کې څه کارونه ترسره کوئ."
]


def _start_health_server(port: int):
    """Start a tiny HTTP server that responds 200 to /health on a background thread."""
    class HealthHandler(http.server.SimpleHTTPRequestHandler):
        def do_GET(self):
            if self.path == "/health":
                self.send_response(200)
                self.send_header("Content-type", "text/plain")
                self.end_headers()
                self.wfile.write(b"ok")
            else:
                self.send_response(404)
                self.end_headers()

    def _serve():
        try:
            with socketserver.TCPServer(("", int(port)), HealthHandler) as httpd:
                logger.info("Health endpoint listening on port %s", port)
                httpd.serve_forever()
        except Exception as e:
            logger.exception("Health server failed: %s", e)

    t = threading.Thread(target=_serve, daemon=True)
    t.start()


def _detect_device() -> int:
    # return device id for transformers pipeline: -1 for CPU or 0..N for CUDA
    try:
        if torch.cuda.is_available():
            logger.info("CUDA available; using GPU device 0")
            return 0
    except Exception:
        logger.debug("torch.cuda check failed; falling back to CPU")
    return -1


# ---------------- Generator (cached) ----------------
@lru_cache(maxsize=1)
def get_generator(model_id: str = MODEL_ID, cache_dir: Optional[str] = CACHE_DIR):
    device = _detect_device()
    logger.info("Loading tokenizer and model: %s (device=%s)", model_id, device)

    tokenizer = None
    local_model_path = None
    try:
        hf = importlib.import_module("huggingface_hub")
        snapshot_download = getattr(hf, "snapshot_download", None)
        if snapshot_download:
            try:
                logger.info("Attempting to snapshot_download model %s to cache_dir=%s", model_id, cache_dir)
                local_model_path = snapshot_download(repo_id=model_id, cache_dir=cache_dir, repo_type="model")
                if local_model_path:
                    local_model_path = str(local_model_path)
                    logger.info("Model snapshot downloaded to %s", local_model_path)
            except Exception as e:
                logger.warning("snapshot_download failed for %s: %s", model_id, e)
                local_model_path = None
    except Exception:
        logger.debug("huggingface_hub not available; falling back to AutoTokenizer.from_pretrained")

    try:
        if local_model_path:
            tokenizer = AutoTokenizer.from_pretrained(local_model_path, use_fast=False, cache_dir=cache_dir)
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, cache_dir=cache_dir)
        logger.info("Loaded tokenizer for %s", model_id)
    except Exception as e2:
        logger.exception("Failed to load tokenizer for %s: %s", model_id, e2)
        raise

    gen = pipeline(
        "text2text-generation",
        model=model_id,
        tokenizer=tokenizer,
        device=device,
    )
    return gen


def predict(instruction: str,
            input_text: str,
            max_new_tokens: int,
            num_beams: int,
            do_sample: bool,
            temperature: float,
            top_p: float,
            num_return_sequences: int,
            mode: str):
    """Generate text using the cached pipeline and return output or error message."""
    if not instruction or not instruction.strip():
        return "⚠️ مهرباني وکړئ یوه لارښوونه ولیکئ."

    def build_prompt() -> str:
        base = instruction.strip()
        if input_text and input_text.strip():
            return base + "\n" + input_text.strip()
        return base

    prompt = build_prompt()
    active_mode = (mode or "").strip().lower() or ECHO_MODE
    if active_mode in ("echo", "useless"):
        if active_mode == "echo":
            return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n````\n{prompt}\n````"
        return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\nThis is a useless placeholder response."

    allowed_keys = {"max_new_tokens", "num_beams", "do_sample", "temperature", "top_p", "num_return_sequences"}

    start = time.time()
    try:
        gen = get_generator()
        raw_kwargs = {
            "max_new_tokens": int(max_new_tokens),
            "num_beams": int(num_beams) if not do_sample else 1,
            "do_sample": bool(do_sample),
            "temperature": float(temperature),
            "top_p": float(top_p),
            "num_return_sequences": max(1, int(num_return_sequences)),
        }
        gen_kwargs = {k: v for k, v in raw_kwargs.items() if k in allowed_keys}
        outputs = gen(prompt, **gen_kwargs)

        if not isinstance(outputs, list):
            outputs = [outputs]
        texts = []
        for out in outputs:
            if isinstance(out, dict):
                text = out.get("generated_text", "").strip()
            else:
                text = str(out).strip()
            if text:
                texts.append(text)
        if not texts:
            LAST_METRICS.update({
                "latency_sec": round(time.time() - start, 3),
                "input_tokens": None,
                "output_tokens": 0,
                "num_sequences": 0,
                "mode": active_mode,
            })
            return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n⚠️ No response generated."
        joined = "\n\n---\n\n".join(texts)

        # Basic token counting via whitespace split (approximate)
        input_tokens = len(prompt.split())
        output_tokens = sum(len(t.split()) for t in texts)
        LAST_METRICS.update({
            "latency_sec": round(time.time() - start, 3),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "num_sequences": len(texts),
            "mode": active_mode,
        })
        metrics_md = f"\n\n### Metrics\n- Latency: {LAST_METRICS['latency_sec']}s\n- Input tokens (approx): {input_tokens}\n- Output tokens (approx): {output_tokens}\n- Sequences: {len(texts)}"
        return f"### Prompt\n\n````\n{prompt}\n````\n\n### Output\n\n{joined}{metrics_md}"
    except Exception as e:
        logger.exception("Generation failed: %s", e)
        return f"⚠️ Generation failed: {e}"


def build_ui():
    with gr.Blocks() as demo:
        device_label = "GPU" if _detect_device() != -1 else "CPU"
        gr.Markdown(
            f"""
            # ZamAI mT5 Pashto Demo
            اپلیکیشن  **ZamAI-mT5-Pashto** د پښتو لارښوونو لپاره.  
            **Device:** {device_label}  |  **Env Mode:** {ECHO_MODE}  |  **Offline:** {os.getenv('HF_HUB_OFFLINE','0')}  
            که د موډ بدلول غواړئ لاندې د Mode selector څخه استفاده وکړئ.
            """
        )
        with gr.Row():
            with gr.Column(scale=2):
                instruction_dropdown = gr.Dropdown(
                    choices=SAMPLE_INSTRUCTIONS,
                    label="نمونې لارښوونې",
                    value=SAMPLE_INSTRUCTIONS[0],
                    interactive=True,
                )
                instruction_textbox = gr.Textbox(
                    lines=3,
                    placeholder="دلته لارښوونه ولیکئ...",
                    label="لارښوونه",
                )
                input_text = gr.Textbox(lines=2, placeholder="اختیاري متن...", label="متن")
                output = gr.Markdown(label="ځواب")
                generate_btn = gr.Button("جوړول", variant="primary")
                mode_selector = gr.Dropdown(
                    choices=["off", "echo", "useless"],
                    value=ECHO_MODE,
                    label="Mode (off=real, echo=return prompt, useless=fixed)",
                    interactive=True,
                )
                status_box = gr.Markdown(value="Loading status pending...", label="Status")
                refresh_status = gr.Button("Refresh Status")

            with gr.Column(scale=1):
                gr.Markdown("### د تولید تنظیمات")
                max_new_tokens = gr.Slider(16, 512, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="اعظمي نوي ټوکنونه (max_new_tokens)")
                num_beams = gr.Slider(1, 8, value=2, step=1, label="شمیر شعاعونه (num_beams)")
                do_sample = gr.Checkbox(label="نمونې فعال کړئ (do_sample)", value=True)
                temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="تودوخه (temperature)")
                top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
                num_return_sequences = gr.Slider(1, 4, value=1, step=1, label="د راګرځېدونکو تسلسلو شمېر")

        instruction_dropdown.change(lambda x: x, inputs=instruction_dropdown, outputs=instruction_textbox)

        def refresh():
            base = f"**Device:** {'GPU' if _detect_device() != -1 else 'CPU'} | **Offline:** {os.getenv('HF_HUB_OFFLINE','0')} | **Env Mode:** {ECHO_MODE}"
            if LAST_METRICS.get('latency_sec') is not None:
                base += (f"<br>**Last Gen:** latency={LAST_METRICS['latency_sec']}s, "
                         f"in≈{LAST_METRICS['input_tokens']}, out≈{LAST_METRICS['output_tokens']}, seqs={LAST_METRICS['num_sequences']}")
            return base

        refresh_status.click(fn=refresh, inputs=None, outputs=status_box)

        generate_btn.click(
            fn=predict,
            inputs=[instruction_textbox, input_text, max_new_tokens, num_beams, do_sample, temperature, top_p, num_return_sequences, mode_selector],
            outputs=output,
        )

        # Model load banner shown after interface loads (async)
        def _post_load():
            return "✅ Model interface ready. If this is the first run and model wasn't cached, initial generation may still warm up."
        demo.load(_post_load, inputs=None, outputs=status_box)

    return demo


if __name__ == "__main__":
    logger.info("Starting ZamAI mT5 Pashto Demo (model=%s)", MODEL_ID)
    try:
        _start_health_server(HEALTH_PORT)
    except Exception:
        logger.exception("Failed to start health server")

    demo = build_ui()
    demo.launch(server_name=GRADIO_HOST, server_port=GRADIO_PORT)

    logging.info("HF_HOME=%s TRANSFORMERS_CACHE=%s", os.getenv("HF_HOME"), os.getenv("TRANSFORMERS_CACHE"))