import os import torch from fastapi import FastAPI, File, UploadFile, HTTPException, Body from fastapi.responses import JSONResponse,RedirectResponse from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.cache_utils import DynamicCache , StaticCache from pydantic import BaseModel from typing import Optional import uvicorn import tempfile from time import time from pyngrok import ngrok os.environ["HF_HOME"] = "/app/hf_cache" #os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache" # Add necessary serialization safety torch.serialization.add_safe_globals([DynamicCache]) torch.serialization.add_safe_globals([set]) #These lines allow PyTorch to serialize and deserialize these objects without raising errors, # #ensuring compatibility and functionality during cache saving/loading. # Minimal generate function for token-by-token generation def generate(model, input_ids, past_key_values, max_new_tokens=50): """ This function performs token-by-token text generation using a pre-trained language model. Purpose: To generate new text based on input tokens, without loading the full context repeatedly Process: It takes a model, input IDs, and cached key-values, then generates new tokens one by one up to the specified maximum Performance: Uses the cached key-values for efficiency and returns only the newly generated tokens """ device = model.model.embed_tokens.weight.device origin_len = input_ids.shape[-1]#Stores the length of the input sequence (number of tokens) before text generation begins./return only the newly input_ids = input_ids.to(device)#same device as the model. output_ids = input_ids.clone()#will be updated during the generation process to include newly generated tokens. next_token = input_ids#the token that will process in the next iteration. with torch.no_grad(): for _ in range(max_new_tokens): out = model( input_ids=next_token, past_key_values=past_key_values, use_cache=True ) logits = out.logits[:, -1, :]#Extracts the logits for the last token token = torch.argmax(logits, dim=-1, keepdim=True)#highest predicted probability as the next token. output_ids = torch.cat([output_ids, token], dim=-1)#add the newly generated token past_key_values = out.past_key_values next_token = token.to(device) if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: break return output_ids[:, origin_len:] # Return just the newly generated part def get_kv_cache(model, tokenizer, prompt): """ This function creates a key-value cache for a given prompt. Purpose: To pre-compute and store the model's internal representations (key-value states) for a prompt Process: Encodes the prompt, runs it through the model, and captures the resulting cache Returns: The cache object and the original prompt length for future reference """ # Encode prompt device = model.model.embed_tokens.weight.device input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) cache = DynamicCache() # it grows as text is generated # Run the model to populate the KV cache: with torch.no_grad(): _ = model( input_ids=input_ids, past_key_values=cache, use_cache=True ) return cache, input_ids.shape[-1] def clean_up(cache, origin_len): # Make a deep copy of the cache first new_cache = DynamicCache() for i in range(len(cache.key_cache)): new_cache.key_cache.append(cache.key_cache[i].clone()) new_cache.value_cache.append(cache.value_cache[i].clone()) # Remove any tokens appended to the original knowledge for i in range(len(new_cache.key_cache)): new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :] new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :] return new_cache #os.environ["TRANSFORMERS_OFFLINE"] = "1" #os.environ["HF_HUB_OFFLINE"] = "1" # Path to your local model # Initialize model and tokenizer def load_model_and_tokenizer(): model_name = "Locutusque/TinyMistral-248M" #"tiiuae/falcon-rw-1b" #"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" #"facebook/opt-125m" # Load tokenizer and model from disk (without trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(model_name ) if torch.cuda.is_available(): # Load model on GPU if CUDA is available model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" # Automatically map model layers to GPU ) else: # Load model on CPU if no GPU is available model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for compatibility with CPU low_cpu_mem_usage=True # Reduce memory usage on CPU ) return model, tokenizer # Create FastAPI app app = FastAPI(title="DeepSeek QA with KV Cache API") # Initialize model and tokenizer at startup model, tokenizer = load_model_and_tokenizer() # Global variables to store the cache, origin length, and model/tokenizer cache_store = {} class QueryRequest(BaseModel): query: str max_new_tokens: Optional[int] = 150 def clean_response(response_text): """ Clean up model response by removing redundant tags, repetitions, and formatting issues. """ # First, try to extract just the answer content between tags if they exist import re # Try to extract content between assistant tags if present assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL) matches = assistant_pattern.findall(response_text) if matches: # Return the first meaningful assistant response for match in matches: cleaned = match.strip() if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5: return cleaned # If no proper match found, try more aggressive cleaning # Remove all tag markers completely cleaned = re.sub(r'<\|.*?\|>', '', response_text) cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned) # Remove duplicate lines (common in generated responses) lines = cleaned.strip().split('\n') unique_lines = [] for line in lines: line = line.strip() if line and line not in unique_lines: unique_lines.append(line) result = '\n'.join(unique_lines) # Final cleanup - remove any trailing system/user markers result = re.sub(r'<\/?\|.*?\|>\s*$', '', result) return result.strip() @app.post("/upload-document_to_create_KV_cache") async def upload_document(file: UploadFile = File(...)): """Upload a document and create KV cache for it""" t1 = time() # Save the uploaded file temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file: temp_file_path = temp_file.name content = await file.read() temp_file.write(content) try: # Read the document with open(temp_file_path, "r", encoding="utf-8") as f: doc_text = f.read() # Create system prompt with document context system_prompt = f""" <|system|> Answer concisely and precisely, You are an assistant who provides concise factual answers. <|user|> Context: {doc_text} Question: """.strip() # Create KV cache cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) # Generate a unique ID for this document/cache cache_id = f"cache_{int(time())}" # Store the cache and origin_len cache_store[cache_id] = { "cache": cache, "origin_len": origin_len, "doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text } # Clean up the temporary file os.unlink(temp_file_path) t2 = time() return { "cache_id": cache_id, "message": "Document uploaded and cache created successfully", "doc_preview": cache_store[cache_id]["doc_preview"], "time_taken": f"{t2 - t1:.4f} seconds" } except Exception as e: # Clean up the temporary file in case of error if os.path.exists(temp_file_path): os.unlink(temp_file_path) raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}") @app.post("/generate_answer_from_cache/{cache_id}") async def generate_answer(cache_id: str, request: QueryRequest): """Generate an answer to a question based on the uploaded document""" t1 = time() # Check if the document/cache exists if cache_id not in cache_store: raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") try: # Get a clean copy of the cache current_cache = clean_up( cache_store[cache_id]["cache"], cache_store[cache_id]["origin_len"] ) # Prepare input with just the query full_prompt = f""" <|user|> Question: {request.query} <|assistant|> """.strip() input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids # Generate response output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) rep = clean_response(response) t2 = time() return { "query": request.query, "answer": rep, "time_taken": f"{t2 - t1:.4f} seconds" } except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}") @app.post("/save_cache/{cache_id}") async def save_cache(cache_id: str): """Save the cache for a document""" if cache_id not in cache_store: raise HTTPException(status_code=404, detail="Document not found. Please upload it first.") try: # Clean up the cache and save it cleaned_cache = clean_up( cache_store[cache_id]["cache"], cache_store[cache_id]["origin_len"] ) cache_path = f"{cache_id}_cache.pth" torch.save(cleaned_cache, cache_path) return { "message": f"Cache saved successfully as {cache_path}", "cache_path": cache_path } except Exception as e: raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}") @app.post("/load_cache") async def load_cache(file: UploadFile = File(...)): """Load a previously saved cache""" with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file: temp_file_path = temp_file.name content = await file.read() temp_file.write(content) try: # Load the cache loaded_cache = torch.load(temp_file_path) # Generate a unique ID for this cache cache_id = f"loaded_cache_{int(time())}" # Store the cache (we don't have the original document text) cache_store[cache_id] = { "cache": loaded_cache, "origin_len": loaded_cache.key_cache[0].shape[-2], "doc_preview": "Loaded from cache file" } # Clean up the temporary file os.unlink(temp_file_path) return { "cache_id": cache_id, "message": "Cache loaded successfully" } except Exception as e: # Clean up the temporary file in case of error if os.path.exists(temp_file_path): os.unlink(temp_file_path) raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}") @app.get("/list_of_caches") async def list_documents(): """List all uploaded documents/caches""" documents = {} for cache_id in cache_store: documents[cache_id] = { "doc_preview": cache_store[cache_id]["doc_preview"], "origin_len": cache_store[cache_id]["origin_len"] } return {"documents": documents} @app.get("/", include_in_schema=False) async def root(): return RedirectResponse(url="/docs") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)