Spaces:
Running
on
Zero
Running
on
Zero
# This space is mostly a copy of the work of Aritra Roy Gosthipaty (see https://huggingface.co/spaces/ariG23498/kv-press/blob/main/app.py) | |
import spaces | |
import requests | |
import gradio as gr | |
from bs4 import BeautifulSoup | |
from transformers import pipeline | |
from kvpress import ( | |
ExpectedAttentionPress, | |
KnormPress, | |
RandomPress, | |
SnapKVPress, | |
StreamingLLMPress, | |
TOVAPress, | |
) | |
press_dict = { | |
"ExpectedAttentionPress": ExpectedAttentionPress, | |
"KnormPress": KnormPress, | |
"RandomPress": RandomPress, | |
"SnapKVPress": SnapKVPress, | |
"StreamingLLMPress": StreamingLLMPress, | |
"TOVAPress": TOVAPress, | |
} | |
pipe_dict = dict( | |
(ckpt, pipeline("kv-press-text-generation", model=ckpt, device="cuda:0", torch_dtype="auto")) | |
for ckpt in ["meta-llama/Meta-Llama-3.1-8B-Instruct", "Qwen/Qwen2.5-7B-Instruct-1M"] | |
) | |
def process_request(url, question, press_name, pipe_name, compression_ratio): | |
""" """ | |
if press_name not in press_dict: | |
return f"Invalid press selected: {press_name}", -1, -1 | |
# Fetch the Wikipedia article | |
try: | |
content = requests.get(url).content | |
except requests.exceptions.RequestException as e: | |
return f"Error fetching the Wikipedia article: {str(e)}", -1, -1 | |
try: | |
# Parse the Wikipedia HTML | |
soup = BeautifulSoup(content, "html.parser") | |
context = "".join([p.text for p in soup.find_all("p")]) + "\n\n" | |
# Initialize the press | |
press = press_dict[press_name](compression_ratio) | |
num_tokens = pipe_dict[pipe_name].tokenizer(context, return_tensors="pt")["input_ids"].shape[1] | |
pred_answer = pipe_dict[pipe_name](context, question=question, press=press)["answer"] | |
return pred_answer, num_tokens, int(num_tokens * (1 - compression_ratio)) | |
except Exception as e: | |
if "CUDA out of memory" in str(e): | |
return "Error: CUDA out of memory. Try using a smaller article or a lower compression ratio.", -1 | |
else: | |
return str(e), -1, -1 | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Wikipedia Article Question Answering with kvpress | |
This demo answers questions about any given Wikipedia article. | |
Under the hood, [kvpress](https://github.com/NVIDIA/kvpress) *compresses the key-value (KV) cache* associated with the article, helping reduce memory usage and accelerate decoding. | |
**How to use:** | |
1. Enter a Wikipedia article URL | |
2. Type your question | |
3. Select a model, a press and the desired compression ratio | |
4. Press "Submit" to see the answer, along with token statistics before and after compression | |
""" | |
) | |
with gr.Row(): | |
url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here") | |
question_input = gr.Textbox(label="Question", placeholder="Type your question here") | |
with gr.Row(): | |
pipe_selector = gr.Dropdown( | |
choices=list(pipe_dict.keys()), | |
value="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
label="Select Model", | |
) | |
press_selector = gr.Dropdown( | |
choices=list(press_dict.keys()), | |
value="ExpectedAttentionPress", | |
label="Select Press", | |
) | |
compression_slider = gr.Slider(minimum=0.0, maximum=0.9, step=0.1, value=0.5, label="Compression Ratio") | |
output = gr.Textbox(label="Output", lines=10) | |
output_num_tokens = gr.Number(label="Number of tokens before compression", interactive=False) | |
output_compressed_num_tokens = gr.Number(label="Number of tokens after compression", interactive=False) | |
submit_button = gr.Button("Submit") | |
gr.Examples( | |
examples=[ | |
[ | |
"https://en.wikipedia.org/wiki/Nvidia", | |
"Complete this sentence: In May 2017, the program had 1,300 companies. As of March 2018, there were ", | |
"ExpectedAttentionPress", | |
0.5, | |
], | |
[ | |
"https://en.wikipedia.org/wiki/Hugging_Face", | |
"What was the original name of the transformers library ?", | |
"ExpectedAttentionPress", | |
0.5, | |
], | |
[ | |
"https://en.wikipedia.org/wiki/World_Chess_Championship_2024", | |
"On which move did the world chess championship end?", | |
"ExpectedAttentionPress", | |
0.5, | |
], | |
], | |
inputs=[url_input, question_input, press_selector, compression_slider], | |
) | |
submit_button.click( | |
process_request, | |
inputs=[url_input, question_input, press_selector, pipe_selector, compression_slider], | |
outputs=[output, output_num_tokens, output_compressed_num_tokens], | |
) | |
return demo | |
if __name__ == "__main__": | |
# Launch demo | |
demo = gradio_interface() | |
demo.launch() | |