exemple3 / app.py
kouki321's picture
Update app.py
b5dd82d verified
import os
import torch
from fastapi import FastAPI, File, UploadFile, HTTPException, Body
from fastapi.responses import JSONResponse,RedirectResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache , StaticCache
from pydantic import BaseModel
from typing import Optional
import uvicorn
import tempfile
from time import time
from pyngrok import ngrok
os.environ["HF_HOME"] = "/app/hf_cache"
#os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
# Add necessary serialization safety
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])
#These lines allow PyTorch to serialize and deserialize these objects without raising errors,
# #ensuring compatibility and functionality during cache saving/loading.
# Minimal generate function for token-by-token generation
def generate(model,
input_ids,
past_key_values,
max_new_tokens=50):
"""
This function performs token-by-token text generation using a pre-trained language model.
Purpose: To generate new text based on input tokens, without loading the full context repeatedly
Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum
Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens
"""
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly
input_ids = input_ids.to(device)#same device as the model.
output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens.
next_token = input_ids#the token that will process in the next iteration.
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]#Extracts the logits for the last token
token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token.
output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:] # Return just the newly generated part
def get_kv_cache(model, tokenizer, prompt):
"""
This function creates a key-value cache for a given prompt.
Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt
Process: Encodes the prompt, runs it through the model, and captures the resulting cache
Returns: The cache object and the original prompt length for future reference
"""
# Encode prompt
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache() # it grows as text is generated
# Run the model to populate the KV cache:
with torch.no_grad():
_ = model(
input_ids=input_ids,
past_key_values=cache,
use_cache=True
)
return cache, input_ids.shape[-1]
def clean_up(cache, origin_len):
# Make a deep copy of the cache first
new_cache = DynamicCache()
for i in range(len(cache.key_cache)):
new_cache.key_cache.append(cache.key_cache[i].clone())
new_cache.value_cache.append(cache.value_cache[i].clone())
# Remove any tokens appended to the original knowledge
for i in range(len(new_cache.key_cache)):
new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
return new_cache
#os.environ["TRANSFORMERS_OFFLINE"] = "1"
#os.environ["HF_HUB_OFFLINE"] = "1"
# Path to your local model
# Initialize model and tokenizer
def load_model_and_tokenizer():
model_name = "Locutusque/TinyMistral-248M"
#"tiiuae/falcon-rw-1b"
#"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
#"facebook/opt-125m"
# Load tokenizer and model from disk (without trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(model_name )
if torch.cuda.is_available():
# Load model on GPU if CUDA is available
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto" # Automatically map model layers to GPU
)
else:
# Load model on CPU if no GPU is available
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for compatibility with CPU
low_cpu_mem_usage=True # Reduce memory usage on CPU
)
return model, tokenizer
# Create FastAPI app
app = FastAPI(title="DeepSeek QA with KV Cache API")
# Initialize model and tokenizer at startup
model, tokenizer = load_model_and_tokenizer()
# Global variables to store the cache, origin length, and model/tokenizer
cache_store = {}
class QueryRequest(BaseModel):
query: str
max_new_tokens: Optional[int] = 150
def clean_response(response_text):
"""
Clean up model response by removing redundant tags, repetitions, and formatting issues.
"""
# First, try to extract just the answer content between tags if they exist
import re
# Try to extract content between assistant tags if present
assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
matches = assistant_pattern.findall(response_text)
if matches:
# Return the first meaningful assistant response
for match in matches:
cleaned = match.strip()
if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
return cleaned
# If no proper match found, try more aggressive cleaning
# Remove all tag markers completely
cleaned = re.sub(r'<\|.*?\|>', '', response_text)
cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
# Remove duplicate lines (common in generated responses)
lines = cleaned.strip().split('\n')
unique_lines = []
for line in lines:
line = line.strip()
if line and line not in unique_lines:
unique_lines.append(line)
result = '\n'.join(unique_lines)
# Final cleanup - remove any trailing system/user markers
result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
return result.strip()
@app.post("/upload-document_to_create_KV_cache")
async def upload_document(file: UploadFile = File(...)):
"""Upload a document and create KV cache for it"""
t1 = time()
# Save the uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
# Read the document
with open(temp_file_path, "r", encoding="utf-8") as f:
doc_text = f.read()
# Create system prompt with document context
system_prompt = f"""
<|system|>
Answer concisely and precisely, You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()
# Create KV cache
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
# Generate a unique ID for this document/cache
cache_id = f"cache_{int(time())}"
# Store the cache and origin_len
cache_store[cache_id] = {
"cache": cache,
"origin_len": origin_len,
"doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
}
# Clean up the temporary file
os.unlink(temp_file_path)
t2 = time()
return {
"cache_id": cache_id,
"message": "Document uploaded and cache created successfully",
"doc_preview": cache_store[cache_id]["doc_preview"],
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
# Clean up the temporary file in case of error
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
@app.post("/generate_answer_from_cache/{cache_id}")
async def generate_answer(cache_id: str, request: QueryRequest):
"""Generate an answer to a question based on the uploaded document"""
t1 = time()
# Check if the document/cache exists
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
# Get a clean copy of the cache
current_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
# Prepare input with just the query
full_prompt = f"""
<|user|>
Question: {request.query}
<|assistant|>
""".strip()
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
# Generate response
output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
rep = clean_response(response)
t2 = time()
return {
"query": request.query,
"answer": rep,
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
@app.post("/save_cache/{cache_id}")
async def save_cache(cache_id: str):
"""Save the cache for a document"""
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
# Clean up the cache and save it
cleaned_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
cache_path = f"{cache_id}_cache.pth"
torch.save(cleaned_cache, cache_path)
return {
"message": f"Cache saved successfully as {cache_path}",
"cache_path": cache_path
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")
@app.post("/load_cache")
async def load_cache(file: UploadFile = File(...)):
"""Load a previously saved cache"""
with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
# Load the cache
loaded_cache = torch.load(temp_file_path)
# Generate a unique ID for this cache
cache_id = f"loaded_cache_{int(time())}"
# Store the cache (we don't have the original document text)
cache_store[cache_id] = {
"cache": loaded_cache,
"origin_len": loaded_cache.key_cache[0].shape[-2],
"doc_preview": "Loaded from cache file"
}
# Clean up the temporary file
os.unlink(temp_file_path)
return {
"cache_id": cache_id,
"message": "Cache loaded successfully"
}
except Exception as e:
# Clean up the temporary file in case of error
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")
@app.get("/list_of_caches")
async def list_documents():
"""List all uploaded documents/caches"""
documents = {}
for cache_id in cache_store:
documents[cache_id] = {
"doc_preview": cache_store[cache_id]["doc_preview"],
"origin_len": cache_store[cache_id]["origin_len"]
}
return {"documents": documents}
@app.get("/", include_in_schema=False)
async def root():
return RedirectResponse(url="/docs")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)