Spaces:
Running
on
Zero
Running
on
Zero
import functools | |
import gc | |
import json | |
import logging | |
import os | |
from pathlib import Path | |
try: | |
import spaces | |
except ModuleNotFoundError: | |
spaces = lambda: None | |
spaces.GPU = lambda fn: fn | |
import gradio as gr | |
import tiktoken | |
import torch | |
from openai import OpenAI | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from igcs import grounding | |
from igcs.entities import Doc, Selection | |
from igcs.utils import log | |
from igcs.utils.diskcache import disk_cache | |
logger = logging.getLogger("igcs-demo") | |
_EXAMPLES_DIR = Path(__file__).parent | |
# In this simulation, we store only a single document although multi-document is possible. | |
# taken from https://en.wikipedia.org/wiki/Barack_Obama | |
with open(_EXAMPLES_DIR / "barack_obama_wiki.txt", encoding="utf8") as fp: | |
DEFAULT_TEXT = fp.read().strip() | |
# This is the global doc in this demo | |
DEFAULT_PROMPTS = ( | |
"Select content that details Obama's initiatives", | |
"Select content that discusses Obama's personal life", | |
"Select content that details Obama's education", | |
"Select content with Obama's financial data", | |
) | |
# see src/igcs/prompting.py for more info | |
PROMPT_TEMPLATE = ( | |
"Given the following document(s), {selection_instruction}. " | |
"Output the exact text phrases from the given document(s) as a valid json array of strings. Do not change the copied text.\n\n" | |
"Document #0:\n{doc.text}\n" | |
) | |
MODELS_LIST = [ | |
# local models: | |
("====== IGCS Fine-tuned SLMs ======", None), | |
("Qwen2.5-3b-GenCS-union (local)", "shmuelamar/Qwen2.5-3b-GenCS-union"), | |
("Qwen2.5-3b-GenCS-majority (local)", "shmuelamar/Qwen2.5-3b-GenCS-majority"), | |
("Qwen2.5-7b-GenCS-union (local)", "shmuelamar/Qwen2.5-7b-GenCS-union"), | |
("Qwen2.5-7b-GenCS-majority (local)", "shmuelamar/Qwen2.5-7b-GenCS-majority"), | |
("Llama-3-8B-GenCS-union (local)", "shmuelamar/Llama-3-8B-GenCS-union"), | |
("Llama-3-8B-GenCS-majority (local)", "shmuelamar/Llama-3-8B-GenCS-majority"), | |
("SmolLM2-1.7B-GenCS-union (local)", "shmuelamar/SmolLM2-1.7B-GenCS-union"), | |
("SmolLM2-1.7B-GenCS-majority (local)", "shmuelamar/SmolLM2-1.7B-GenCS-majority"), | |
("====== Zero-shot SLMs ======", None), | |
("Qwen/Qwen2.5-3B-Instruct (local)", "Qwen/Qwen2.5-3B-Instruct"), | |
("Qwen/Qwen2.5-7B-Instruct (local)", "Qwen/Qwen2.5-7B-Instruct"), | |
("meta-llama/Meta-Llama-3-8B-Instruct (local)", "meta-llama/Meta-Llama-3-8B-Instruct"), | |
("HuggingFaceTB/SmolLM2-1.7B-Instruct (local)", "HuggingFaceTB/SmolLM2-1.7B-Instruct"), | |
("====== API-based Models (OpenRouter) ======", None), | |
("qwen/qwen3-14b (API)", "api:qwen/qwen3-14b:free"), | |
("moonshotai/kimi-k2 (API)", "api:moonshotai/kimi-k2:free"), | |
("deepseek/deepseek-chat-v3-0324 (API)", "api:deepseek/deepseek-chat-v3-0324:free"), | |
("meta-llama/llama-3.3-70b-instruct (API)", "api:meta-llama/llama-3.3-70b-instruct:free"), | |
("meta-llama/llama-3.1-405b-instruct (API)", "api:meta-llama/llama-3.1-405b-instruct:free"), | |
] | |
DEFAULT_MODEL = MODELS_LIST[1][1] | |
MAX_INPUT_TOKENS = 4500 | |
MAX_PROMPT_TOKENS = 256 | |
INTRO_TEXT = """ | |
## 🚀 Welcome to the IGCS Live Demo! | |
This is a demo for the paper titled [**“A Unifying Scheme for Extractive Content Selection Tasks”**][arxiv-paper] — try Instruction‑Guided Content Selection on **any** | |
text or code: use the demo text or upload your document, enter an instruction, choose a model, and hit **Submit** to see the most relevant spans highlighted! | |
🔍 Learn more in our [paper][arxiv-paper] and explore the full [GitHub repo](https://github.com/shmuelamar/igcs) ⭐. Enjoy! 🎉 | |
[arxiv-paper]: http://arxiv.org/abs/2507.16922 "A Unifying Scheme for Extractive Content Selection Tasks" | |
""" | |
def completion(prompt: str, model_id: str): | |
# load model and tokenizer | |
logger.info(f"loading local model and tokenizer for {model_id}") | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto") | |
logger.info(f"done loading {model_id}") | |
# tokenize | |
input_ids = tokenizer.apply_chat_template( | |
[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt}, | |
], | |
add_generation_prompt=True, | |
return_tensors="pt", | |
).to(model.device) | |
# MPS (on Mac) requires manual attention mask | |
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=model.device) | |
logger.info(f"generating completion with model_id: {model.name_or_path} and prompt: {prompt!r}") | |
outputs = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=2048, | |
# eos_token_id=[tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]], | |
do_sample=False, | |
top_k=None, | |
top_p=None, | |
temperature=None, | |
) | |
# decode response | |
resp = tokenizer.decode(outputs[0][input_ids.shape[-1] :], skip_special_tokens=True) | |
# cleanup memory | |
del model, tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
return resp | |
def completion_openrouter(prompt: str, model_id: str): | |
logger.info(f"calling openrouter with model_id: {model_id} and prompt: {prompt!r}") | |
client = load_openrouter_client() | |
resp = client.chat.completions.create( | |
model=model_id, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.0, | |
) | |
return resp.choices[0].message.content | |
def load_openrouter_client(): | |
logger.info(f"connecting to OpenRouter") | |
return OpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=os.environ.get("OPENROUTER_API_KEY"), | |
) | |
def get_completion_cache(*, prompt: str, model_id: str) -> str: | |
return get_completion(prompt=prompt, model_id=model_id) | |
def get_completion(*, prompt: str, model_id: str): | |
if model_id.startswith("api:"): | |
return completion_openrouter(prompt, model_id.removeprefix("api:")) | |
else: | |
resp = completion(prompt, model_id) | |
return resp | |
TIKTOKEN_TOKENIZER = tiktoken.encoding_for_model("gpt-4") | |
def count_tokens(text: str) -> int: | |
return len(TIKTOKEN_TOKENIZER.encode(text)) | |
def perform_igcs( | |
doc: Doc, selection_instruction: str, model_id: str | |
) -> tuple[list[Selection] | None, str]: | |
logger.info(f"performing selection with {selection_instruction!r} using {model_id!r}") | |
prompt = PROMPT_TEMPLATE.format(doc=doc, selection_instruction=selection_instruction) | |
# For the example inputs - we cache from disk as they are more popular | |
if doc.text == DEFAULT_TEXT and selection_instruction in DEFAULT_PROMPTS: | |
logger.info("using disk_cache mode") | |
resp = get_completion_cache(prompt=prompt, model_id=model_id) | |
else: | |
resp = get_completion(prompt=prompt, model_id=model_id) | |
logger.info(f"Got response from model: {model_id}: {resp!r}") | |
# First, parse the selections as json array of strings | |
selection_spans = grounding.parse_selection(resp) | |
# Next, ground them to specific character positions in the source documents | |
selections = grounding.ground_selections(selection_spans, docs=[doc]) | |
logger.info(f"model selections: {selections!r}") | |
return selections, resp | |
def convert_selections_to_gradio_highlights(selections, doc) -> list[tuple[str, str | None]]: | |
pos = 0 | |
highlights = [] | |
# add hallucinations outside the text itself: | |
if any(sel.doc_id == -1 for sel in selections): | |
highlights.append( | |
("\n\nHallucinated selections (not found in the document):\n\n", "hallucination") | |
) | |
for sel in selections: | |
if sel.doc_id != -1: # not hallucination | |
continue | |
highlights.append((sel.content + "\n", "hallucination")) | |
selections.sort(key=lambda sel: (sel.end_pos, sel.start_pos)) | |
for sel in selections: | |
if sel.doc_id == -1: | |
continue # hallucination | |
if pos < sel.start_pos: | |
highlights.append((doc.text[pos : sel.start_pos], None)) # outside selection | |
elif pos >= sel.end_pos: | |
continue # two selections overlap - we only display the first. | |
highlights.append( | |
(doc.text[sel.start_pos : sel.end_pos], sel.metadata["mode"]) | |
) # the selection | |
pos = sel.end_pos | |
if pos + 1 < len(doc.text): | |
highlights.append((doc.text[pos:], None)) # end of the text | |
return highlights | |
def process_igcs_request(selection_instruction: str, model_id: str, doc_data: list[dict]): | |
if model_id is None: | |
raise gr.Error("Please select a valid model from the list.") | |
doc_text = "".join( | |
[doc["token"] for doc in doc_data if doc["class_or_confidence"] != "hallucination"] | |
) | |
if count_tokens(doc_text) > MAX_INPUT_TOKENS: | |
raise gr.Error( | |
f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported" | |
) | |
if count_tokens(selection_instruction) > MAX_PROMPT_TOKENS: | |
raise gr.Error(f"Prompt is too long! only supports up-to {MAX_PROMPT_TOKENS} tokens.") | |
# Perform content selection | |
# TODO: cache examples | |
doc = Doc(id=0, text=doc_text) | |
selections, model_resp = perform_igcs(doc, selection_instruction, model_id) | |
if selections is None: | |
raise gr.Error( | |
"Cannot parse selections, model response is invalid. please try another instruction or model." | |
) | |
# Post-process selections for display as highlighted spans | |
highlights = convert_selections_to_gradio_highlights(selections, doc) | |
selections_text = json.dumps([s.model_dump(mode="json") for s in selections], indent=2) | |
return highlights, model_resp, selections_text | |
def get_app() -> gr.Interface: | |
with gr.Blocks(title="Instruction-guided content selection", theme="ocean", head="") as app: | |
with gr.Row(): | |
gr.Markdown(INTRO_TEXT) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2, min_width=300): | |
prompt_text = gr.Dropdown( | |
label="Content Selection Instruction:", | |
info='Choose an existing instruction or write a short one, starting with "Select content" or "Select code".', | |
value=DEFAULT_PROMPTS[0], | |
choices=DEFAULT_PROMPTS, | |
multiselect=False, | |
allow_custom_value=True, | |
) | |
with gr.Column(scale=1, min_width=200): | |
model_selector = gr.Dropdown( | |
label="Choose a Model", | |
info="Choose a model from the predefined list below.", | |
value=DEFAULT_MODEL, | |
choices=MODELS_LIST, | |
multiselect=False, | |
allow_custom_value=False, | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Submit", variant="primary") | |
upload_button = gr.UploadButton("Upload a text or code file", file_count="single") | |
reset_button = gr.Button("Default text") | |
with gr.Row(): | |
with gr.Accordion("Detailed response", open=False): | |
model_resp_text = gr.Code( | |
label="Model's raw response", | |
interactive=False, | |
value="No response yet", | |
lines=5, | |
language="json", | |
) | |
model_selections_text = gr.Code( | |
label="Grounded selections", | |
interactive=False, | |
value="No response yet", | |
lines=10, | |
language="json", | |
) | |
with gr.Row(): | |
highlighted_text = gr.HighlightedText( | |
label="Selected Content", | |
value=[(DEFAULT_TEXT, None), ("", "exact_match")], | |
combine_adjacent=False, | |
show_legend=True, | |
interactive=False, | |
color_map={ | |
"exact_match": "lightgreen", | |
"normalized_match": "green", | |
"fuzzy_match": "yellow", | |
"hallucination": "red", | |
}, | |
) | |
def upload_file(filepath): | |
with open(filepath, "r", encoding="utf8") as fp: | |
text = fp.read().strip() | |
if count_tokens(text) > MAX_INPUT_TOKENS: | |
raise gr.Error( | |
f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported" | |
) | |
return [(text, None), ("", "exact_match")] | |
def reset_text(*args): | |
return [(DEFAULT_TEXT, None), ("", "exact_match")] | |
upload_button.upload(upload_file, upload_button, outputs=[highlighted_text]) | |
submit_button.click( | |
process_igcs_request, | |
inputs=[prompt_text, model_selector, highlighted_text], | |
outputs=[highlighted_text, model_resp_text, model_selections_text], | |
) | |
reset_button.click(reset_text, reset_button, outputs=[highlighted_text]) | |
return app | |
if __name__ == "__main__": | |
log.init() | |
logger.info("starting app") | |
app = get_app() | |
app.queue() | |
app.launch() | |
logger.info("done") | |