Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,447 Bytes
2fda2ca cf4739b 2fda2ca b7ff1e4 2fda2ca b7ff1e4 2fda2ca b00876e 2fda2ca 9126461 2fda2ca 9126461 61c4504 b7ff1e4 61c4504 9126461 61c4504 2fda2ca 9126461 2fda2ca b00876e 2fda2ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 |
import functools
import gc
import json
import logging
import os
from pathlib import Path
try:
import spaces
except ModuleNotFoundError:
spaces = lambda: None
spaces.GPU = lambda fn: fn
import gradio as gr
import tiktoken
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM, AutoTokenizer
from igcs import grounding
from igcs.entities import Doc, Selection
from igcs.utils import log
from igcs.utils.diskcache import disk_cache
logger = logging.getLogger("igcs-demo")
_EXAMPLES_DIR = Path(__file__).parent
# In this simulation, we store only a single document although multi-document is possible.
# taken from https://en.wikipedia.org/wiki/Barack_Obama
with open(_EXAMPLES_DIR / "barack_obama_wiki.txt", encoding="utf8") as fp:
DEFAULT_TEXT = fp.read().strip()
# This is the global doc in this demo
DEFAULT_PROMPTS = (
"Select content that details Obama's initiatives",
"Select content that discusses Obama's personal life",
"Select content that details Obama's education",
"Select content with Obama's financial data",
)
# see src/igcs/prompting.py for more info
PROMPT_TEMPLATE = (
"Given the following document(s), {selection_instruction}. "
"Output the exact text phrases from the given document(s) as a valid json array of strings. Do not change the copied text.\n\n"
"Document #0:\n{doc.text}\n"
)
MODELS_LIST = [
# local models:
("====== IGCS Fine-tuned SLMs ======", None),
("Qwen2.5-3b-GenCS-union (local)", "shmuelamar/Qwen2.5-3b-GenCS-union"),
("Qwen2.5-3b-GenCS-majority (local)", "shmuelamar/Qwen2.5-3b-GenCS-majority"),
("Qwen2.5-7b-GenCS-union (local)", "shmuelamar/Qwen2.5-7b-GenCS-union"),
("Qwen2.5-7b-GenCS-majority (local)", "shmuelamar/Qwen2.5-7b-GenCS-majority"),
("Llama-3-8B-GenCS-union (local)", "shmuelamar/Llama-3-8B-GenCS-union"),
("Llama-3-8B-GenCS-majority (local)", "shmuelamar/Llama-3-8B-GenCS-majority"),
("SmolLM2-1.7B-GenCS-union (local)", "shmuelamar/SmolLM2-1.7B-GenCS-union"),
("SmolLM2-1.7B-GenCS-majority (local)", "shmuelamar/SmolLM2-1.7B-GenCS-majority"),
("====== Zero-shot SLMs ======", None),
("Qwen/Qwen2.5-3B-Instruct (local)", "Qwen/Qwen2.5-3B-Instruct"),
("Qwen/Qwen2.5-7B-Instruct (local)", "Qwen/Qwen2.5-7B-Instruct"),
("meta-llama/Meta-Llama-3-8B-Instruct (local)", "meta-llama/Meta-Llama-3-8B-Instruct"),
("HuggingFaceTB/SmolLM2-1.7B-Instruct (local)", "HuggingFaceTB/SmolLM2-1.7B-Instruct"),
("====== API-based Models (OpenRouter) ======", None),
("qwen/qwen3-14b (API)", "api:qwen/qwen3-14b:free"),
("moonshotai/kimi-k2 (API)", "api:moonshotai/kimi-k2:free"),
("deepseek/deepseek-chat-v3-0324 (API)", "api:deepseek/deepseek-chat-v3-0324:free"),
("meta-llama/llama-3.3-70b-instruct (API)", "api:meta-llama/llama-3.3-70b-instruct:free"),
("meta-llama/llama-3.1-405b-instruct (API)", "api:meta-llama/llama-3.1-405b-instruct:free"),
]
DEFAULT_MODEL = MODELS_LIST[1][1]
MAX_INPUT_TOKENS = 4500
MAX_PROMPT_TOKENS = 256
INTRO_TEXT = """
## 🚀 Welcome to the IGCS Live Demo!
This is a demo for the paper titled [**“A Unifying Scheme for Extractive Content Selection Tasks”**][arxiv-paper] — try Instruction‑Guided Content Selection on **any**
text or code: use the demo text or upload your document, enter an instruction, choose a model, and hit **Submit** to see the most relevant spans highlighted!
🔍 Learn more in our [paper][arxiv-paper] and explore the full [GitHub repo](https://github.com/shmuelamar/igcs) ⭐. Enjoy! 🎉
[arxiv-paper]: http://arxiv.org/abs/2507.16922 "A Unifying Scheme for Extractive Content Selection Tasks"
"""
@spaces.GPU
def completion(prompt: str, model_id: str):
# load model and tokenizer
logger.info(f"loading local model and tokenizer for {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto")
logger.info(f"done loading {model_id}")
# tokenize
input_ids = tokenizer.apply_chat_template(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
# MPS (on Mac) requires manual attention mask
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=model.device)
logger.info(f"generating completion with model_id: {model.name_or_path} and prompt: {prompt!r}")
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=2048,
# eos_token_id=[tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]],
do_sample=False,
top_k=None,
top_p=None,
temperature=None,
)
# decode response
resp = tokenizer.decode(outputs[0][input_ids.shape[-1] :], skip_special_tokens=True)
# cleanup memory
del model, tokenizer
torch.cuda.empty_cache()
gc.collect()
return resp
def completion_openrouter(prompt: str, model_id: str):
logger.info(f"calling openrouter with model_id: {model_id} and prompt: {prompt!r}")
client = load_openrouter_client()
resp = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
)
return resp.choices[0].message.content
def load_openrouter_client():
logger.info(f"connecting to OpenRouter")
return OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=os.environ.get("OPENROUTER_API_KEY"),
)
@disk_cache(cache_dir=_EXAMPLES_DIR / "models-cache")
def get_completion_cache(*, prompt: str, model_id: str) -> str:
return get_completion(prompt=prompt, model_id=model_id)
@functools.lru_cache(maxsize=2048)
def get_completion(*, prompt: str, model_id: str):
if model_id.startswith("api:"):
return completion_openrouter(prompt, model_id.removeprefix("api:"))
else:
resp = completion(prompt, model_id)
return resp
TIKTOKEN_TOKENIZER = tiktoken.encoding_for_model("gpt-4")
def count_tokens(text: str) -> int:
return len(TIKTOKEN_TOKENIZER.encode(text))
def perform_igcs(
doc: Doc, selection_instruction: str, model_id: str
) -> tuple[list[Selection] | None, str]:
logger.info(f"performing selection with {selection_instruction!r} using {model_id!r}")
prompt = PROMPT_TEMPLATE.format(doc=doc, selection_instruction=selection_instruction)
# For the example inputs - we cache from disk as they are more popular
if doc.text == DEFAULT_TEXT and selection_instruction in DEFAULT_PROMPTS:
logger.info("using disk_cache mode")
resp = get_completion_cache(prompt=prompt, model_id=model_id)
else:
resp = get_completion(prompt=prompt, model_id=model_id)
logger.info(f"Got response from model: {model_id}: {resp!r}")
# First, parse the selections as json array of strings
selection_spans = grounding.parse_selection(resp)
# Next, ground them to specific character positions in the source documents
selections = grounding.ground_selections(selection_spans, docs=[doc])
logger.info(f"model selections: {selections!r}")
return selections, resp
def convert_selections_to_gradio_highlights(selections, doc) -> list[tuple[str, str | None]]:
pos = 0
highlights = []
# add hallucinations outside the text itself:
if any(sel.doc_id == -1 for sel in selections):
highlights.append(
("\n\nHallucinated selections (not found in the document):\n\n", "hallucination")
)
for sel in selections:
if sel.doc_id != -1: # not hallucination
continue
highlights.append((sel.content + "\n", "hallucination"))
selections.sort(key=lambda sel: (sel.end_pos, sel.start_pos))
for sel in selections:
if sel.doc_id == -1:
continue # hallucination
if pos < sel.start_pos:
highlights.append((doc.text[pos : sel.start_pos], None)) # outside selection
elif pos >= sel.end_pos:
continue # two selections overlap - we only display the first.
highlights.append(
(doc.text[sel.start_pos : sel.end_pos], sel.metadata["mode"])
) # the selection
pos = sel.end_pos
if pos + 1 < len(doc.text):
highlights.append((doc.text[pos:], None)) # end of the text
return highlights
def process_igcs_request(selection_instruction: str, model_id: str, doc_data: list[dict]):
if model_id is None:
raise gr.Error("Please select a valid model from the list.")
doc_text = "".join(
[doc["token"] for doc in doc_data if doc["class_or_confidence"] != "hallucination"]
)
if count_tokens(doc_text) > MAX_INPUT_TOKENS:
raise gr.Error(
f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported"
)
if count_tokens(selection_instruction) > MAX_PROMPT_TOKENS:
raise gr.Error(f"Prompt is too long! only supports up-to {MAX_PROMPT_TOKENS} tokens.")
# Perform content selection
# TODO: cache examples
doc = Doc(id=0, text=doc_text)
selections, model_resp = perform_igcs(doc, selection_instruction, model_id)
if selections is None:
raise gr.Error(
"Cannot parse selections, model response is invalid. please try another instruction or model."
)
# Post-process selections for display as highlighted spans
highlights = convert_selections_to_gradio_highlights(selections, doc)
selections_text = json.dumps([s.model_dump(mode="json") for s in selections], indent=2)
return highlights, model_resp, selections_text
def get_app() -> gr.Interface:
with gr.Blocks(title="Instruction-guided content selection", theme="ocean", head="") as app:
with gr.Row():
gr.Markdown(INTRO_TEXT)
with gr.Row(equal_height=True):
with gr.Column(scale=2, min_width=300):
prompt_text = gr.Dropdown(
label="Content Selection Instruction:",
info='Choose an existing instruction or write a short one, starting with "Select content" or "Select code".',
value=DEFAULT_PROMPTS[0],
choices=DEFAULT_PROMPTS,
multiselect=False,
allow_custom_value=True,
)
with gr.Column(scale=1, min_width=200):
model_selector = gr.Dropdown(
label="Choose a Model",
info="Choose a model from the predefined list below.",
value=DEFAULT_MODEL,
choices=MODELS_LIST,
multiselect=False,
allow_custom_value=False,
)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary")
upload_button = gr.UploadButton("Upload a text or code file", file_count="single")
reset_button = gr.Button("Default text")
with gr.Row():
with gr.Accordion("Detailed response", open=False):
model_resp_text = gr.Code(
label="Model's raw response",
interactive=False,
value="No response yet",
lines=5,
language="json",
)
model_selections_text = gr.Code(
label="Grounded selections",
interactive=False,
value="No response yet",
lines=10,
language="json",
)
with gr.Row():
highlighted_text = gr.HighlightedText(
label="Selected Content",
value=[(DEFAULT_TEXT, None), ("", "exact_match")],
combine_adjacent=False,
show_legend=True,
interactive=False,
color_map={
"exact_match": "lightgreen",
"normalized_match": "green",
"fuzzy_match": "yellow",
"hallucination": "red",
},
)
def upload_file(filepath):
with open(filepath, "r", encoding="utf8") as fp:
text = fp.read().strip()
if count_tokens(text) > MAX_INPUT_TOKENS:
raise gr.Error(
f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported"
)
return [(text, None), ("", "exact_match")]
def reset_text(*args):
return [(DEFAULT_TEXT, None), ("", "exact_match")]
upload_button.upload(upload_file, upload_button, outputs=[highlighted_text])
submit_button.click(
process_igcs_request,
inputs=[prompt_text, model_selector, highlighted_text],
outputs=[highlighted_text, model_resp_text, model_selections_text],
)
reset_button.click(reset_text, reset_button, outputs=[highlighted_text])
return app
if __name__ == "__main__":
log.init()
logger.info("starting app")
app = get_app()
app.queue()
app.launch()
logger.info("done")
|