Spaces:
Paused
Paused
| # Updated generator.py with proper function order | |
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| import torch | |
| import torchaudio | |
| import logging | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer | |
| from tokenizers.processors import TemplateProcessing | |
| from app.models import Segment | |
| from app.text_normalizer import clean_text_for_tts | |
| from app.text_normalizer import TextNormalizer | |
| # Set up logging | |
| logger = logging.getLogger(__name__) | |
| # Import the CSM watermarking code | |
| try: | |
| from app.watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark | |
| except ImportError: | |
| # Define stubs for watermarking if the module is not available | |
| CSM_1B_GH_WATERMARK = "CSM1B" | |
| def load_watermarker(device="cpu"): | |
| return None | |
| def watermark(watermarker, audio, sample_rate, key): | |
| return audio, sample_rate | |
| def load_llama3_tokenizer(): | |
| """ | |
| Load tokenizer for Llama 3.2, using unsloth's open version | |
| instead of the gated meta-llama version. | |
| """ | |
| try: | |
| # Use the unsloth version which is not gated | |
| tokenizer_name = "unsloth/Llama-3.2-1B" | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| bos = tokenizer.bos_token | |
| eos = tokenizer.eos_token | |
| tokenizer._tokenizer.post_processor = TemplateProcessing( | |
| single=f"{bos}:0 $A:0 {eos}:0", | |
| pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", | |
| special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)], | |
| ) | |
| logger.info("Successfully loaded tokenizer from unsloth/Llama-3.2-1B") | |
| return tokenizer | |
| except Exception as e: | |
| logger.error(f"Error loading tokenizer from unsloth: {e}") | |
| # Fallback to a simpler tokenizer if needed | |
| try: | |
| from transformers import GPT2Tokenizer | |
| logger.warning("Falling back to GPT2Tokenizer") | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return tokenizer | |
| except Exception as fallback_e: | |
| logger.error(f"Fallback tokenizer also failed: {fallback_e}") | |
| raise RuntimeError("Could not load any suitable tokenizer") | |
| class Generator: | |
| """Generator class for CSM-1B model.""" | |
| def __init__(self, model): | |
| """Initialize generator with model.""" | |
| self._model = model | |
| self._model.setup_caches(1) | |
| self._text_tokenizer = load_llama3_tokenizer() | |
| device = next(model.parameters()).device | |
| # Load Mimi codec for audio tokenization | |
| try: | |
| logger.info("Loading Mimi audio codec...") | |
| from huggingface_hub import hf_hub_download | |
| # First try to import from moshi | |
| try: | |
| from moshi.models import loaders | |
| DEFAULT_REPO = loaders.DEFAULT_REPO | |
| MIMI_NAME = loaders.MIMI_NAME | |
| get_mimi = loaders.get_mimi | |
| except ImportError: | |
| logger.warning("moshi.models.loaders not found, using fallback") | |
| # Fallback values if moshi.models.loaders is not available | |
| DEFAULT_REPO = "kyutai/mimi" | |
| MIMI_NAME = "mimi-december.pt" | |
| # Fallback function to load mimi | |
| def get_mimi(checkpoint_path, device): | |
| from moshi.models.vqvae_model import MiMiModule | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model = MiMiModule.init_from_checkpoint(checkpoint, device=device) | |
| return model | |
| mimi_weight = hf_hub_download(DEFAULT_REPO, MIMI_NAME) | |
| mimi = get_mimi(mimi_weight, device=device) | |
| mimi.set_num_codebooks(32) | |
| self._audio_tokenizer = mimi | |
| self.sample_rate = mimi.sample_rate | |
| logger.info(f"Mimi codec loaded successfully with sample rate {self.sample_rate}") | |
| except Exception as e: | |
| logger.error(f"Error loading Mimi codec: {e}") | |
| self._audio_tokenizer = None | |
| self.sample_rate = 24000 # Default sample rate | |
| logger.warning(f"Using fallback sample rate: {self.sample_rate}") | |
| raise RuntimeError(f"Failed to load Mimi codec: {e}") | |
| try: | |
| self._watermarker = load_watermarker(device=device) | |
| logger.info("Watermarker loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"Error loading watermarker: {e}. Watermarking will be disabled.") | |
| self._watermarker = None | |
| self.device = device | |
| # Optimize for CUDA throughput | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.benchmark = True | |
| torch.cuda.empty_cache() | |
| logger.info("CUDA optimizations enabled") | |
| def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Tokenize a text segment.""" | |
| frame_tokens = [] | |
| frame_masks = [] | |
| # Strip any voice instructions in square brackets to avoid them being read out | |
| text = self._clean_text_input(text) | |
| text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") | |
| text_frame = torch.zeros(len(text_tokens), 33).long() | |
| text_frame_mask = torch.zeros(len(text_tokens), 33).bool() | |
| text_frame[:, -1] = torch.tensor(text_tokens) | |
| text_frame_mask[:, -1] = True | |
| frame_tokens.append(text_frame.to(self.device)) | |
| frame_masks.append(text_frame_mask.to(self.device)) | |
| return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
| def _clean_text_input(self, text: str) -> str: | |
| """Clean and normalize text for TTS.""" | |
| return clean_text_for_tts(text) | |
| def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Tokenize audio.""" | |
| if self._audio_tokenizer is None: | |
| raise RuntimeError("Audio tokenizer not initialized") | |
| frame_tokens = [] | |
| frame_masks = [] | |
| # (K, T) | |
| audio = audio.to(self.device) | |
| audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0] | |
| # add EOS frame | |
| eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device) | |
| audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1) | |
| audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device) | |
| audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device) | |
| audio_frame[:, :-1] = audio_tokens.transpose(0, 1) | |
| audio_frame_mask[:, :-1] = True | |
| frame_tokens.append(audio_frame) | |
| frame_masks.append(audio_frame_mask) | |
| return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0) | |
| def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Tokenize a segment of text and audio.""" | |
| text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker) | |
| audio_tokens, audio_masks = self._tokenize_audio(segment.audio) | |
| return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0) | |
| def generate_quick( | |
| self, | |
| text: str, | |
| speaker: int, | |
| context: List[Segment], | |
| max_audio_length_ms: float = 2000, # Short for quick generation | |
| temperature: float = 0.7, # Lower for more predictable output | |
| topk: int = 20, # Lower for faster beam selection | |
| ) -> torch.Tensor: | |
| """Generate audio quickly for real-time streaming.""" | |
| # Similar to generate() but optimized for speed | |
| self._model.reset_caches() | |
| # Convert max_audio_length_ms to frames - limit for faster generation | |
| max_audio_frames = min(int(max_audio_length_ms / 80), 128) # Smaller limit | |
| # Process text | |
| cleaned_text = clean_text_for_tts(text) | |
| # Prepare tokens | |
| tokens, tokens_mask = [], [] | |
| # Add context segments (limited to 1 for speed) | |
| if context: | |
| segment_tokens, segment_tokens_mask = self._tokenize_segment(context[0]) | |
| tokens.append(segment_tokens) | |
| tokens_mask.append(segment_tokens_mask) | |
| # Add text tokens | |
| gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
| tokens.append(gen_segment_tokens) | |
| tokens_mask.append(gen_segment_tokens_mask) | |
| prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
| prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
| # Generate with larger batch size for fewer iterations | |
| curr_tokens = prompt_tokens.unsqueeze(0) | |
| curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
| curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
| # Use larger batch size | |
| batch_size = 64 # Generate more frames at once | |
| all_samples = [] | |
| for start_idx in range(0, max_audio_frames, batch_size): | |
| end_idx = min(start_idx + batch_size, max_audio_frames) | |
| batch_frames = end_idx - start_idx | |
| samples_batch = [] | |
| for i in range(batch_frames): | |
| sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
| samples_batch.append(sample) | |
| if torch.all(sample == 0): | |
| break | |
| curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
| curr_tokens_mask = torch.cat( | |
| [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
| ).unsqueeze(1) | |
| curr_pos = curr_pos[:, -1:] + 1 | |
| all_samples.extend(samples_batch) | |
| if len(samples_batch) < batch_frames: | |
| break | |
| if not all_samples: | |
| return torch.zeros(10, device=self.device) # Return short empty audio | |
| # Decode audio | |
| audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0) | |
| return audio | |
| def generate( | |
| self, | |
| text: str, | |
| speaker: int, | |
| context: List[Segment], | |
| max_audio_length_ms: float = 90_000, | |
| temperature: float = 0.9, | |
| topk: int = 50, | |
| ) -> torch.Tensor: | |
| """Generate audio from text.""" | |
| if self._audio_tokenizer is None: | |
| raise RuntimeError("Audio tokenizer not initialized") | |
| # Start timing | |
| start_time = torch.cuda.Event(enable_timing=True) | |
| end_time = torch.cuda.Event(enable_timing=True) | |
| start_time.record() | |
| self._model.reset_caches() | |
| # Convert max_audio_length_ms to frames - this controls the maximum generation length | |
| max_audio_frames = min(int(max_audio_length_ms / 80), 1024) # Limit to reasonable size | |
| max_seq_len = 2048 - max_audio_frames | |
| # Check if text is long and should be split | |
| if len(text) > 200: | |
| logger.info(f"Long text detected ({len(text)} chars), processing in segments") | |
| sentences = TextNormalizer.split_into_sentences(text) | |
| logger.info(f"Split into {len(sentences)} segments") | |
| # Process sentences individually and concatenate the results | |
| all_audio_segments = [] | |
| # Use the first sentence to establish voice | |
| first_sentence = sentences[0] | |
| cleaned_text = clean_text_for_tts(first_sentence) | |
| # Generate the first segment | |
| tokens, tokens_mask = [], [] | |
| # Add context segments for the first sentence | |
| for segment in context: | |
| segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
| tokens.append(segment_tokens) | |
| tokens_mask.append(segment_tokens_mask) | |
| # Add first sentence tokens | |
| gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
| tokens.append(gen_segment_tokens) | |
| tokens_mask.append(gen_segment_tokens_mask) | |
| prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
| prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
| # Check context size and truncate if needed | |
| if prompt_tokens.size(0) >= max_seq_len: | |
| logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}") | |
| prompt_tokens = prompt_tokens[-max_seq_len+50:] | |
| prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:] | |
| # Generate first sentence audio | |
| curr_tokens = prompt_tokens.unsqueeze(0) | |
| curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
| curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
| # Generate first segment | |
| first_segment_samples = [] | |
| for start_idx in range(0, max_audio_frames, 32): | |
| end_idx = min(start_idx + 32, max_audio_frames) | |
| batch_frames = end_idx - start_idx | |
| samples_batch = [] | |
| for i in range(batch_frames): | |
| sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
| samples_batch.append(sample) | |
| if torch.all(sample == 0): | |
| break | |
| curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
| curr_tokens_mask = torch.cat( | |
| [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
| ).unsqueeze(1) | |
| curr_pos = curr_pos[:, -1:] + 1 | |
| first_segment_samples.extend(samples_batch) | |
| if len(samples_batch) < batch_frames: | |
| break | |
| if not first_segment_samples: | |
| raise RuntimeError("No audio generated for first segment") | |
| # Decode first segment | |
| first_segment_audio = self._audio_tokenizer.decode( | |
| torch.stack(first_segment_samples).permute(1, 2, 0) | |
| ).squeeze(0).squeeze(0) | |
| all_audio_segments.append(first_segment_audio) | |
| # Now process remaining sentences using the first as context | |
| for i, sentence in enumerate(sentences[1:], 1): | |
| logger.info(f"Generating segment {i+1}/{len(sentences)}") | |
| cleaned_text = clean_text_for_tts(sentence) | |
| # Create a context segment from the previous generation | |
| prev_segment = Segment( | |
| speaker=speaker, | |
| text=sentences[i-1], | |
| audio=all_audio_segments[-1] | |
| ) | |
| # Generate with this segment as context | |
| segment_tokens, segment_tokens_mask = [], [] | |
| segment_tokens.append(self._tokenize_segment(prev_segment)[0]) | |
| segment_tokens_mask.append(self._tokenize_segment(prev_segment)[1]) | |
| # Add current segment tokens | |
| current_tokens, current_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
| segment_tokens.append(current_tokens) | |
| segment_tokens_mask.append(current_tokens_mask) | |
| segment_prompt_tokens = torch.cat(segment_tokens, dim=0).long().to(self.device) | |
| segment_prompt_tokens_mask = torch.cat(segment_tokens_mask, dim=0).bool().to(self.device) | |
| # Check length and truncate if needed | |
| if segment_prompt_tokens.size(0) >= max_seq_len: | |
| segment_prompt_tokens = segment_prompt_tokens[-max_seq_len+50:] | |
| segment_prompt_tokens_mask = segment_prompt_tokens_mask[-max_seq_len+50:] | |
| # Generate audio for this segment | |
| curr_tokens = segment_prompt_tokens.unsqueeze(0) | |
| curr_tokens_mask = segment_prompt_tokens_mask.unsqueeze(0) | |
| curr_pos = torch.arange(0, segment_prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
| # Generate segment | |
| segment_samples = [] | |
| for start_idx in range(0, max_audio_frames, 32): | |
| end_idx = min(start_idx + 32, max_audio_frames) | |
| batch_frames = end_idx - start_idx | |
| samples_batch = [] | |
| for i in range(batch_frames): | |
| sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
| samples_batch.append(sample) | |
| if torch.all(sample == 0): | |
| break | |
| curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
| curr_tokens_mask = torch.cat( | |
| [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
| ).unsqueeze(1) | |
| curr_pos = curr_pos[:, -1:] + 1 | |
| segment_samples.extend(samples_batch) | |
| if len(samples_batch) < batch_frames: | |
| break | |
| if not segment_samples: | |
| logger.warning(f"No audio generated for segment {i+1}") | |
| continue | |
| # Decode segment | |
| segment_audio = self._audio_tokenizer.decode( | |
| torch.stack(segment_samples).permute(1, 2, 0) | |
| ).squeeze(0).squeeze(0) | |
| all_audio_segments.append(segment_audio) | |
| # Combine all segments with small pauses | |
| pause_samples = int(0.3 * self.sample_rate) # 300ms pause | |
| pause = torch.zeros(pause_samples, device=self.device) | |
| audio_parts = [] | |
| for i, segment_audio in enumerate(all_audio_segments): | |
| audio_parts.append(segment_audio) | |
| if i < len(all_audio_segments) - 1: | |
| audio_parts.append(pause) | |
| audio = torch.cat(audio_parts) | |
| logger.info(f"Combined {len(all_audio_segments)} segments into final audio") | |
| else: | |
| # For shorter text, standard processing | |
| tokens, tokens_mask = [], [] | |
| # Add context segments | |
| for segment in context: | |
| segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) | |
| tokens.append(segment_tokens) | |
| tokens_mask.append(segment_tokens_mask) | |
| # Process text | |
| cleaned_text = clean_text_for_tts(text) | |
| gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker) | |
| tokens.append(gen_segment_tokens) | |
| tokens_mask.append(gen_segment_tokens_mask) | |
| prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device) | |
| prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device) | |
| # Check context size | |
| if prompt_tokens.size(0) >= max_seq_len: | |
| logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}") | |
| prompt_tokens = prompt_tokens[-max_seq_len+50:] | |
| prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:] | |
| # Generate audio - optimized batch generation | |
| curr_tokens = prompt_tokens.unsqueeze(0) | |
| curr_tokens_mask = prompt_tokens_mask.unsqueeze(0) | |
| curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device) | |
| # Using optimized batch generation | |
| batch_size = 32 # Generate this many frames at once | |
| all_samples = [] | |
| for start_idx in range(0, max_audio_frames, batch_size): | |
| end_idx = min(start_idx + batch_size, max_audio_frames) | |
| batch_frames = end_idx - start_idx | |
| samples_batch = [] | |
| for i in range(batch_frames): | |
| sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk) | |
| samples_batch.append(sample) | |
| if torch.all(sample == 0): | |
| break | |
| curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1) | |
| curr_tokens_mask = torch.cat( | |
| [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1 | |
| ).unsqueeze(1) | |
| curr_pos = curr_pos[:, -1:] + 1 | |
| all_samples.extend(samples_batch) | |
| if len(samples_batch) < batch_frames: | |
| logger.info(f"Early stopping at frame {start_idx + len(samples_batch)}/{max_audio_frames}") | |
| break | |
| if not all_samples: | |
| raise RuntimeError("No audio generated - model produced empty output") | |
| # Decode audio | |
| audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0) | |
| # Apply watermark | |
| if self._watermarker is not None: | |
| try: | |
| audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK) | |
| audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate) | |
| except Exception as e: | |
| logger.warning(f"Error applying watermark: {e}. Continuing without watermark.") | |
| # Record execution time | |
| end_time.record() | |
| torch.cuda.synchronize() | |
| execution_ms = start_time.elapsed_time(end_time) | |
| audio_length_ms = (audio.shape[0] / self.sample_rate) * 1000 | |
| # Calculate real-time factor (RTF) | |
| rtf = execution_ms / audio_length_ms | |
| logger.info(f"Audio generated in {execution_ms:.2f}ms, length: {audio_length_ms:.2f}ms, RTF: {rtf:.2f}x") | |
| return audio | |
| # Define helper functions for multi-GPU support | |
| def _manual_device_map(model, state_dict, strategy="balanced"): | |
| """Apply manual device mapping for multi-GPU setups. | |
| Args: | |
| model: The model to map | |
| state_dict: Model state dict | |
| strategy: Mapping strategy ('balanced', 'sequential') | |
| Returns: | |
| Model with weights distributed across GPUs | |
| """ | |
| num_gpus = torch.cuda.device_count() | |
| if num_gpus <= 1: | |
| # No need for mapping with single GPU | |
| model.load_state_dict(state_dict) | |
| model = model.to("cuda") | |
| return model | |
| logger.info(f"Applying manual {strategy} device mapping across {num_gpus} GPUs") | |
| # Get all layer names from state dict | |
| layer_names = [name for name in state_dict.keys() if "layers" in name] | |
| backbone_layers = [name for name in layer_names if "backbone.layers" in name] | |
| decoder_layers = [name for name in layer_names if "decoder.layers" in name] | |
| # Count number of backbone and decoder layers | |
| backbone_layer_indices = set() | |
| for name in backbone_layers: | |
| parts = name.split('.') | |
| if len(parts) > 2: | |
| try: | |
| backbone_layer_indices.add(int(parts[2])) | |
| except ValueError: | |
| pass | |
| decoder_layer_indices = set() | |
| for name in decoder_layers: | |
| parts = name.split('.') | |
| if len(parts) > 2: | |
| try: | |
| decoder_layer_indices.add(int(parts[2])) | |
| except ValueError: | |
| pass | |
| num_backbone_layers = len(backbone_layer_indices) | |
| num_decoder_layers = len(decoder_layer_indices) | |
| # Create device map | |
| device_map = {} | |
| if strategy == "balanced": | |
| # Distribute layers evenly across GPUs | |
| layers_per_gpu = (num_backbone_layers + num_decoder_layers) // num_gpus | |
| remainder = (num_backbone_layers + num_decoder_layers) % num_gpus | |
| # Assign backbone layers | |
| for i in backbone_layer_indices: | |
| gpu_idx = min(i // layers_per_gpu, num_gpus - 1) | |
| device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}" | |
| # Assign decoder layers | |
| for i in decoder_layer_indices: | |
| gpu_idx = min((i + num_backbone_layers) // layers_per_gpu, num_gpus - 1) | |
| device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}" | |
| elif strategy == "sequential": | |
| # Fill each GPU sequentially | |
| # Backbone layers on first GPU(s) | |
| backbone_per_gpu = max(1, num_backbone_layers // ((num_gpus + 1) // 2)) | |
| for i in backbone_layer_indices: | |
| gpu_idx = min(i // backbone_per_gpu, (num_gpus + 1) // 2 - 1) | |
| device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}" | |
| # Decoder layers on remaining GPU(s) | |
| decoder_per_gpu = max(1, num_decoder_layers // (num_gpus - (num_gpus + 1) // 2 + 1)) | |
| for i in decoder_layer_indices: | |
| gpu_idx = min(i // decoder_per_gpu + (num_gpus + 1) // 2 - 1, num_gpus - 1) | |
| device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}" | |
| # Assign embeddings and other components | |
| device_map["text_embeddings"] = "cuda:0" | |
| device_map["audio_embeddings"] = "cuda:0" | |
| device_map["projection"] = "cuda:0" | |
| device_map["codebook0_head"] = "cuda:0" | |
| device_map["audio_head"] = "cuda:0" | |
| # Load state dict with device mapping | |
| model.load_state_dict(state_dict) | |
| # Move model parts to assigned devices | |
| for name, device in device_map.items(): | |
| if "backbone.layers" in name: | |
| layer_idx = int(name.split('.')[-1]) | |
| if hasattr(model.backbone, 'layers') and layer_idx < len(model.backbone.layers): | |
| model.backbone.layers[layer_idx] = model.backbone.layers[layer_idx].to(device) | |
| elif "decoder.layers" in name: | |
| layer_idx = int(name.split('.')[-1]) | |
| if hasattr(model.decoder, 'layers') and layer_idx < len(model.decoder.layers): | |
| model.decoder.layers[layer_idx] = model.decoder.layers[layer_idx].to(device) | |
| elif hasattr(model, name): | |
| setattr(model, name, getattr(model, name).to(device)) | |
| logger.info(f"Model distributed across GPUs with {strategy} strategy") | |
| return model | |
| def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda", device_map: str = None) -> Generator: | |
| """Load CSM-1B model and create generator with performance optimizations. | |
| Args: | |
| ckpt_path: Path to model checkpoint | |
| device: Device to load model on ('cuda', 'cpu', or specific CUDA device) | |
| device_map: Optional device mapping strategy ('auto', 'balanced', 'sequential', or None) | |
| Returns: | |
| Generator instance with optimized settings | |
| """ | |
| try: | |
| # Import models module for CSM | |
| from app.torchtune_models import Model, ModelArgs | |
| # Create model | |
| model_args = ModelArgs( | |
| backbone_flavor="llama-1B", | |
| decoder_flavor="llama-100M", | |
| text_vocab_size=128256, | |
| audio_vocab_size=2051, | |
| audio_num_codebooks=32, | |
| ) | |
| # Load model | |
| logger.info(f"Loading CSM-1B model from {ckpt_path} with device={device}, device_map={device_map}") | |
| # Check for CUDA availability | |
| cuda_available = device == "cuda" and torch.cuda.is_available() | |
| # Set up torch for optimized inference | |
| if cuda_available: | |
| # Check if we should enable TF32 (faster but slightly less precise) | |
| enable_tf32 = os.environ.get("ENABLE_TF32", "true").lower() == "true" | |
| if enable_tf32: | |
| logger.info("Enabling TF32 for faster matrix multiplications") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Check for available precision modes | |
| use_bfloat16 = torch.cuda.is_bf16_supported() | |
| use_float16 = not use_bfloat16 and torch.cuda.is_available() # Fallback to float16 | |
| if use_bfloat16: | |
| dtype = torch.bfloat16 | |
| logger.info("Using bfloat16 precision for faster inference") | |
| elif use_float16: | |
| dtype = torch.float16 | |
| logger.info("Using float16 precision for faster inference") | |
| else: | |
| dtype = torch.float32 | |
| logger.info("Using float32 precision (mixed precision not available)") | |
| # Enable Flash Attention if available | |
| try: | |
| import flash_attn | |
| if os.environ.get("ENABLE_FLASH_ATTN", "true").lower() == "true": | |
| logger.info("Flash Attention detected - enabling for faster attention") | |
| os.environ["PYTORCH_FLASH_ATTENTION_ENABLED"] = "1" | |
| except ImportError: | |
| logger.info("Flash Attention not available (install flash-attn for faster inference)") | |
| else: | |
| # CPU-only mode | |
| dtype = torch.float32 | |
| logger.info("Using CPU mode with float32 precision") | |
| # Check for quantization | |
| enable_quantization = os.environ.get("ENABLE_QUANTIZATION", "false").lower() == "true" | |
| is_quantized = False | |
| # Check for multi-GPU setup | |
| if device_map and torch.cuda.device_count() > 1: | |
| logger.info(f"Using device_map={device_map} across {torch.cuda.device_count()} GPUs") | |
| # Create model with device map | |
| model = Model(model_args) | |
| # Load state dict | |
| state_dict = torch.load(ckpt_path, map_location='cpu') | |
| # Try quantization before device mapping if enabled | |
| if enable_quantization and cuda_available: | |
| try: | |
| from bitsandbytes.nn import Linear8bitLt | |
| def replace_with_8bit(model): | |
| """Replace linear layers with 8-bit quantized versions""" | |
| for name, module in model.named_modules(): | |
| if isinstance(module, torch.nn.Linear) and module.out_features > 256: | |
| parent_name = name.rsplit('.', 1)[0] if '.' in name else '' | |
| parent = model | |
| if parent_name: | |
| for attr in parent_name.split('.'): | |
| parent = getattr(parent, attr) | |
| child_name = name.rsplit('.', 1)[1] if '.' in name else name | |
| setattr(parent, child_name, Linear8bitLt.from_float(module)) | |
| return model | |
| logger.info("Applying 8-bit quantization to linear layers") | |
| model = replace_with_8bit(model) | |
| is_quantized = True | |
| except ImportError: | |
| logger.warning("bitsandbytes not available, skipping quantization") | |
| # Apply device mapping | |
| if device_map == "auto": | |
| # Use accelerate for automatic device mapping | |
| try: | |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
| # Initialize empty model | |
| with init_empty_weights(): | |
| empty_model = Model(model_args) | |
| # Load and dispatch model across GPUs | |
| model = load_checkpoint_and_dispatch( | |
| empty_model, | |
| ckpt_path, | |
| device_map="auto", | |
| no_split_module_classes=["TransformerLayer"], | |
| # Offload CPU if very large model | |
| offload_folder="offload" if os.environ.get("OFFLOAD_TO_CPU", "false").lower() == "true" else None | |
| ) | |
| logger.info("Model loaded with automatic device mapping") | |
| except ImportError: | |
| logger.warning("accelerate package not found, falling back to manual device mapping") | |
| model = _manual_device_map(model, state_dict, "balanced") | |
| except Exception as mapping_error: | |
| logger.error(f"Auto device mapping failed: {mapping_error}, falling back to manual") | |
| model = _manual_device_map(model, state_dict, "balanced") | |
| else: | |
| # Manual device mapping | |
| model = _manual_device_map(model, state_dict, device_map or "balanced") | |
| else: | |
| # Single GPU or CPU setup | |
| # Try quantization before loading if enabled (GPU only) | |
| if enable_quantization and cuda_available and not is_quantized: | |
| try: | |
| # First load to CPU for quantization | |
| model = Model(model_args).to("cpu") | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| from bitsandbytes.nn import Linear8bitLt | |
| def replace_with_8bit(model): | |
| """Replace linear layers with 8-bit quantized versions""" | |
| for name, module in model.named_modules(): | |
| if isinstance(module, torch.nn.Linear) and module.out_features > 256: | |
| parent_name = name.rsplit('.', 1)[0] if '.' in name else '' | |
| parent = model | |
| if parent_name: | |
| for attr in parent_name.split('.'): | |
| parent = getattr(parent, attr) | |
| child_name = name.rsplit('.', 1)[1] if '.' in name else name | |
| setattr(parent, child_name, Linear8bitLt.from_float(module)) | |
| return model | |
| logger.info("Applying 8-bit quantization to linear layers") | |
| model = replace_with_8bit(model) | |
| model = model.to(device=device) | |
| is_quantized = True | |
| except ImportError: | |
| logger.warning("bitsandbytes not available, loading without quantization") | |
| # Load the standard way | |
| model = Model(model_args).to(device=device, dtype=dtype) | |
| state_dict = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| except Exception as quant_error: | |
| logger.error(f"Quantization failed: {quant_error}, loading without quantization") | |
| # Load the standard way | |
| model = Model(model_args).to(device=device, dtype=dtype) | |
| state_dict = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| else: | |
| # Standard load without quantization | |
| model = Model(model_args).to(device=device, dtype=dtype) | |
| state_dict = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| # Apply torch.compile if available (PyTorch 2.0+) | |
| compile_mode = os.environ.get("TORCH_COMPILE_MODE", "none") | |
| if hasattr(torch, 'compile') and compile_mode != "none" and cuda_available: | |
| try: | |
| logger.info(f"Using torch.compile with mode '{compile_mode}' for faster inference") | |
| if compile_mode == "default": | |
| model = torch.compile(model) | |
| else: | |
| model = torch.compile(model, mode=compile_mode) | |
| except Exception as compile_error: | |
| logger.warning(f"Torch compile failed (requires PyTorch 2.0+): {compile_error}") | |
| # Try to optimize CUDA graphs for faster inference (advanced) | |
| use_cuda_graphs = os.environ.get("USE_CUDA_GRAPHS", "false").lower() == "true" | |
| if use_cuda_graphs and cuda_available and hasattr(torch.cuda, 'CUDAGraph'): | |
| try: | |
| logger.info("Setting up CUDA graphs for repeated inference patterns") | |
| # This requires custom integration inside the model's forward method | |
| # Just flagging that CUDA graphs should be used | |
| model.use_cuda_graphs = True | |
| except Exception as cuda_graph_error: | |
| logger.warning(f"CUDA graphs setup failed: {cuda_graph_error}") | |
| model.use_cuda_graphs = False | |
| # Set optimal settings for CUDA context | |
| if cuda_available: | |
| # Set benchmark mode for hardware-specific optimizations | |
| torch.backends.cudnn.benchmark = True | |
| # Clean up CUDA cache before creating generator | |
| torch.cuda.empty_cache() | |
| # Ensure all CUDA work is completed to avoid launch delays | |
| torch.cuda.synchronize() | |
| # Create generator | |
| logger.info("Creating generator with optimized settings") | |
| generator = Generator(model) | |
| # Log memory usage if on CUDA | |
| if cuda_available: | |
| memory_allocated = torch.cuda.memory_allocated() / (1024**3) | |
| memory_reserved = torch.cuda.memory_reserved() / (1024**3) | |
| logger.info(f"Model loaded, CUDA memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved") | |
| logger.info(f"Generator created successfully: precision={dtype}, quantized={is_quantized}") | |
| return generator | |
| except Exception as e: | |
| logger.error(f"Failed to load CSM-1B model: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| raise |