|
import os |
|
import gc |
|
import json |
|
import random |
|
import torch |
|
import asyncio |
|
import logging |
|
import time |
|
from typing import List, Dict, Any, Optional, Union, AsyncGenerator, Tuple |
|
from fastapi import FastAPI, HTTPException, Query, Request, Depends, status |
|
from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse |
|
from fastapi.security import APIKeyHeader |
|
from pydantic import BaseModel, Field, ValidationError, validator |
|
from transformers import ( |
|
AutoConfig, AutoModelForCausalLM, AutoTokenizer, |
|
GenerationConfig, LogitsProcessorList, |
|
MinLengthLogitsProcessor, MaxLengthCriteria, |
|
StoppingCriteriaList, StoppingCriteria |
|
) |
|
import uvicorn |
|
from concurrent.futures import ThreadPoolExecutor |
|
import math |
|
import torch.nn.functional as F |
|
import copy |
|
|
|
app = FastAPI(title="Chatbot Profesional API", version="1.0.0") |
|
|
|
class StopSequenceCriteria(StoppingCriteria): |
|
def __init__(self, stop_sequences: List[str], tokenizer: AutoTokenizer): |
|
self.tokenizer = tokenizer |
|
self.stop_sequences_text = [] |
|
self.stop_sequence_ids = [] |
|
for seq in stop_sequences: |
|
if seq: |
|
encoded_ids = tokenizer.encode(seq, add_special_tokens=False) |
|
decoded_text = tokenizer.decode(encoded_ids, skip_special_tokens=True) |
|
if decoded_text: |
|
self.stop_sequences_text.append(decoded_text) |
|
self.stop_sequence_ids.append(encoded_ids) |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
if not self.stop_sequence_ids: |
|
return False |
|
|
|
input_ids_list = input_ids[0].tolist() |
|
|
|
for stop_seq_ids in self.stop_sequence_ids: |
|
stop_len = len(stop_seq_ids) |
|
if len(input_ids_list) >= stop_len: |
|
if input_ids_list[-stop_len:] == stop_seq_ids: |
|
return True |
|
|
|
check_tail_len = 50 |
|
if self.stop_sequence_ids: |
|
max_stop_seq_token_len = max((len(seq) for seq in self.stop_sequence_ids), default=0) |
|
check_tail_len = max(check_tail_len, max_stop_seq_token_len + 10) |
|
|
|
tail_ids = input_ids_list[-min(check_tail_len, len(input_ids_list)):] |
|
tail_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True) |
|
|
|
for stop_seq_text in self.stop_sequences_text: |
|
if stop_seq_text and stop_seq_text in tail_text: |
|
return True |
|
|
|
return False |
|
|
|
logging.getLogger("uvicorn").handlers.clear() |
|
logging.getLogger("uvicorn.error").handlers.clear() |
|
logging.getLogger("uvicorn.access").handlers.clear() |
|
logging.getLogger("uvicorn").propagate = False |
|
logging.getLogger("uvicorn.error").propagate = False |
|
logging.getLogger("uvicorn.access").propagate = False |
|
logging.getLogger("uvicorn").setLevel(logging.CRITICAL) |
|
logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL) |
|
logging.getLogger("uvicorn.access").setLevel(logging.CRITICAL) |
|
logging.getLogger("fastapi").setLevel(logging.CRITICAL) |
|
logging.getLogger("transformers").setLevel(logging.CRITICAL) |
|
logging.getLogger().handlers.clear() |
|
logging.getLogger().addHandler(logging.NullHandler()) |
|
|
|
DEFAULT_MODEL_NAME = "hghghgkskdmskdms/xddd" |
|
MODEL_NAME = os.environ.get("MODEL_NAME", DEFAULT_MODEL_NAME) |
|
SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", "Eres un asistente profesional y servicial.") |
|
|
|
try: |
|
MAX_CONTEXT_TOKENS = int(os.environ.get("MAX_CONTEXT_TOKENS", 1024)) |
|
if MAX_CONTEXT_TOKENS <= 0: |
|
raise ValueError("MAX_CONTEXT_TOKENS must be positive.") |
|
except (ValueError, TypeError) as e: |
|
logging.error(f"Invalid MAX_CONTEXT_TOKENS environment variable: {os.environ.get('MAX_CONTEXT_TOKENS')}. Using default 1024. Error: {e}") |
|
MAX_CONTEXT_TOKENS = 1024 |
|
|
|
try: |
|
MAX_GENERATION_TOKENS = int(os.environ.get("MAX_GENERATION_TOKENS", 512)) |
|
if MAX_GENERATION_TOKENS <= 0: |
|
raise ValueError("MAX_GENERATION_TOKENS must be positive.") |
|
except (ValueError, TypeError) as e: |
|
logging.error(f"Invalid MAX_GENERATION_TOKENS environment variable: {os.environ.get('MAX_GENERATION_TOKENS')}. Using default 512. Error: {e}") |
|
MAX_GENERATION_TOKENS = 512 |
|
|
|
try: |
|
MAX_CONCURRENT_GENERATIONS = int(os.environ.get("MAX_CONCURRENT_GENERATIONS", 4)) |
|
if MAX_CONCURRENT_GENERATIONS <= 0: |
|
raise ValueError("MAX_CONCURRENT_GENERATIONS must be positive.") |
|
except (ValueError, TypeError) as e: |
|
logging.error(f"Invalid MAX_CONCURRENT_GENERATIONS environment variable: {os.environ.get('MAX_CONCURRENT_GENERATIONS')}. Using default 4. Error: {e}") |
|
MAX_CONCURRENT_GENERATIONS = 4 |
|
|
|
TRUST_REMOTE_CODE_ENV = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true" |
|
TRUST_REMOTE_CODE = TRUST_REMOTE_CODE_ENV or (MODEL_NAME == DEFAULT_MODEL_NAME) |
|
ENABLE_FLASH_ATTENTION_2 = os.environ.get("ENABLE_FLASH_ATTENTION_2", "false").lower() == "true" |
|
TORCH_DTYPE_STR = os.environ.get("TORCH_DTYPE", "float32") |
|
TORCH_DTYPE = getattr(torch, TORCH_DTYPE_STR.lower(), torch.float32) |
|
if TORCH_DTYPE != torch.float32: |
|
logging.warning(f"Requested dtype {TORCH_DTYPE_STR} might not be fully performant on CPU. Using float32.") |
|
TORCH_DTYPE = torch.float32 |
|
|
|
API_KEY = os.environ.get("API_KEY") |
|
|
|
global_model = None |
|
global_tokenizer = None |
|
global_tokens: Dict[str, Optional[int]] = {} |
|
executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_GENERATIONS) |
|
generation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_GENERATIONS) |
|
|
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) |
|
|
|
async def get_api_key(api_key: str = Depends(api_key_header)): |
|
if API_KEY is None: |
|
return |
|
if api_key is None or api_key != API_KEY: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API Key") |
|
return api_key |
|
|
|
class GenerateRequest(BaseModel): |
|
input_text: str = Field(..., description="The input text from the user.", examples=["Hola, ¿cómo estás?"]) |
|
history: Optional[List[Dict[str, str]]] = Field(None, description="Conversation history.", examples=[[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris."}]]) |
|
stream: bool = Field(True, description="Whether to stream the response.") |
|
temperature: float = Field(1.0, ge=0.0, le=2.0, description="Controls the randomness.") |
|
top_k: int = Field(50, ge=0, description="Top-k filtering.") |
|
top_p: float = Field(1.0, ge=0.0, le=1.0, description="Top-p (nucleus) filtering.") |
|
repetition_penalty: float = Field(1.0, ge=0.0, description="Repetition penalty.") |
|
frequency_penalty: float = Field(0.0, ge=0.0, description="Frequency penalty.") |
|
presence_penalty: float = Field(0.0, ge=0.0, description="Presence penalty.") |
|
num_beams: int = Field(1, ge=1, description="Number of beams for beam search.") |
|
length_penalty: float = Field(1.0, ge=0.0, description="Length penalty.") |
|
no_repeat_ngram_size: int = Field(0, ge=0, description="No repeat ngram size.") |
|
early_stopping: bool = Field(False, description="Early stopping for beam search.") |
|
do_sample: bool = Field(True, description="Whether to use sampling.") |
|
use_mirostat: bool = Field(False, description="Whether to use Mirostat sampling.") |
|
mirostat_tau: float = Field(5.0, ge=0.0, description="Mirostat tau.") |
|
mirostat_eta: float = Field(0.1, ge=0.0, description="Mirostat eta.") |
|
max_new_tokens: int = Field(MAX_GENERATION_TOKENS, ge=1, description="Max new tokens.") |
|
system_prompt: Optional[str] = Field(None, description="Override the default system prompt.") |
|
seed: Optional[int] = Field(None, description="Random seed.") |
|
stop_sequences: Optional[List[str]] = Field(None, description="List of stop strings.", examples=[[".", "\nUsuario:"]]) |
|
tokenize_only: bool = Field(False, description="If true, only tokenize input.") |
|
strip_trailing_whitespace: bool = Field(False, description="Strip trailing whitespace.") |
|
remove_incomplete_sentences: bool = Field(False, description="Remove incomplete last sentence.") |
|
num_return_sequences: int = Field(1, ge=1, le=5, description="Number of sequences to return (non-streaming).") |
|
bad_words_ids: Optional[List[List[int]]] = Field(None, description="List of bad word token ids.", examples=[[[32000], [32001]]]) |
|
forced_bos_token_id: Optional[int] = Field(None, description="Forced BOS token id.") |
|
forced_eos_token_id: Optional[int] = Field(None, description="Forced EOS token id.") |
|
renormalize_logits: Optional[bool] = Field(None, description="Renormalize logits.") |
|
suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress.") |
|
begin_suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress at beginning.") |
|
end_suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress at end.") |
|
encoder_no_repeat_ngram_size: int = Field(0, ge=0, description="Encoder no repeat ngram size.") |
|
min_length: int = Field(0, ge=0, description="Minimum total length.") |
|
max_length: Optional[int] = Field(None, description="Maximum total length.") |
|
exponential_decay_length_penalty: Optional[Tuple[float, int, float]] = Field(None, description="Exponential decay length penalty.") |
|
use_cache: bool = Field(True, description="Use cache.") |
|
typical_p: float = Field(1.0, ge=0.0, le=1.0, description="Typical P sampling.") |
|
epsilon_cutoff: float = Field(0.0, ge=0.0, description="Epsilon cutoff for LTS.") |
|
eta_cutoff: float = Field(0.0, ge=0.0, description="Eta cutoff for LTS.") |
|
temperature_cutoff: Optional[float] = Field(None, ge=0.0, description="Temperature cutoff.") |
|
encoder_repetition_penalty: float = Field(1.0, ge=0.0, description="Encoder repetition penalty.") |
|
max_time: Optional[float] = Field(None, ge=0.0, description="Maximum time in seconds.") |
|
output_watermark: bool = Field(False, description="Output watermark.") |
|
remove_input_from_output: bool = Field(False, description="Remove input from output.") |
|
eos_token_id_override: Optional[int] = Field(None, description="Override EOS token id.") |
|
pad_token_id_override: Optional[int] = Field(None, description="Override PAD token id.") |
|
bos_token_id_override: Optional[int] = Field(None, description="Override BOS token id.") |
|
repetition_penalty_range: Optional[int] = Field(None, ge=0, description="Repetition penalty range.") |
|
diversity_penalty: float = Field(0.0, ge=0.0, description="Diversity penalty for diverse beam search.") |
|
num_beam_groups: int = Field(1, ge=1, description="Number of beam groups for diverse beam search.") |
|
return_dict_in_generate: bool = Field(False, description="Return dictionary from generate.") |
|
output_attentions: bool = Field(False, description="Output attentions.") |
|
output_hidden_states: bool = Field(False, description="Output hidden states.") |
|
output_scores: bool = Field(False, description="Output scores.") |
|
return_token_logprobs: bool = Field(False, description="Return token logprobs in stream.") |
|
return_text_from_sequence: bool = Field(True, description="Decode generated sequence to text.") |
|
length_normalization_factor: Optional[float] = Field(None, description="Length normalization factor for beam search.") |
|
min_new_tokens: int = Field(0, ge=0, description="Minimum number of new tokens.") |
|
do_normalize_logits: bool = Field(False, description="Normalize logits.") |
|
return_generation_inputs: bool = Field(False, description="Return generation inputs.") |
|
return_unused_generate_parameters: bool = Field(False, description="Return unused generate parameters.") |
|
use_fast_tokenizer: bool = Field(True, description="Use fast tokenizer if available.") |
|
model_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional model kwargs for generate.") |
|
tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs for encode.") |
|
return_only_text: bool = Field(False, description="If true, only return the generated text.") |
|
|
|
@validator('stop_sequences') |
|
def validate_stop_sequences(cls, v): |
|
if v is not None: |
|
if not all(isinstance(seq, str) for seq in v): |
|
raise ValueError('Each stop sequence must be a string') |
|
return v |
|
|
|
@validator('bad_words_ids') |
|
def validate_bad_words_ids(cls, v): |
|
if v is not None: |
|
if not all(isinstance(word_id_list, list) and all(isinstance(token_id, int) for token_id in word_id_list) for word_id_list in v): |
|
raise ValueError('bad_words_ids must be a list of lists of integers') |
|
return v |
|
|
|
@validator('exponential_decay_length_penalty') |
|
def validate_exponential_decay_length_penalty(cls, v): |
|
if v is not None: |
|
if not (isinstance(v, (list, tuple)) and len(v) == 3 and |
|
isinstance(v[0], (int, float)) and v[0] > 0 and |
|
isinstance(v[1], int) and v[1] >= 0 and |
|
isinstance(v[2], (int, float))): |
|
raise ValueError('exponential_decay_length_penalty must be a tuple/list of 3 numbers (decay_factor, start_index, threshold)') |
|
return v |
|
|
|
class TokenizeRequest(BaseModel): |
|
text: Union[str, List[str]] = Field(..., description="Text or list of texts to tokenize.") |
|
add_special_tokens: bool = Field(True, description="Whether to add special tokens.") |
|
is_split_into_words: bool = Field(False, description="Whether the input text is pre-tokenized.") |
|
return_token_type_ids: bool = Field(False, description="Whether to return token type IDs.") |
|
padding: Union[bool, str] = Field(False, description="Enable padding.") |
|
truncation: Union[bool, str] = Field(False, description="Enable truncation.") |
|
max_length: Optional[int] = Field(None, ge=1, description="Maximum length for padding and truncation.") |
|
return_tensors: Optional[str] = Field(None, description="The type of tensors to return.") |
|
return_attention_mask: Optional[bool] = Field(None, description="Whether to return the attention mask.") |
|
return_offsets_mapping: Optional[bool] = Field(None, description="Whether to return offsets mapping.") |
|
return_length: Optional[bool] = Field(None, description="Whether to return the length.") |
|
verbose: bool = Field(False, description="Verbose tokenizer output.") |
|
tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs.") |
|
|
|
class DecodeRequest(BaseModel): |
|
token_ids: List[int] = Field(..., description="List of token IDs to decode.", examples=[[1, 2, 3]]) |
|
skip_special_tokens: bool = Field(True, description="Skip special tokens.") |
|
clean_up_tokenization_spaces: bool = Field(True, description="Clean up spaces.") |
|
decode_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional decode kwargs.") |
|
|
|
class SystemPromptUpdateRequest(BaseModel): |
|
system_prompt: str = Field(..., description="The new global system prompt.") |
|
|
|
class ModelReloadRequest(BaseModel): |
|
model_name: Optional[str] = Field(None, description="New model name.") |
|
trust_remote_code: Optional[bool] = Field(None, description="Override trust_remote_code.") |
|
enable_flash_attention_2: Optional[bool] = Field(None, description="Override enable_flash_attention_2.") |
|
torch_dtype: Optional[str] = Field(None, description="Override torch_dtype.") |
|
model_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional model kwargs for from_pretrained().") |
|
tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs for from_pretrained().") |
|
|
|
def format_conversation(input_text: str, history: Optional[List[Dict[str, str]]], system_prompt: Optional[str]) -> str: |
|
full_history: List[Dict[str, str]] = [] |
|
used_system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT |
|
if not history or history[0].get("role") != "system" or history[0].get("content") != used_system_prompt: |
|
full_history.append({"role": "system", "content": used_system_prompt}) |
|
if history: |
|
full_history.extend(history) |
|
if not full_history or full_history[-1].get("role") != "user" or full_history[-1].get("content") != input_text: |
|
full_history.append({"role": "user", "content": input_text}) |
|
|
|
if global_tokenizer and hasattr(global_tokenizer, 'apply_chat_template') and global_tokenizer.chat_template: |
|
try: |
|
return global_tokenizer.apply_chat_template(full_history, tokenize=False, add_generation_prompt=True) |
|
except Exception as e: |
|
logging.error(f"Failed to apply chat template: {e}. Falling back to manual formatting.") |
|
pass |
|
formatted_text = "" |
|
for i, message in enumerate(full_history): |
|
if i == 0 and message["role"] == "system" and len(full_history) > 1 and full_history[1].get("role") == "system": |
|
continue |
|
if message["role"] == "system": |
|
formatted_text += f"{message['content'].strip()}\n\n" |
|
elif message["role"] == "user": |
|
formatted_text += f"Usuario: {message['content'].strip()}\n" |
|
elif message["role"] == "assistant": |
|
formatted_text += f"Bot: {message['content'].strip()}\n" |
|
if not formatted_text.endswith("Bot:"): |
|
formatted_text += "Bot:" |
|
return formatted_text.strip() |
|
|
|
def truncate_encoded_ids(input_ids: torch.Tensor, max_length: int) -> torch.Tensor: |
|
if input_ids.shape[-1] > max_length: |
|
return input_ids[:, -max_length:] |
|
return input_ids |
|
|
|
def apply_seed(seed: Optional[int]): |
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
def get_stopping_criteria(req: GenerateRequest, initial_ids: torch.Tensor, tokenizer: AutoTokenizer) -> StoppingCriteriaList: |
|
criteria = StoppingCriteriaList() |
|
max_len_from_req = None |
|
if req.max_length is not None and req.max_length > 0: |
|
max_len_from_req = req.max_length |
|
elif req.max_new_tokens is not None and req.max_new_tokens > 0: |
|
max_len_from_req = initial_ids.shape[-1] + req.max_new_tokens |
|
else: |
|
max_len_from_req = initial_ids.shape[-1] + MAX_GENERATION_TOKENS |
|
if max_len_from_req is not None and max_len_from_req > 0: |
|
criteria.append(MaxLengthCriteria(max_len_from_req)) |
|
if req.min_length is not None and req.min_length > 0: |
|
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id", -1) |
|
criteria.append(MinLengthLogitsProcessor(initial_ids.shape[-1] + req.min_length, eos_token_id)) |
|
if req.stop_sequences: |
|
criteria.append(StopSequenceCriteria(req.stop_sequences, tokenizer)) |
|
return criteria |
|
|
|
def generate_next_token_sync( |
|
input_ids, |
|
past_key_values, |
|
gen_cfg: GenerationConfig, |
|
device: str |
|
) -> Tuple[torch.Tensor, Any, Optional[float], Optional[torch.Tensor], Any, Any]: |
|
with torch.no_grad(): |
|
outputs = global_model( |
|
input_ids, past_key_values=past_key_values, |
|
use_cache=gen_cfg.use_cache, return_dict=True, |
|
output_attentions=gen_cfg.output_attentions, |
|
output_hidden_states=gen_cfg.output_hidden_states, |
|
output_scores=gen_cfg.output_scores, |
|
) |
|
logits = outputs.logits[:, -1, :] |
|
past = outputs.past_key_values |
|
scores = outputs.scores if gen_cfg.output_scores else None |
|
attentions = outputs.attentions if gen_cfg.output_attentions else None |
|
hidden_states = outputs.hidden_states if gen_cfg.output_hidden_states else None |
|
step_logits_for_criteria = logits.clone() |
|
if gen_cfg.do_normalize_logits: |
|
logits = F.log_softmax(logits, dim=-1) |
|
if gen_cfg.do_sample: |
|
if gen_cfg.use_mirostat_mode == 1 and hasattr(global_model, 'mirostat_sample_logits'): |
|
token = global_model.mirostat_sample_logits( |
|
logits=logits, |
|
temperature=gen_cfg.temperature, |
|
mirostat_tau=gen_cfg.mirostat_tau, |
|
mirostat_eta=gen_cfg.mirostat_eta |
|
).unsqueeze(0).to(device) |
|
else: |
|
logits = logits / gen_cfg.temperature |
|
if gen_cfg.temperature_cutoff is not None and gen_cfg.temperature_cutoff > 0: |
|
logits = torch.where(logits < gen_cfg.temperature_cutoff, torch.tensor(-float('Inf')).to(logits.device), logits) |
|
if gen_cfg.top_k: |
|
topk_values, topk_indices = torch.topk(logits, gen_cfg.top_k) |
|
logits[logits < topk_values[:, -1]] = -float('Inf') |
|
if gen_cfg.top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > gen_cfg.top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = False |
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|
logits[:, indices_to_remove] = -float('Inf') |
|
if gen_cfg.typical_p < 1.0: |
|
probs = torch.softmax(logits, dim=-1) |
|
entropy = torch.distributions.Categorical(probs).entropy() |
|
probs_sorted, indices_sorted = torch.sort(probs, dim=-1, descending=True) |
|
cumsum_probs_sorted = torch.cumsum(probs_sorted, dim=-1) |
|
mask = cumsum_probs_sorted < gen_cfg.typical_p * entropy.exp() |
|
indices_to_remove = indices_sorted[~mask] |
|
logits[:, indices_to_remove] = -float('Inf') |
|
if gen_cfg.epsilon_cutoff is not None and gen_cfg.epsilon_cutoff > 0: |
|
probs = torch.softmax(logits, dim=-1) |
|
mask = probs < gen_cfg.epsilon_cutoff |
|
logits[:, mask] = -float('Inf') |
|
if gen_cfg.eta_cutoff is not None and gen_cfg.eta_cutoff > 0: |
|
probs = torch.softmax(logits, dim=-1) |
|
mask = probs > gen_cfg.eta_cutoff |
|
logits[:, ~mask] = -float('Inf') |
|
probs = torch.softmax(logits, dim=-1) |
|
token = torch.multinomial(probs, 1) |
|
else: |
|
token = torch.argmax(logits, dim=-1, keepdim=True) |
|
token_logprob = None |
|
if gen_cfg.output_scores: |
|
log_probs = F.log_softmax(step_logits_for_criteria, dim=-1) |
|
if 0 <= token.squeeze().item() < log_probs.shape[-1]: |
|
token_logprob = float(log_probs[:, token.squeeze()].item()) |
|
else: |
|
token_logprob = None |
|
return token, past, token_logprob, step_logits_for_criteria, attentions, hidden_states |
|
|
|
def post_process_text(text: str, strip_trailing_whitespace: bool, remove_incomplete_sentences: bool) -> str: |
|
if strip_trailing_whitespace: |
|
text = text.rstrip() |
|
if remove_incomplete_sentences: |
|
for terminator in ['.', '!', '?', '\n']: |
|
last_terminator = text.rfind(terminator) |
|
if last_terminator != -1: |
|
text = text[:last_terminator + 1] |
|
break |
|
return text |
|
|
|
async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> AsyncGenerator[Union[str, Tuple[Dict[str, Any], str]], None]: |
|
past = None |
|
generated_tokens_count = 0 |
|
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id") |
|
pad_token_id = req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id", eos_token_id) |
|
stop_token_ids = {eos_token_id} if eos_token_id is not None else set() |
|
if pad_token_id is not None and pad_token_id != eos_token_id: |
|
stop_token_ids.add(pad_token_id) |
|
|
|
current_ids = initial_ids |
|
start_time = time.time() |
|
total_ids_list = initial_ids.tolist()[0] |
|
finish_reason = "unknown" |
|
|
|
stopping_criteria = get_stopping_criteria(req, initial_ids, global_tokenizer) |
|
|
|
last_step_logits = None |
|
accumulated_text_for_processing = "" |
|
|
|
try: |
|
while True: |
|
if generated_tokens_count >= req.max_new_tokens: |
|
finish_reason = "max_new_tokens" |
|
break |
|
if req.max_time is not None and (time.time() - start_time) > req.max_time: |
|
finish_reason = "time" |
|
break |
|
|
|
input_ids_sync = current_ids if past is None else token |
|
|
|
token, past, token_logprob, step_logits, attentions, hidden_states = await asyncio.to_thread( |
|
generate_next_token_sync, |
|
input_ids_sync, |
|
past, |
|
gen_cfg, |
|
device |
|
) |
|
last_step_logits = step_logits |
|
|
|
generated_token_id = token[0].item() |
|
total_ids_list.append(generated_token_id) |
|
|
|
text = global_tokenizer.decode([generated_token_id], skip_special_tokens=True) |
|
accumulated_text_for_processing += text |
|
|
|
if req.return_only_text: |
|
yield text |
|
else: |
|
chunk_payload: Dict[str, Any] = { |
|
"type": "token", |
|
"text": text, |
|
"token_id": generated_token_id, |
|
"generated_tokens_count": generated_tokens_count + 1, |
|
} |
|
if req.return_token_logprobs and token_logprob is not None: |
|
chunk_payload["logprob"] = token_logprob |
|
|
|
yield json.dumps(chunk_payload) + "\n" |
|
|
|
if generated_token_id in stop_token_ids: |
|
finish_reason = "eos_token" |
|
break |
|
|
|
current_full_ids_tensor = torch.tensor([total_ids_list], device=device) |
|
if stopping_criteria(current_full_ids_tensor, step_logits): |
|
finish_reason = "stopping_criteria" |
|
current_len = len(total_ids_list) |
|
initial_len = initial_ids.shape[-1] |
|
|
|
max_len_crit_met = any(isinstance(c, MaxLengthCriteria) for c in stopping_criteria) and \ |
|
( (req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens)) or |
|
(req.max_length is not None and current_len >= req.max_length) ) |
|
stop_seq_crit_met = any(isinstance(c, StopSequenceCriteria) for c in stopping_criteria) and req.stop_sequences and \ |
|
any(seq in global_tokenizer.decode(total_ids_list[initial_len:], skip_special_tokens=True) for seq in req.stop_sequences) |
|
|
|
if max_len_crit_met: |
|
if req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens): |
|
finish_reason = "max_new_tokens" |
|
elif req.max_length is not None and current_len >= req.max_length: |
|
finish_reason = "max_length" |
|
|
|
if stop_seq_crit_met: |
|
finish_reason = "stop_sequence" |
|
|
|
|
|
break |
|
|
|
|
|
current_ids = token |
|
generated_tokens_count += 1 |
|
|
|
final_text_raw = global_tokenizer.decode(total_ids_list[initial_ids.shape[-1]:], skip_special_tokens=True) |
|
if req.stop_sequences and finish_reason == "stop_sequence": |
|
for stop_seq in req.stop_sequences: |
|
if stop_seq and stop_seq in final_text_raw: |
|
final_text_raw = final_text_raw.split(stop_seq, 1)[0] |
|
break |
|
|
|
final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences) |
|
|
|
|
|
if not req.return_only_text: |
|
final_payload: Dict[str, Any] = { |
|
"type": "done", |
|
"total_prompt_tokens": initial_ids.shape[-1], |
|
"total_generated_tokens": generated_tokens_count, |
|
"total_sequence_tokens": len(total_ids_list), |
|
"final_text": final_text_processed, |
|
"finish_reason": finish_reason |
|
} |
|
yield json.dumps(final_payload) + "\n" |
|
|
|
|
|
except Exception as e: |
|
logging.exception("Streaming generation error:") |
|
if req.return_only_text: |
|
yield f"Error: {e}\n" |
|
else: |
|
error_payload = {"type": "error", "message": str(e)} |
|
yield json.dumps(error_payload) + "\n" |
|
|
|
finally: |
|
await cleanup(device) |
|
|
|
|
|
async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]: |
|
try: |
|
logits_processor_list = LogitsProcessorList() |
|
|
|
stopping_criteria_list = get_stopping_criteria(req, initial_ids, global_tokenizer) |
|
|
|
|
|
with torch.no_grad(): |
|
out = global_model.generate( |
|
input_ids=initial_ids, |
|
generation_config=gen_cfg, |
|
return_dict_in_generate=True, |
|
output_scores=req.output_scores, |
|
output_attentions=req.output_attentions, |
|
output_hidden_states=req.output_hidden_states, |
|
num_return_sequences=req.num_return_sequences, |
|
bad_words_ids=req.bad_words_ids, |
|
suppress_tokens=req.suppress_tokens, |
|
begin_suppress_tokens=req.begin_suppress_tokens, |
|
end_suppress_tokens=req.end_suppress_tokens, |
|
logits_processor=logits_processor_list if logits_processor_list else None, |
|
stopping_criteria=stopping_criteria_list if stopping_criteria_list else None, |
|
) |
|
|
|
generated_data = [] |
|
for i in range(req.num_return_sequences): |
|
if i >= len(out.sequences): |
|
break |
|
|
|
sequence = out.sequences[i] |
|
start_index = initial_ids.shape[-1] |
|
generated_ids_tensor = sequence[start_index:] |
|
full_sequence_ids = sequence.tolist() |
|
|
|
text = global_tokenizer.decode(generated_ids_tensor, skip_special_tokens=True) |
|
|
|
if req.stop_sequences: |
|
for stop_seq in req.stop_sequences: |
|
if stop_seq and stop_seq in text: |
|
text = text.split(stop_seq, 1)[0] |
|
break |
|
|
|
text = post_process_text(text, req.strip_trailing_whitespace, req.remove_incomplete_sentences) |
|
|
|
finish_reason = "length" |
|
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id") |
|
if len(generated_ids_tensor) > 0 and eos_token_id is not None and generated_ids_tensor[-1] == eos_token_id: |
|
finish_reason = "eos_token" |
|
elif len(generated_ids_tensor) >= gen_cfg.max_new_tokens: |
|
finish_reason = "max_new_tokens" |
|
elif req.max_length is not None and len(full_sequence_ids) >= req.max_length: |
|
finish_reason = "max_length" |
|
elif hasattr(out, 'max_time_exceeded') and out.max_time_exceeded: |
|
finish_reason = "time" |
|
|
|
if req.stop_sequences and finish_reason == "length": |
|
decoded_full_output = global_tokenizer.decode(full_sequence_ids, skip_special_tokens=True) |
|
if any(seq in decoded_full_output for seq in req.stop_sequences): |
|
finish_reason = "stop_sequence" |
|
|
|
|
|
item_data: Dict[str, Any] = { |
|
"text": text if req.return_text_from_sequence else None, |
|
"token_ids": generated_ids_tensor.tolist(), |
|
"generated_tokens_count": len(generated_ids_tensor), |
|
"finish_reason": finish_reason |
|
} |
|
if not req.remove_input_from_output: |
|
item_data["full_sequence_token_ids"] = full_sequence_ids |
|
|
|
if req.output_scores and hasattr(out, 'scores') and out.scores is not None: |
|
item_data["scores"] = "Scores output needs custom handling (complex structure)." |
|
|
|
if req.return_token_logprobs: |
|
item_data["token_logprobs"] = "Token logprobs require parsing scores output which is complex for batched/beamed generation." |
|
|
|
if req.output_attentions and hasattr(out, 'attentions') and out.attentions is not None: |
|
item_data["attentions"] = "Attentions output needs custom handling (too large)." |
|
if req.output_hidden_states and hasattr(out, 'hidden_states') and out.hidden_states is not None: |
|
item_data["hidden_states"] = "Hidden states output needs custom handling (too large)." |
|
if hasattr(out, 'watermark') and out.watermark is not None: |
|
item_data["watermark"] = out.watermark[i] if isinstance(out.watermark, list) and len(out.watermark) > i else out.watermark |
|
|
|
|
|
generated_data.append(item_data) |
|
|
|
|
|
response_payload: Dict[str, Any] = { |
|
"prompt_tokens": initial_ids.shape[-1], |
|
"generated_sequences": generated_data, |
|
} |
|
if req.num_return_sequences == 1 and generated_data: |
|
response_payload["total_tokens"] = response_payload["prompt_tokens"] + generated_data[0]["generated_tokens_count"] |
|
|
|
if req.return_dict_in_generate: |
|
raw_out_dict = {} |
|
for key in out.keys(): |
|
if key not in ['sequences', 'scores', 'attentions', 'hidden_states', 'past_key_values', 'watermark', 'sequences_scores']: |
|
value = out[key] |
|
if isinstance(value, torch.Tensor): |
|
raw_out_dict[key] = value.tolist() |
|
else: |
|
raw_out_dict[key] = value |
|
|
|
response_payload["raw_generate_output"] = raw_out_dict |
|
|
|
return response_payload |
|
|
|
except Exception as e: |
|
logging.exception("Non-streaming generation error:") |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}") |
|
|
|
async def cleanup(device: str): |
|
if device == "cuda" and torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
@app.on_event("startup") |
|
async def load_model(): |
|
global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR, TRUST_REMOTE_CODE_ENV |
|
|
|
torch.set_num_threads(max(1, os.cpu_count() // 2)) |
|
torch.set_num_interop_threads(max(1, os.cpu_count() // 4)) |
|
|
|
torch.backends.cuda.preferred_linalg_backend = "fused" if torch.backends.cuda.is_built() else None |
|
torch.backends.cudnn.benchmark = True if torch.cuda.is_available() else False |
|
|
|
try: |
|
TORCH_DTYPE = getattr(torch, TORCH_DTYPE_STR.lower(), torch.float32) |
|
if TORCH_DTYPE != torch.float32: |
|
logging.warning(f"Requested dtype {TORCH_DTYPE_STR} might not be fully performant on CPU. Using float32.") |
|
TORCH_DTYPE = torch.float32 |
|
except AttributeError: |
|
logging.warning(f"Invalid TORCH_DTYPE specified: {TORCH_DTYPE_STR}. Falling back to float32.") |
|
TORCH_DTYPE = torch.float32 |
|
|
|
current_model_name = MODEL_NAME |
|
current_trust_remote_code = TRUST_REMOTE_CODE_ENV or (current_model_name == DEFAULT_MODEL_NAME) |
|
device = "cpu" |
|
|
|
try: |
|
logging.info(f"Loading config for model: {current_model_name}") |
|
config = AutoConfig.from_pretrained(current_model_name, trust_remote_code=current_trust_remote_code) |
|
original_config = copy.deepcopy(config) |
|
|
|
logging.info(f"Modifying config for simplified model.") |
|
|
|
if hasattr(config, 'num_hidden_layers'): |
|
config.num_hidden_layers = 1 |
|
elif hasattr(config, 'num_layers'): |
|
config.num_layers = 1 |
|
|
|
if hasattr(config, 'bos_token_id'): |
|
config.bos_token_id = 1 |
|
|
|
if hasattr(config, 'do_sample'): |
|
config.do_sample = None |
|
|
|
if hasattr(config, 'eos_token_id'): |
|
config.eos_token_id = 2 |
|
|
|
if hasattr(config, 'head_dim'): |
|
config.head_dim = 96 |
|
|
|
if hasattr(config, 'hidden_size'): |
|
config.hidden_size = 192 |
|
|
|
if hasattr(config, 'initializer_range'): |
|
config.initializer_range = 0.02 |
|
|
|
if hasattr(config, 'intermediate_size'): |
|
config.intermediate_size = 512 |
|
|
|
if hasattr(config, 'max_position_embeddings'): |
|
config.max_position_embeddings = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'n_positions'): |
|
config.n_positions = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'seq_len'): |
|
config.seq_len = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'ctx'): |
|
config.ctx = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'n_ctx'): |
|
config.n_ctx = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'max_seq_length'): |
|
config.max_seq_length = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'max_sequence_length'): |
|
config.max_sequence_length = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'max_length'): |
|
config.max_length = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'block_size'): |
|
config.block_size = MAX_CONTEXT_TOKENS |
|
|
|
if hasattr(config, 'use_cache'): |
|
config.use_cache = False |
|
|
|
if hasattr(config, 'gradient_checkpointing'): |
|
config.gradient_checkpointing = True |
|
|
|
if hasattr(config, 'torch_dtype'): |
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16: |
|
config.torch_dtype = 'bfloat16' |
|
else: |
|
config.torch_dtype = 'float16' |
|
|
|
if hasattr(config, 'use_bfloat16'): |
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16: |
|
config.use_bfloat16 = True |
|
else: |
|
config.use_bfloat16 = False |
|
|
|
if hasattr(config, 'attention_probs_dropout_prob'): |
|
config.attention_probs_dropout_prob = 0.1 |
|
|
|
if hasattr(config, 'hidden_dropout_prob'): |
|
config.hidden_dropout_prob = 0.1 |
|
|
|
if hasattr(config, 'layerdrop'): |
|
config.layerdrop = 0.1 |
|
|
|
if hasattr(config, 'layer_norm_eps'): |
|
config.layer_norm_eps = 1e-5 |
|
|
|
if hasattr(config, 'initializer_range'): |
|
config.initializer_range = 0.02 |
|
|
|
if hasattr(config, 'rotary_pct'): |
|
config.rotary_pct = 0.25 |
|
|
|
if hasattr(config, 'rotary_emb_base'): |
|
config.rotary_emb_base = 10000 |
|
|
|
if hasattr(config, 'position_embedding_type'): |
|
config.position_embedding_type = 'rotary' |
|
|
|
if hasattr(config, 'activation_function'): |
|
config.activation_function = 'gelu_new' |
|
|
|
if hasattr(config, 'vocab_size'): |
|
config.vocab_size = 32000 |
|
|
|
if hasattr(config, 'quantization_config'): |
|
if torch.cuda.is_available(): |
|
config.quantization_config = { |
|
'load_in_8bit': True, |
|
'load_in_4bit': False, |
|
'bnb_4bit_compute_dtype':'float16', |
|
'bnb_4bit_use_double_quant':True, |
|
'bnb_4bit_quant_type':'nf4' |
|
} |
|
else: |
|
logging.warning("Quantization config requested but CUDA not available. Skipping quantization config modification.") |
|
config.quantization_config = {} |
|
|
|
if hasattr(config, 'load_in_8bit'): |
|
if torch.cuda.is_available(): |
|
config.load_in_8bit = True |
|
else: |
|
config.load_in_8bit = False |
|
|
|
if hasattr(config, 'load_in_4bit'): |
|
if torch.cuda.is_available(): |
|
config.load_in_4bit = False |
|
else: |
|
config.load_in_4bit = False |
|
|
|
if hasattr(config, 'tie_word_embeddings'): |
|
config.tie_word_embeddings = True |
|
|
|
if hasattr(config, 'output_attentions'): |
|
config.output_attentions = False |
|
|
|
if hasattr(config, 'output_hidden_states'): |
|
config.output_hidden_states = False |
|
|
|
if hasattr(config, 'use_cache'): |
|
config.use_cache = False |
|
|
|
logging.info(f"Loading tokenizer for model: {current_model_name}") |
|
tokenizer_kwargs = {"config": original_config, "trust_remote_code": current_trust_remote_code} |
|
global_tokenizer = AutoTokenizer.from_pretrained(current_model_name, **tokenizer_kwargs) |
|
logging.info("Tokenizer loaded.") |
|
|
|
logging.info(f"Loading model: {current_model_name} with modified config and dtype {TORCH_DTYPE} onto {device}") |
|
|
|
model_kwargs = {"config": config, "torch_dtype": TORCH_DTYPE, "trust_remote_code": current_trust_remote_code} |
|
|
|
global_model = AutoModelForCausalLM.from_pretrained(current_model_name, **model_kwargs) |
|
global_model.to(device) |
|
|
|
try: |
|
global_model = torch.compile(global_model, mode="max-autotune") |
|
logging.info("Model compiled with torch.compile (max-autotune mode).") |
|
except Exception as e: |
|
logging.warning(f"Failed to compile model with torch.compile: {e}") |
|
pass |
|
|
|
global_model.eval() |
|
logging.info("Model loaded successfully.") |
|
|
|
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id |
|
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id |
|
if global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is not None: |
|
global_tokens["pad_token_id"] = global_tokens["eos_token_id"] |
|
if global_model.config.pad_token_id is None: |
|
global_model.config.pad_token_id = global_tokens["pad_token_id"] |
|
elif global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is None: |
|
logging.warning("Neither EOS nor PAD token is defined for this tokenizer/model.") |
|
if global_model.config.pad_token_id is None and global_tokens.get("pad_token_id") is not None: |
|
global_model.config.pad_token_id = global_tokens["pad_token_id"] |
|
|
|
except Exception as e: |
|
logging.exception("Failed to load model or tokenizer:") |
|
global_model = None |
|
global_tokenizer = None |
|
global_tokens = {} |
|
|
|
html_code = """ |
|
<!DOCTYPE html> |
|
<html lang="es"> |
|
<head> |
|
<meta charset="UTF-8" /> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> |
|
<title>Chatbot Profesional</title> |
|
<style> |
|
body { font-family: Arial, sans-serif; margin: 20px; } |
|
#chatbox { width: 100%; height: 400px; border: 1px solid #ccc; padding: 10px; overflow-y: scroll; margin-bottom: 10px; } |
|
#user-input { width: calc(100% - 100px); padding: 8px; box-sizing: border-box;} |
|
#send-btn { width: 90px; padding: 8px 0;} |
|
#input-area { display: flex;} |
|
</style> |
|
</head> |
|
<body> |
|
<h1>Chatbot Profesional (POST API)</h1> |
|
<div id="chatbox"></div> |
|
<div id="input-area"> |
|
<input type="text" id="user-input" placeholder="Escribe tu mensaje aquí..." autocomplete="off"/> |
|
<button id="send-btn">Enviar</button> |
|
</div> |
|
<script> |
|
const chatbox = document.getElementById('chatbox'); |
|
const userInput = document.getElementById('user-input'); |
|
const sendBtn = document.getElementById('send-btn'); |
|
|
|
let conversationHistory = []; |
|
const DEFAULT_SYSTEM_PROMPT = "Eres un asistente profesional y servicial."; |
|
let currentSystemPrompt = DEFAULT_SYSTEM_PROMPT; |
|
let botMessageElement = null; |
|
|
|
function appendMessage(sender, text, isStreaming = false) { |
|
let msg; |
|
if (isStreaming && botMessageElement) { |
|
botMessageElement.textContent += text; |
|
} else { |
|
msg = document.createElement('p'); |
|
msg.innerHTML = `<strong>${sender}:</strong> `; |
|
const textNode = document.createTextNode(text); |
|
msg.appendChild(textNode); |
|
chatbox.appendChild(msg); |
|
if (sender === 'Bot' && isStreaming) { |
|
botMessageElement = textNode; |
|
} else { |
|
botMessageElement = null; |
|
} |
|
} |
|
chatbox.scrollTop = chatbox.scrollHeight; |
|
} |
|
|
|
function updateHistory(role, content) { |
|
conversationHistory.push({ "role": role, "content": content }); |
|
const maxHistorySize = 10; |
|
if (conversationHistory.length > maxHistorySize * 2) { |
|
conversationHistory = conversationHistory.slice(-(maxHistorySize * 2)); |
|
} |
|
} |
|
|
|
async function sendMessage() { |
|
const text = userInput.value; |
|
if (!text) { |
|
return; |
|
} |
|
appendMessage('Usuario', text); |
|
updateHistory("user", text); |
|
userInput.value = ''; |
|
sendBtn.disabled = true; |
|
|
|
botMessageElement = null; |
|
|
|
const messagePayload = { |
|
input_text: text, |
|
history: conversationHistory, |
|
system_prompt: currentSystemPrompt, |
|
stream: true, |
|
temperature: 1.0, |
|
top_k: 50, |
|
top_p: 1.0, |
|
repetition_penalty: 1.0, |
|
frequency_penalty: 0.0, |
|
presence_penalty: 0.0, |
|
num_beams: 1, |
|
length_penalty: 1.0, |
|
no_repeat_ngram_size: 0, |
|
early_stopping: false, |
|
do_sample: true, |
|
use_mirostat: false, |
|
mirostat_tau: 5.0, |
|
mirostat_eta: 0.1, |
|
max_new_tokens: 512, |
|
num_return_sequences: 1, |
|
return_token_logprobs: true |
|
}; |
|
|
|
try { |
|
const response = await fetch('/generate', { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json', |
|
// Add API Key header if needed |
|
// 'X-API-Key': 'YOUR_API_KEY_HERE' |
|
}, |
|
body: JSON.stringify(messagePayload), |
|
}); |
|
|
|
if (!response.ok) { |
|
const errorData = await response.json(); |
|
throw new Error(`API Error: ${response.status} ${response.statusText} - ${errorData.detail || errorData.error}`); |
|
} |
|
|
|
const reader = response.body.getReader(); |
|
const decoder = new TextDecoder(); |
|
let buffer = ''; |
|
let currentBotResponse = ""; |
|
|
|
while (true) { |
|
const { value, done } = await reader.read(); |
|
if (done) break; |
|
|
|
buffer += decoder.decode(value, { stream: true }); |
|
|
|
const lines = buffer.split('\n'); |
|
buffer = lines.pop(); |
|
|
|
for (const line of lines) { |
|
if (line.trim() === '') continue; |
|
try { |
|
const data = JSON.parse(line); |
|
if (data.type === 'token') { |
|
currentBotResponse += data.text; |
|
appendMessage('Bot', data.text, true); |
|
console.log('Token:', data.token_id, 'Text:', data.text, 'Logprob:', data.logprob); |
|
} else if (data.type === 'done') { |
|
console.log('Generation done', data); |
|
if (data.total_tokens !== undefined) { |
|
appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`); |
|
} |
|
if (data.final_text !== undefined) { |
|
updateHistory("assistant", data.final_text); |
|
} else if (currentBotResponse) { |
|
updateHistory("assistant", currentBotResponse); |
|
} |
|
|
|
} else if (data.type === 'error') { |
|
appendMessage('Error', data.message); |
|
currentBotResponse = ""; |
|
} |
|
} catch (e) { |
|
console.error('Failed to parse stream chunk:', e, line); |
|
appendMessage('Error', 'Failed to process stream.'); |
|
currentBotResponse = ""; |
|
reader.cancel(); |
|
return; |
|
} |
|
} |
|
} |
|
|
|
if (buffer.trim() !== '') { |
|
try { |
|
const data = JSON.parse(buffer); |
|
if (data.type === 'token') { |
|
currentBotResponse += data.text; |
|
appendMessage('Bot', data.text, true); |
|
console.log('Token:', data.token_id, 'Text:', data.text, 'Logprob:', data.logprob); |
|
} else if (data.type === 'done') { |
|
console.log('Generation done', data); |
|
if (data.total_tokens !== undefined) { |
|
appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`); |
|
} |
|
if (data.final_text !== undefined) { |
|
updateHistory("assistant", data.final_text); |
|
} else if (currentBotResponse) { |
|
updateHistory("assistant", currentBotResponse); |
|
} |
|
} else if (data.type === 'error') { |
|
appendMessage('Error', data.message); |
|
currentBotResponse = ""; |
|
} |
|
} catch (e) { |
|
console.error('Failed to parse remaining buffer:', e, buffer); |
|
appendMessage('Error', 'Failed to process remaining stream data.'); |
|
currentBotResponse = ""; |
|
} |
|
} |
|
|
|
|
|
if (currentBotResponse && !botMessageElement) { |
|
updateHistory("assistant", currentBotResponse); |
|
} |
|
botMessageElement = null; |
|
currentBotResponse = ""; |
|
|
|
|
|
} catch (error) { |
|
console.error('Send message error:', error); |
|
appendMessage('Error', error.message || 'An unknown error occurred.'); |
|
botMessageElement = null; |
|
currentBotResponse = ""; |
|
} finally { |
|
sendBtn.disabled = false; |
|
} |
|
} |
|
|
|
sendBtn.onclick = sendMessage; |
|
|
|
userInput.addEventListener('keypress', function(event) { |
|
if (event.key === 'Enter') { |
|
event.preventDefault(); |
|
sendMessage(); |
|
} |
|
}); |
|
|
|
|
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
@app.get("/", response_class=HTMLResponse, summary="Interactive HTML interface") |
|
async def root(): |
|
return HTMLResponse(content=html_code) |
|
|
|
async def check_health(): |
|
model_loaded = global_model is not None |
|
tokenizer_loaded = global_tokenizer is not None |
|
status_data = { |
|
"model_loaded": model_loaded, |
|
"tokenizer_loaded": tokenizer_loaded, |
|
"status": "ok" if model_loaded and tokenizer_loaded else "loading model", |
|
"cuda_available": torch.cuda.is_available(), |
|
"cpu_cores": os.cpu_count(), |
|
"max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, |
|
"currently_running_generations": MAX_CONCURRENT_GENERATIONS - generation_semaphore._value, |
|
"available_slots": generation_semaphore._value, |
|
} |
|
if torch.cuda.is_available(): |
|
device_count = torch.cuda.device_count() |
|
status_data["device_count"] = device_count |
|
status_data["devices"] = [] |
|
for i in range(device_count): |
|
try: |
|
device_status = { |
|
"id": i, |
|
"name": torch.cuda.get_device_name(i), |
|
"total_memory_mib": round(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024), 2), |
|
"allocated_memory_mib": round(torch.cuda.memory_allocated(i) / (1024 * 1024), 2), |
|
"cached_memory_mib": round(torch.cuda.memory_reserved(i) / (1024 * 1024), 2), |
|
} |
|
status_data["devices"].append(device_status) |
|
except Exception as e: |
|
logging.error(f"Error getting GPU memory info for device {i}: {e}") |
|
status_data["devices"].append({"id": i, "error": str(e)}) |
|
else: |
|
status_data["message"] = "CUDA not available. GPU resource info is not applicable." |
|
return status_data |
|
|
|
async def get_config_data(): |
|
torch_dtype_str_out = str(TORCH_DTYPE).split('.')[-1] if isinstance(TORCH_DTYPE, torch.dtype) else str(TORCH_DTYPE) |
|
return { |
|
"model_name": MODEL_NAME, |
|
"system_prompt_default": SYSTEM_PROMPT, |
|
"max_context_tokens": MAX_CONTEXT_TOKENS, |
|
"max_generation_tokens": MAX_GENERATION_TOKENS, |
|
"cuda_available": torch.cuda.is_available(), |
|
"model_loaded": global_model is not None, |
|
"tokenizer_loaded": global_tokenizer is not None, |
|
"max_concurrent_generations": MAX_CONCURRENT_GENERATIONS, |
|
"trust_remote_code_startup_env": TRUST_REMOTE_CODE_ENV, |
|
"trust_remote_code_effective": TRUST_REMOTE_CODE, |
|
"enable_flash_attention_2": ENABLE_FLASH_ATTENTION_2, |
|
"torch_dtype": torch_dtype_str_out, |
|
"eos_token_id": global_tokens.get("eos_token_id"), |
|
"pad_token_id": global_tokens.get("pad_token_id"), |
|
"bos_token_id": global_tokenizer.bos_token_id if global_tokenizer else None, |
|
"api_key_required": API_KEY is not None |
|
} |
|
|
|
async def get_model_info_data(): |
|
if global_model is None: |
|
return {"model_name": MODEL_NAME, "is_loaded": False, "message": "Model is not loaded."} |
|
try: |
|
config_dict = global_model.config.to_dict() |
|
keys_to_remove = ['torch_dtype', '_attn_implementation', 'architectures', 'id2label', 'label2id', 'torch_dtype'] |
|
for key in keys_to_remove: |
|
config_dict.pop(key, None) |
|
return { |
|
"model_name": MODEL_NAME, |
|
"is_loaded": True, |
|
"device": str(global_model.device), |
|
"torch_dtype": str(global_model.dtype), |
|
"config": config_dict |
|
} |
|
except Exception as e: |
|
logging.exception("Error getting model info:") |
|
return {"model_name": MODEL_NAME, "is_loaded": True, "error": f"Error getting model info: {e}"} |
|
|
|
async def internal_tokenize(text: Union[str, List[str]], add_special_tokens: bool = True, is_split_into_words: bool = False, return_token_type_ids: bool = False, padding: Union[bool, str] = False, truncation: Union[bool, str] = False, max_length: Optional[int] = None, return_tensors: Optional[str] = None, return_attention_mask: Optional[bool] = None, return_offsets_mapping: Optional[bool] = None, return_length: Optional[bool] = None, verbose: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None): |
|
if global_tokenizer is None: |
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tokenizer is not loaded.") |
|
try: |
|
tokenizer_kwargs_final = tokenizer_kwargs or {} |
|
return_tensors_final = return_tensors if return_tensors is not None else None |
|
if return_tensors_final is None and (return_attention_mask or return_offsets_mapping or return_length): |
|
return_tensors_final = "pt" |
|
encoded = global_tokenizer( |
|
text, |
|
add_special_tokens=add_special_tokens, |
|
return_token_type_ids=return_token_type_ids, |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
is_split_into_words=is_split_into_words, |
|
return_tensors=return_tensors_final, |
|
return_attention_mask=return_attention_mask, |
|
return_offsets_mapping=return_offsets_mapping, |
|
return_length=return_length, |
|
verbose=verbose, |
|
**tokenizer_kwargs_final |
|
) |
|
return encoded |
|
except Exception as e: |
|
logging.exception("Tokenization error:") |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Tokenization error: {e}") |
|
|
|
async def internal_decode(token_ids: List[int], skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = True, decode_kwargs: Optional[Dict[str, Any]] = None): |
|
if global_tokenizer is None: |
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tokenizer is not loaded.") |
|
try: |
|
decode_kwargs_final = decode_kwargs or {} |
|
text = global_tokenizer.decode( |
|
token_ids, |
|
skip_special_tokens=skip_special_tokens, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
**decode_kwargs_final |
|
) |
|
return {"text": text} |
|
except Exception as e: |
|
logging.exception("Decoding error:") |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Decoding error: {e}") |
|
|
|
def update_global_system_prompt(new_prompt: str): |
|
global SYSTEM_PROMPT |
|
if new_prompt is not None: |
|
SYSTEM_PROMPT = new_prompt.strip() |
|
return {"status": "success", "message": "Global system prompt updated"} |
|
else: |
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="System prompt cannot be null") |
|
|
|
async def internal_reload_model(req: ModelReloadRequest): |
|
global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR, TRUST_REMOTE_CODE_ENV |
|
new_model_name = req.model_name if req.model_name else MODEL_NAME |
|
new_trust_remote_code = req.trust_remote_code if req.trust_remote_code is not None else (TRUST_REMOTE_CODE_ENV or (new_model_name == DEFAULT_MODEL_NAME)) |
|
new_enable_flash_attention_2 = req.enable_flash_attention_2 if req.enable_flash_attention_2 is not None else ENABLE_FLASH_ATTENTION_2 |
|
new_torch_dtype_str_req = req.torch_dtype if req.torch_dtype else TORCH_DTYPE_STR |
|
try: |
|
new_torch_dtype = getattr(torch, new_torch_dtype_str_req.lower()) |
|
if new_torch_dtype != torch.float32: |
|
logging.warning(f"Requested dtype {new_torch_dtype_str_req} might not be fully performant on CPU. Using float32.") |
|
new_torch_dtype = torch.float32 |
|
elif not isinstance(new_torch_dtype, torch.dtype): |
|
raise AttributeError |
|
new_torch_dtype_str = str(new_torch_dtype).split('.')[-1] |
|
except AttributeError: |
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid or unsupported torch_dtype: {new_torch_dtype_str_req}") |
|
device = "cpu" |
|
async def _reload(): |
|
global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR |
|
logging.info(f"Attempting to load model: {new_model_name}") |
|
try: |
|
logging.info("Unloading current model...") |
|
await cleanup(device) |
|
if global_model is not None: |
|
del global_model |
|
global_model = None |
|
if global_tokenizer is not None: |
|
del global_tokenizer |
|
global_tokenizer = None |
|
global_tokens = {} |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
gc.collect() |
|
logging.info("Current model unloaded.") |
|
logging.info(f"Loading config for model: {new_model_name}") |
|
config = AutoConfig.from_pretrained(new_model_name, trust_remote_code=new_trust_remote_code) |
|
original_config = copy.deepcopy(config) |
|
|
|
logging.info(f"Modifying config for simplified model.") |
|
|
|
config_modifications = { |
|
'num_hidden_layers': 1, |
|
'num_layers': 1, |
|
'bos_token_id': 1, |
|
'do_sample': None, |
|
'eos_token_id': 2, |
|
'head_dim': 96, |
|
'hidden_size': 192, |
|
'initializer_range': 0.02, |
|
'intermediate_size': 512, |
|
'max_position_embeddings': MAX_CONTEXT_TOKENS, |
|
'n_positions': MAX_CONTEXT_TOKENS, |
|
'seq_len': MAX_CONTEXT_TOKENS, |
|
'ctx': MAX_CONTEXT_TOKENS, |
|
'n_ctx': MAX_CONTEXT_TOKENS, |
|
'max_seq_length': MAX_CONTEXT_TOKENS, |
|
'max_sequence_length': MAX_CONTEXT_TOKENS, |
|
'max_length': MAX_CONTEXT_TOKENS, |
|
'block_size': MAX_CONTEXT_TOKENS, |
|
'use_cache': False, |
|
'gradient_checkpointing': True, |
|
'attention_probs_dropout_prob': 0.1, |
|
'hidden_dropout_prob': 0.1, |
|
'layerdrop': 0.1, |
|
'layer_norm_eps': 1e-5, |
|
'rotary_pct': 0.25, |
|
'rotary_emb_base': 10000, |
|
'position_embedding_type': 'rotary', |
|
'activation_function': 'gelu_new', |
|
'vocab_size': 32000, |
|
'tie_word_embeddings': True, |
|
'output_attentions': False, |
|
'output_hidden_states': False, |
|
} |
|
|
|
for attr, new_val in config_modifications.items(): |
|
if hasattr(config, attr): |
|
if attr == 'torch_dtype': |
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16: |
|
setattr(config, attr, torch.bfloat16) |
|
else: |
|
setattr(config, attr, torch.float16) |
|
elif attr == 'use_bfloat16': |
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16: |
|
setattr(config, attr, True) |
|
else: |
|
setattr(config, attr, False) |
|
elif attr == 'quantization_config': |
|
if torch.cuda.is_available(): |
|
setattr(config, attr, new_val) |
|
else: |
|
logging.warning(f"Quantization config requested for '{attr}' but CUDA not available. Skipping modification.") |
|
else: |
|
setattr(config, attr, new_val) |
|
elif attr in ['num_hidden_layers', 'num_layers', 'max_position_embeddings', 'n_positions', 'seq_len', 'ctx', 'n_ctx', 'max_seq_length', 'max_sequence_length', 'max_length', 'block_size']: |
|
logging.warning(f"Could not find a standard parameter '{attr}' in config for {new_model_name}. Max context/layer logic might not be fully effective.") |
|
|
|
|
|
logging.info(f"Loading tokenizer for model: {new_model_name}") |
|
tokenizer_kwargs = {"config": original_config, "trust_remote_code": new_trust_remote_code} |
|
if req.tokenizer_kwargs: |
|
tokenizer_kwargs.update(req.tokenizer_kwargs) |
|
tokenizer = AutoTokenizer.from_pretrained(new_model_name, **tokenizer_kwargs) |
|
logging.info("Tokenizer loaded.") |
|
|
|
logging.info(f"Loading model: {new_model_name} with modified config and dtype {new_torch_dtype_str} onto {device}") |
|
model_kwargs = {"config": config, "torch_dtype": new_torch_dtype, "trust_remote_code": new_trust_remote_code} |
|
model = AutoModelForCausalLM.from_pretrained(new_model_name, **model_kwargs) |
|
model.to(device) |
|
|
|
try: |
|
model = torch.compile(model, mode="max-autotune") |
|
logging.info("New model compiled with torch.compile (max-autotune mode).") |
|
except Exception as e: |
|
logging.warning(f"Failed to compile new model with torch.compile: {e}") |
|
pass |
|
model.eval() |
|
logging.info("New model loaded successfully.") |
|
global_model = model |
|
global_tokenizer = tokenizer |
|
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id |
|
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id |
|
if global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is not None: |
|
global_tokens["pad_token_id"] = global_tokens["eos_token_id"] |
|
if global_model.config.pad_token_id is None: |
|
global_model.config.pad_token_id = global_tokens["pad_token_id"] |
|
elif global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is None: |
|
logging.warning("Neither EOS nor PAD token defined for new model.") |
|
if global_model.config.pad_token_id is None and global_tokens.get("pad_token_id") is not None: |
|
global_model.config.pad_token_id = global_tokens["pad_token_id"] |
|
MODEL_NAME = new_model_name |
|
TRUST_REMOTE_CODE = new_trust_remote_code |
|
ENABLE_FLASH_ATTENTION_2 = new_enable_flash_attention_2 |
|
TORCH_DTYPE = new_torch_dtype |
|
TORCH_DTYPE_STR = new_torch_dtype_str |
|
if hasattr(global_tokenizer, 'use_fast'): |
|
pass |
|
logging.info(f"Model successfully reloaded to: {MODEL_NAME}") |
|
logging.info({"status": "success", "message": f"Model {new_model_name} loaded successfully."}) |
|
except Exception as e: |
|
logging.exception(f"Failed to load model {new_model_name}:") |
|
global_model = None |
|
global_tokenizer = None |
|
global_tokens = {} |
|
logging.error({"status": "error", "message": f"Failed to load model {new_model_name}: {e}. Model is now unloaded."}) |
|
asyncio.create_task(_reload()) |
|
return {"status": "info", "message": f"Attempting to load model {new_model_name} in background. Check logs for status."} |
|
|
|
async def internal_unload_model(): |
|
global global_model, global_tokenizer, global_tokens |
|
device = "cpu" |
|
logging.info("Attempting to unload model.") |
|
try: |
|
await cleanup(device) |
|
if global_model is not None: |
|
del global_model |
|
global_model = None |
|
if global_tokenizer is not None: |
|
del global_tokenizer |
|
global_tokenizer = None |
|
global_tokens = {} |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
gc.collect() |
|
logging.info("Model unloaded successfully.") |
|
return {"status": "success", "message": "Model unloaded successfully."} |
|
except Exception as e: |
|
logging.exception("Failed to unload model:") |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to unload model: {e}") |
|
|
|
|
|
@app.post("/generate", summary="Generate text", dependencies=[Depends(get_api_key)]) |
|
async def generate_endpoint(req: GenerateRequest): |
|
if global_model is None or global_tokenizer is None: |
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model is not loaded. It may still be loading or failed to load.") |
|
device = "cpu" |
|
apply_seed(req.seed) |
|
try: |
|
initial_prompt_text = format_conversation(req.input_text, req.history, req.system_prompt) |
|
except Exception as e: |
|
logging.exception("Error formatting conversation:") |
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Error formatting conversation: {e}") |
|
try: |
|
tokenizer_encoding_kwargs = req.tokenizer_kwargs or {} |
|
|
|
encoded = global_tokenizer(initial_prompt_text, return_tensors="pt", add_special_tokens=False, **tokenizer_encoding_kwargs).to(device) |
|
initial_ids_before_trunc = encoded.input_ids |
|
initial_prompt_tokens_count_before_trunc = initial_ids_before_trunc.shape[-1] |
|
|
|
ids = truncate_encoded_ids(initial_ids_before_trunc, MAX_CONTEXT_TOKENS) |
|
current_prompt_tokens_count = ids.shape[-1] |
|
|
|
except Exception as e: |
|
logging.exception("Tokenizer error during encoding:") |
|
await cleanup(device) |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Tokenizer encoding error: {e}") |
|
if req.tokenize_only: |
|
await cleanup(device) |
|
return JSONResponse({ |
|
"prompt_tokens_count": initial_prompt_tokens_count_before_trunc, |
|
"max_context_tokens": MAX_CONTEXT_TOKENS, |
|
"truncated": initial_prompt_tokens_count_before_trunc > MAX_CONTEXT_TOKENS, |
|
"input_text_processed": initial_prompt_text, |
|
"input_ids_truncated": ids.tolist()[0] |
|
}) |
|
total_capacity = MAX_CONTEXT_TOKENS + MAX_GENERATION_TOKENS |
|
total_requested_seq_len = current_prompt_tokens_count + req.max_new_tokens |
|
if not req.stream and total_requested_seq_len > total_capacity: |
|
await cleanup(device) |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail=f"Requested sequence length ({total_requested_seq_len} tokens = {current_prompt_tokens_count} prompt + {req.max_new_tokens} new) exceeds model capacity ({total_capacity} tokens) and non-streaming is requested. Consider enabling streaming or reducing max_new_tokens." |
|
) |
|
async with generation_semaphore: |
|
try: |
|
gen_cfg = GenerationConfig( |
|
temperature=req.temperature, |
|
top_k=req.top_k, |
|
top_p=req.top_p, |
|
repetition_penalty=req.repetition_penalty, |
|
frequency_penalty=req.frequency_penalty, |
|
presence_penalty=req.presence_penalty, |
|
num_beams=req.num_beams if not req.stream else 1, |
|
length_penalty=req.length_penalty, |
|
no_repeat_ngram_size=req.no_repeat_ngram_size, |
|
early_stopping=req.early_stopping, |
|
do_sample=req.do_sample, |
|
use_mirostat_mode=1 if req.use_mirostat else 0, |
|
mirostat_tau=req.mirostat_tau, |
|
mirostat_eta=req.mirostat_eta, |
|
max_new_tokens=req.max_new_tokens, |
|
eos_token_id=req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id"), |
|
pad_token_id=req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id"), |
|
bos_token_id=req.bos_token_id_override if req.bos_token_id_override is not None else global_tokenizer.bos_token_id, |
|
num_return_sequences=req.num_return_sequences if not req.stream else 1, |
|
bad_words_ids=req.bad_words_ids, |
|
forced_bos_token_id=req.forced_bos_token_id, |
|
forced_eos_token_id=req.forced_eos_token_id, |
|
renormalize_logits=req.renormalize_logits, |
|
suppress_tokens=req.suppress_tokens, |
|
begin_suppress_tokens=req.begin_suppress_tokens, |
|
end_suppress_tokens=req.end_suppress_tokens, |
|
encoder_no_repeat_ngram_size=req.encoder_no_repeat_ngram_size, |
|
min_length=req.min_length, |
|
max_length=req.max_length, |
|
exponential_decay_length_penalty=req.exponential_decay_length_penalty, |
|
use_cache=req.use_cache, |
|
typical_p=req.typical_p, |
|
epsilon_cutoff=req.epsilon_cutoff, |
|
eta_cutoff=req.eta_cutoff, |
|
temperature_cutoff=req.temperature_cutoff, |
|
encoder_repetition_penalty=req.encoder_repetition_penalty, |
|
max_time=req.max_time, |
|
output_watermark=req.output_watermark, |
|
diversity_penalty=req.diversity_penalty, |
|
num_beam_groups=req.num_beam_groups if not req.stream else 1, |
|
length_normalization_factor=req.length_normalization_factor, |
|
min_new_tokens=req.min_new_tokens, |
|
do_normalize_logits=req.do_normalize_logits, |
|
output_scores=req.output_scores, |
|
output_attentions=req.output_attentions, |
|
output_hidden_states=req.output_hidden_states, |
|
) |
|
if req.stream: |
|
gen_cfg.use_cache = True |
|
gen_cfg.num_beams = 1 |
|
gen_cfg.num_return_sequences = 1 |
|
gen_cfg.num_beam_groups = 1 |
|
return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="text/plain" if req.return_only_text else "application/json") |
|
else: |
|
response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device) |
|
if req.return_only_text: |
|
texts = [seq["text"] for seq in response_payload.get("generated_sequences", []) if seq.get("text") is not None] |
|
if req.num_return_sequences == 1 and texts: |
|
return PlainTextResponse(texts[0]) |
|
else: |
|
return JSONResponse(texts) |
|
else: |
|
return JSONResponse(response_payload) |
|
except Exception as e: |
|
logging.exception("Generation error:") |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}") |
|
finally: |
|
await cleanup(device) |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run( |
|
app, host="0.0.0.0", port=7860, |
|
log_level="critical", |
|
access_log=False |
|
) |