|
import os |
|
import torch |
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Body |
|
from fastapi.responses import JSONResponse,RedirectResponse |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers.cache_utils import DynamicCache , StaticCache |
|
from pydantic import BaseModel |
|
from typing import Optional |
|
import uvicorn |
|
import tempfile |
|
from time import time |
|
from pyngrok import ngrok |
|
|
|
os.environ["HF_HOME"] = "/app/hf_cache" |
|
|
|
|
|
|
|
torch.serialization.add_safe_globals([DynamicCache]) |
|
torch.serialization.add_safe_globals([set]) |
|
|
|
|
|
|
|
|
|
def generate(model, |
|
input_ids, |
|
past_key_values, |
|
max_new_tokens=50): |
|
""" |
|
This function performs token-by-token text generation using a pre-trained language model. |
|
Purpose: To generate new text based on input tokens, without loading the full context repeatedly |
|
Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum |
|
Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens |
|
""" |
|
device = model.model.embed_tokens.weight.device |
|
origin_len = input_ids.shape[-1] |
|
input_ids = input_ids.to(device) |
|
output_ids = input_ids.clone() |
|
next_token = input_ids |
|
with torch.no_grad(): |
|
for _ in range(max_new_tokens): |
|
out = model( |
|
input_ids=next_token, |
|
past_key_values=past_key_values, |
|
use_cache=True |
|
) |
|
logits = out.logits[:, -1, :] |
|
token = torch.argmax(logits, dim=-1, keepdim=True) |
|
output_ids = torch.cat([output_ids, token], dim=-1) |
|
past_key_values = out.past_key_values |
|
next_token = token.to(device) |
|
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: |
|
break |
|
return output_ids[:, origin_len:] |
|
|
|
def get_kv_cache(model, tokenizer, prompt): |
|
""" |
|
This function creates a key-value cache for a given prompt. |
|
Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt |
|
Process: Encodes the prompt, runs it through the model, and captures the resulting cache |
|
Returns: The cache object and the original prompt length for future reference |
|
""" |
|
|
|
device = model.model.embed_tokens.weight.device |
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
cache = DynamicCache() |
|
|
|
|
|
with torch.no_grad(): |
|
_ = model( |
|
input_ids=input_ids, |
|
past_key_values=cache, |
|
use_cache=True |
|
) |
|
return cache, input_ids.shape[-1] |
|
|
|
def clean_up(cache, origin_len): |
|
|
|
new_cache = DynamicCache() |
|
for i in range(len(cache.key_cache)): |
|
new_cache.key_cache.append(cache.key_cache[i].clone()) |
|
new_cache.value_cache.append(cache.value_cache[i].clone()) |
|
|
|
|
|
for i in range(len(new_cache.key_cache)): |
|
new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :] |
|
new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :] |
|
return new_cache |
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(): |
|
model_name = "Locutusque/TinyMistral-248M" |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name ) |
|
if torch.cuda.is_available(): |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
|
|
) |
|
else: |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True |
|
|
|
) |
|
return model, tokenizer |
|
|
|
|
|
app = FastAPI(title="DeepSeek QA with KV Cache API") |
|
|
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
|
cache_store = {} |
|
class QueryRequest(BaseModel): |
|
query: str |
|
max_new_tokens: Optional[int] = 150 |
|
def clean_response(response_text): |
|
""" |
|
Clean up model response by removing redundant tags, repetitions, and formatting issues. |
|
""" |
|
|
|
import re |
|
|
|
|
|
assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL) |
|
matches = assistant_pattern.findall(response_text) |
|
|
|
if matches: |
|
|
|
for match in matches: |
|
cleaned = match.strip() |
|
if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5: |
|
return cleaned |
|
|
|
|
|
|
|
cleaned = re.sub(r'<\|.*?\|>', '', response_text) |
|
cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned) |
|
|
|
|
|
lines = cleaned.strip().split('\n') |
|
unique_lines = [] |
|
for line in lines: |
|
line = line.strip() |
|
if line and line not in unique_lines: |
|
unique_lines.append(line) |
|
|
|
result = '\n'.join(unique_lines) |
|
|
|
|
|
result = re.sub(r'<\/?\|.*?\|>\s*$', '', result) |
|
|
|
return result.strip() |
|
@app.post("/upload-document_to_create_KV_cache") |
|
async def upload_document(file: UploadFile = File(...)): |
|
"""Upload a document and create KV cache for it""" |
|
t1 = time() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file: |
|
temp_file_path = temp_file.name |
|
content = await file.read() |
|
temp_file.write(content) |
|
|
|
try: |
|
|
|
with open(temp_file_path, "r", encoding="utf-8") as f: |
|
doc_text = f.read() |
|
|
|
|
|
system_prompt = f""" |
|
<|system|> |
|
Answer concisely and precisely, You are an assistant who provides concise factual answers. |
|
<|user|> |
|
Context: |
|
{doc_text} |
|
Question: |
|
""".strip() |
|
|
|
|
|
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) |
|
|
|
|
|
cache_id = f"cache_{int(time())}" |
|
|
|
|
|
cache_store[cache_id] = { |
|
"cache": cache, |
|
"origin_len": origin_len, |
|
"doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text |
|
} |
|
|
|
|
|
os.unlink(temp_file_path) |
|
|
|
t2 = time() |
|
|
|
return { |
|
"cache_id": cache_id, |
|
"message": "Document uploaded and cache created successfully", |
|
"doc_preview": cache_store[cache_id]["doc_preview"], |
|
"time_taken": f"{t2 - t1:.4f} seconds" |
|
} |
|
|
|
except Exception as e: |
|
|
|
if os.path.exists(temp_file_path): |
|
os.unlink(temp_file_path) |
|
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}") |
|
|
|
@app.post("/generate_answer_from_cache/{cache_id}") |
|
async def generate_answer(cache_id: str, request: QueryRequest): |
|
"""Generate an answer to a question based on the uploaded document""" |
|
t1 = time() |
|
|
|
|
|
if cache_id not in cache_store: |
|
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") |
|
|
|
try: |
|
|
|
current_cache = clean_up( |
|
cache_store[cache_id]["cache"], |
|
cache_store[cache_id]["origin_len"] |
|
) |
|
|
|
|
|
full_prompt = f""" |
|
<|user|> |
|
Question: {request.query} |
|
<|assistant|> |
|
""".strip() |
|
|
|
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids |
|
|
|
|
|
output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens) |
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
rep = clean_response(response) |
|
t2 = time() |
|
|
|
return { |
|
"query": request.query, |
|
"answer": rep, |
|
"time_taken": f"{t2 - t1:.4f} seconds" |
|
} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}") |
|
|
|
@app.post("/save_cache/{cache_id}") |
|
async def save_cache(cache_id: str): |
|
"""Save the cache for a document""" |
|
if cache_id not in cache_store: |
|
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") |
|
|
|
try: |
|
|
|
cleaned_cache = clean_up( |
|
cache_store[cache_id]["cache"], |
|
cache_store[cache_id]["origin_len"] |
|
) |
|
|
|
cache_path = f"{cache_id}_cache.pth" |
|
torch.save(cleaned_cache, cache_path) |
|
|
|
return { |
|
"message": f"Cache saved successfully as {cache_path}", |
|
"cache_path": cache_path |
|
} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}") |
|
|
|
@app.post("/load_cache") |
|
async def load_cache(file: UploadFile = File(...)): |
|
"""Load a previously saved cache""" |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file: |
|
temp_file_path = temp_file.name |
|
content = await file.read() |
|
temp_file.write(content) |
|
|
|
try: |
|
|
|
loaded_cache = torch.load(temp_file_path) |
|
|
|
|
|
cache_id = f"loaded_cache_{int(time())}" |
|
|
|
|
|
cache_store[cache_id] = { |
|
"cache": loaded_cache, |
|
"origin_len": loaded_cache.key_cache[0].shape[-2], |
|
"doc_preview": "Loaded from cache file" |
|
} |
|
|
|
|
|
os.unlink(temp_file_path) |
|
|
|
return { |
|
"cache_id": cache_id, |
|
"message": "Cache loaded successfully" |
|
} |
|
|
|
except Exception as e: |
|
|
|
if os.path.exists(temp_file_path): |
|
os.unlink(temp_file_path) |
|
raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}") |
|
|
|
@app.get("/list_of_caches") |
|
async def list_documents(): |
|
"""List all uploaded documents/caches""" |
|
documents = {} |
|
for cache_id in cache_store: |
|
documents[cache_id] = { |
|
"doc_preview": cache_store[cache_id]["doc_preview"], |
|
"origin_len": cache_store[cache_id]["origin_len"] |
|
} |
|
|
|
return {"documents": documents} |
|
|
|
@app.get("/", include_in_schema=False) |
|
async def root(): |
|
return RedirectResponse(url="/docs") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |