igcs-demo / app.py
shmuelamar's picture
update readme with arxiv paper link
b00876e unverified
import functools
import gc
import json
import logging
import os
from pathlib import Path
try:
import spaces
except ModuleNotFoundError:
spaces = lambda: None
spaces.GPU = lambda fn: fn
import gradio as gr
import tiktoken
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM, AutoTokenizer
from igcs import grounding
from igcs.entities import Doc, Selection
from igcs.utils import log
from igcs.utils.diskcache import disk_cache
logger = logging.getLogger("igcs-demo")
_EXAMPLES_DIR = Path(__file__).parent
# In this simulation, we store only a single document although multi-document is possible.
# taken from https://en.wikipedia.org/wiki/Barack_Obama
with open(_EXAMPLES_DIR / "barack_obama_wiki.txt", encoding="utf8") as fp:
DEFAULT_TEXT = fp.read().strip()
# This is the global doc in this demo
DEFAULT_PROMPTS = (
"Select content that details Obama's initiatives",
"Select content that discusses Obama's personal life",
"Select content that details Obama's education",
"Select content with Obama's financial data",
)
# see src/igcs/prompting.py for more info
PROMPT_TEMPLATE = (
"Given the following document(s), {selection_instruction}. "
"Output the exact text phrases from the given document(s) as a valid json array of strings. Do not change the copied text.\n\n"
"Document #0:\n{doc.text}\n"
)
MODELS_LIST = [
# local models:
("====== IGCS Fine-tuned SLMs ======", None),
("Qwen2.5-3b-GenCS-union (local)", "shmuelamar/Qwen2.5-3b-GenCS-union"),
("Qwen2.5-3b-GenCS-majority (local)", "shmuelamar/Qwen2.5-3b-GenCS-majority"),
("Qwen2.5-7b-GenCS-union (local)", "shmuelamar/Qwen2.5-7b-GenCS-union"),
("Qwen2.5-7b-GenCS-majority (local)", "shmuelamar/Qwen2.5-7b-GenCS-majority"),
("Llama-3-8B-GenCS-union (local)", "shmuelamar/Llama-3-8B-GenCS-union"),
("Llama-3-8B-GenCS-majority (local)", "shmuelamar/Llama-3-8B-GenCS-majority"),
("SmolLM2-1.7B-GenCS-union (local)", "shmuelamar/SmolLM2-1.7B-GenCS-union"),
("SmolLM2-1.7B-GenCS-majority (local)", "shmuelamar/SmolLM2-1.7B-GenCS-majority"),
("====== Zero-shot SLMs ======", None),
("Qwen/Qwen2.5-3B-Instruct (local)", "Qwen/Qwen2.5-3B-Instruct"),
("Qwen/Qwen2.5-7B-Instruct (local)", "Qwen/Qwen2.5-7B-Instruct"),
("meta-llama/Meta-Llama-3-8B-Instruct (local)", "meta-llama/Meta-Llama-3-8B-Instruct"),
("HuggingFaceTB/SmolLM2-1.7B-Instruct (local)", "HuggingFaceTB/SmolLM2-1.7B-Instruct"),
("====== API-based Models (OpenRouter) ======", None),
("qwen/qwen3-14b (API)", "api:qwen/qwen3-14b:free"),
("moonshotai/kimi-k2 (API)", "api:moonshotai/kimi-k2:free"),
("deepseek/deepseek-chat-v3-0324 (API)", "api:deepseek/deepseek-chat-v3-0324:free"),
("meta-llama/llama-3.3-70b-instruct (API)", "api:meta-llama/llama-3.3-70b-instruct:free"),
("meta-llama/llama-3.1-405b-instruct (API)", "api:meta-llama/llama-3.1-405b-instruct:free"),
]
DEFAULT_MODEL = MODELS_LIST[1][1]
MAX_INPUT_TOKENS = 4500
MAX_PROMPT_TOKENS = 256
INTRO_TEXT = """
## 🚀 Welcome to the IGCS Live Demo!
This is a demo for the paper titled [**“A Unifying Scheme for Extractive Content Selection Tasks”**][arxiv-paper] — try Instruction‑Guided Content Selection on **any**
text or code: use the demo text or upload your document, enter an instruction, choose a model, and hit **Submit** to see the most relevant spans highlighted!
🔍 Learn more in our [paper][arxiv-paper] and explore the full [GitHub repo](https://github.com/shmuelamar/igcs) ⭐. Enjoy! 🎉
[arxiv-paper]: http://arxiv.org/abs/2507.16922 "A Unifying Scheme for Extractive Content Selection Tasks"
"""
@spaces.GPU
def completion(prompt: str, model_id: str):
# load model and tokenizer
logger.info(f"loading local model and tokenizer for {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto")
logger.info(f"done loading {model_id}")
# tokenize
input_ids = tokenizer.apply_chat_template(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
# MPS (on Mac) requires manual attention mask
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=model.device)
logger.info(f"generating completion with model_id: {model.name_or_path} and prompt: {prompt!r}")
outputs = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=2048,
# eos_token_id=[tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]],
do_sample=False,
top_k=None,
top_p=None,
temperature=None,
)
# decode response
resp = tokenizer.decode(outputs[0][input_ids.shape[-1] :], skip_special_tokens=True)
# cleanup memory
del model, tokenizer
torch.cuda.empty_cache()
gc.collect()
return resp
def completion_openrouter(prompt: str, model_id: str):
logger.info(f"calling openrouter with model_id: {model_id} and prompt: {prompt!r}")
client = load_openrouter_client()
resp = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
)
return resp.choices[0].message.content
def load_openrouter_client():
logger.info(f"connecting to OpenRouter")
return OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=os.environ.get("OPENROUTER_API_KEY"),
)
@disk_cache(cache_dir=_EXAMPLES_DIR / "models-cache")
def get_completion_cache(*, prompt: str, model_id: str) -> str:
return get_completion(prompt=prompt, model_id=model_id)
@functools.lru_cache(maxsize=2048)
def get_completion(*, prompt: str, model_id: str):
if model_id.startswith("api:"):
return completion_openrouter(prompt, model_id.removeprefix("api:"))
else:
resp = completion(prompt, model_id)
return resp
TIKTOKEN_TOKENIZER = tiktoken.encoding_for_model("gpt-4")
def count_tokens(text: str) -> int:
return len(TIKTOKEN_TOKENIZER.encode(text))
def perform_igcs(
doc: Doc, selection_instruction: str, model_id: str
) -> tuple[list[Selection] | None, str]:
logger.info(f"performing selection with {selection_instruction!r} using {model_id!r}")
prompt = PROMPT_TEMPLATE.format(doc=doc, selection_instruction=selection_instruction)
# For the example inputs - we cache from disk as they are more popular
if doc.text == DEFAULT_TEXT and selection_instruction in DEFAULT_PROMPTS:
logger.info("using disk_cache mode")
resp = get_completion_cache(prompt=prompt, model_id=model_id)
else:
resp = get_completion(prompt=prompt, model_id=model_id)
logger.info(f"Got response from model: {model_id}: {resp!r}")
# First, parse the selections as json array of strings
selection_spans = grounding.parse_selection(resp)
# Next, ground them to specific character positions in the source documents
selections = grounding.ground_selections(selection_spans, docs=[doc])
logger.info(f"model selections: {selections!r}")
return selections, resp
def convert_selections_to_gradio_highlights(selections, doc) -> list[tuple[str, str | None]]:
pos = 0
highlights = []
# add hallucinations outside the text itself:
if any(sel.doc_id == -1 for sel in selections):
highlights.append(
("\n\nHallucinated selections (not found in the document):\n\n", "hallucination")
)
for sel in selections:
if sel.doc_id != -1: # not hallucination
continue
highlights.append((sel.content + "\n", "hallucination"))
selections.sort(key=lambda sel: (sel.end_pos, sel.start_pos))
for sel in selections:
if sel.doc_id == -1:
continue # hallucination
if pos < sel.start_pos:
highlights.append((doc.text[pos : sel.start_pos], None)) # outside selection
elif pos >= sel.end_pos:
continue # two selections overlap - we only display the first.
highlights.append(
(doc.text[sel.start_pos : sel.end_pos], sel.metadata["mode"])
) # the selection
pos = sel.end_pos
if pos + 1 < len(doc.text):
highlights.append((doc.text[pos:], None)) # end of the text
return highlights
def process_igcs_request(selection_instruction: str, model_id: str, doc_data: list[dict]):
if model_id is None:
raise gr.Error("Please select a valid model from the list.")
doc_text = "".join(
[doc["token"] for doc in doc_data if doc["class_or_confidence"] != "hallucination"]
)
if count_tokens(doc_text) > MAX_INPUT_TOKENS:
raise gr.Error(
f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported"
)
if count_tokens(selection_instruction) > MAX_PROMPT_TOKENS:
raise gr.Error(f"Prompt is too long! only supports up-to {MAX_PROMPT_TOKENS} tokens.")
# Perform content selection
# TODO: cache examples
doc = Doc(id=0, text=doc_text)
selections, model_resp = perform_igcs(doc, selection_instruction, model_id)
if selections is None:
raise gr.Error(
"Cannot parse selections, model response is invalid. please try another instruction or model."
)
# Post-process selections for display as highlighted spans
highlights = convert_selections_to_gradio_highlights(selections, doc)
selections_text = json.dumps([s.model_dump(mode="json") for s in selections], indent=2)
return highlights, model_resp, selections_text
def get_app() -> gr.Interface:
with gr.Blocks(title="Instruction-guided content selection", theme="ocean", head="") as app:
with gr.Row():
gr.Markdown(INTRO_TEXT)
with gr.Row(equal_height=True):
with gr.Column(scale=2, min_width=300):
prompt_text = gr.Dropdown(
label="Content Selection Instruction:",
info='Choose an existing instruction or write a short one, starting with "Select content" or "Select code".',
value=DEFAULT_PROMPTS[0],
choices=DEFAULT_PROMPTS,
multiselect=False,
allow_custom_value=True,
)
with gr.Column(scale=1, min_width=200):
model_selector = gr.Dropdown(
label="Choose a Model",
info="Choose a model from the predefined list below.",
value=DEFAULT_MODEL,
choices=MODELS_LIST,
multiselect=False,
allow_custom_value=False,
)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary")
upload_button = gr.UploadButton("Upload a text or code file", file_count="single")
reset_button = gr.Button("Default text")
with gr.Row():
with gr.Accordion("Detailed response", open=False):
model_resp_text = gr.Code(
label="Model's raw response",
interactive=False,
value="No response yet",
lines=5,
language="json",
)
model_selections_text = gr.Code(
label="Grounded selections",
interactive=False,
value="No response yet",
lines=10,
language="json",
)
with gr.Row():
highlighted_text = gr.HighlightedText(
label="Selected Content",
value=[(DEFAULT_TEXT, None), ("", "exact_match")],
combine_adjacent=False,
show_legend=True,
interactive=False,
color_map={
"exact_match": "lightgreen",
"normalized_match": "green",
"fuzzy_match": "yellow",
"hallucination": "red",
},
)
def upload_file(filepath):
with open(filepath, "r", encoding="utf8") as fp:
text = fp.read().strip()
if count_tokens(text) > MAX_INPUT_TOKENS:
raise gr.Error(
f"File too large! currently only up-to {MAX_INPUT_TOKENS} tokens are supported"
)
return [(text, None), ("", "exact_match")]
def reset_text(*args):
return [(DEFAULT_TEXT, None), ("", "exact_match")]
upload_button.upload(upload_file, upload_button, outputs=[highlighted_text])
submit_button.click(
process_igcs_request,
inputs=[prompt_text, model_selector, highlighted_text],
outputs=[highlighted_text, model_resp_text, model_selections_text],
)
reset_button.click(reset_text, reset_button, outputs=[highlighted_text])
return app
if __name__ == "__main__":
log.init()
logger.info("starting app")
app = get_app()
app.queue()
app.launch()
logger.info("done")