Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from typing import Optional, List | |
| import os | |
| CKPT_NAME = "model.pt" | |
| CKPT_LOCAL_DIR = "model_ckpts" | |
| CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME) | |
| CKPT_REPO = "xcczach/mini-omni" | |
| class VocabConfig: | |
| text_vocabsize: int = 151936 | |
| text_specialtokens: int = 64 | |
| audio_vocabsize: int = 4096 | |
| audio_specialtokens: int = 64 | |
| total_vocabsize: int = 181120 | |
| code_layer: int = 7 | |
| padded_text_vocabsize: int = field(init=False) | |
| padded_audio_vocabsize: int = field(init=False) | |
| total_audio_vocabsize: int = field(init=False) | |
| eot: int = field(init=False) # end of text token | |
| pad_t: int = field(init=False) # padding text token | |
| input_t: int = field(init=False) # input text token | |
| answer_t: int = field(init=False) # answer text token | |
| asr: int = field(init=False) # ASR token | |
| eoa: int = field(init=False) # end of audio token | |
| pad_a: int = field(init=False) # padding audio token | |
| input_a: int = field(init=False) # input audio token | |
| answer_a: int = field(init=False) # answer audio token | |
| split: int = field(init=False) # split token | |
| def __post_init__(self): | |
| self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens | |
| self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens | |
| self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer | |
| self.eot = self.text_vocabsize | |
| self.pad_t = self.text_vocabsize + 1 | |
| self.input_t = self.text_vocabsize + 2 | |
| self.answer_t = self.text_vocabsize + 3 | |
| self.asr = self.text_vocabsize + 4 | |
| self.eoa = self.audio_vocabsize | |
| self.pad_a = self.audio_vocabsize + 1 | |
| self.input_a = self.audio_vocabsize + 2 | |
| self.answer_a = self.audio_vocabsize + 3 | |
| self.split = self.audio_vocabsize + 4 | |
| class TTSAdapterConfig: | |
| add_qkv_bias: Optional[bool] = True | |
| bias: bool = False | |
| gelu_approximate: Optional[str] = None | |
| head_size: Optional[int] = 64 | |
| intermediate_size: Optional[int] = 4864 | |
| lm_head_bias: bool = False | |
| mlp_class_name: str = "GptNeoxMLP" | |
| n_layer: int = 6 | |
| n_head: int = 14 | |
| n_embd: int = 896 | |
| n_query_groups: Optional[int] = 2 | |
| norm_class_name: str = "RMSNorm" | |
| norm_eps: float = 1e-6 | |
| parallel_residual: bool = False | |
| rotary_percentage: float = 1 | |
| shared_attention_norm: bool = False | |
| def __post_init__(self): | |
| self.rope_n_elem = int(self.rotary_percentage * self.head_size) | |
| class ModelConfig: | |
| file: str = "model/slam_model_s2s.py:model_factory" | |
| llm_name: str = "qwen2-0.5b" | |
| llm_path: str = "Qwen/Qwen2-0.5B" | |
| llm_type: str = "decoder_only" | |
| llm_dim: int = 896 | |
| encoder_name: Optional[str] = "whisper" | |
| encoder_ds_rate: int = 2 | |
| encoder_path: Optional[str] = "small" | |
| encoder_dim: int = 768 | |
| encoder_projector: str = "linear" | |
| encoder_projector_ds_rate: int = 5 | |
| modal: str = "audio" | |
| normalize: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "whether input is normalized, used for models such as wavlm"}, | |
| ) | |
| encoder_type: str = field( | |
| default="finetune", | |
| metadata={ | |
| "help": "whether model is only pretrained or finetuned, used for models such as hubert" | |
| }, | |
| ) | |
| vocab_config: VocabConfig = field(default_factory=VocabConfig) | |
| codec_decode: bool = True | |
| codec_decoder_type: str = "SNAC" | |
| codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz" | |
| tts_adapter: bool = False | |
| tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig) | |
| class PeftConfig: | |
| peft_method: str = "lora" # None , llama_adapter, prefix | |
| r: int = 8 | |
| lora_alpha: int = 32 | |
| target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"]) | |
| bias: str = "none" | |
| task_type: str = "CAUSAL_LM" | |
| lora_dropout: float = 0.05 | |
| inference_mode: bool = False | |
| class TrainConfig: | |
| model_name: str = "s2s" | |
| enable_ddp: bool = False | |
| enable_deepspeed: bool = False | |
| enable_fsdp: bool = False | |
| low_cpu_fsdp: bool = False | |
| run_validation: bool = True | |
| batch_size_training: int = 4 | |
| batching_strategy: str = field( | |
| default="custom", metadata={"help": "alternative: padding"} | |
| ) # | |
| context_length: int = 4096 | |
| gradient_accumulation_steps: int = 1 | |
| num_epochs: int = 1 | |
| num_workers_dataloader: int = 2 | |
| warmup_steps: int = 1000 | |
| total_steps: int = 100000 | |
| validation_interval: int = 1000 | |
| lr: float = 1e-4 | |
| weight_decay: float = 0.0 | |
| gamma: float = 0.85 | |
| seed: int = 42 | |
| use_fp16: bool = False | |
| mixed_precision: bool = True | |
| val_batch_size: int = 1 | |
| use_peft: bool = False | |
| peft_config: PeftConfig = field(default_factory=PeftConfig) | |
| output_dir: str = "PATH/to/save/PEFT/model" | |
| freeze_layers: bool = False | |
| num_freeze_layers: int = 1 | |
| quantization: bool = False | |
| one_gpu: bool = False | |
| save_model: bool = True | |
| dist_checkpoint_root_folder: str = ( | |
| "PATH/to/save/FSDP/model" # will be used if using FSDP | |
| ) | |
| dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP | |
| save_optimizer: bool = False # will be used if using FSDP | |
| use_fast_kernels: bool = ( | |
| False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | |
| ) | |
| run_test_during_validation: bool = False | |
| run_test_during_validation_file: str = "test.wav" | |
| run_test_during_validation_prompt: str = "<|S2S|>" | |
| freeze_llm: bool = field( | |
| default=True, | |
| metadata={ | |
| "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" | |
| }, | |
| ) | |
| freeze_encoder: bool = True | |
| train_embed_only: bool = False | |
| train_audio_embed_only: bool = False | |
| task_type: str = "s2s" | |
| class DataConfig: | |
| dataset: str = "speech_dataset_s2s" | |
| file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset" | |
| train_data_path: Optional[str] = None | |
| val_data_path: Optional[str] = None | |
| train_split: str = "train" | |
| test_split: str = "validation" | |
| prompt: Optional[str] = None | |
| data_path: Optional[str] = None | |
| max_words: Optional[int] = None | |
| max_mel: Optional[float] = None | |
| fix_length_audio: int = -1 | |
| inference_mode: bool = True | |
| input_type: str = field( | |
| default="mel", | |
| metadata={"help": "Use raw when input is wav, mel when for whisper"}, | |
| ) | |
| mel_size: int = field( | |
| default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"} | |
| ) | |
| normalize: Optional[bool] = field( | |
| default=False, | |
| metadata={"help": "whether input is normalized, used for models such as wavlm"}, | |
| ) | |
| seed: int = 42 | |
| manifest_format: str = field( | |
| default="datasets", metadata={"help": "alternative: jsonl"} | |
| ) | |
| split_size: float = 0.1 | |
| vocab_config: VocabConfig = field(default_factory=VocabConfig) | |
| load_from_cache_file: bool = False | |
| task_type: str = "s2s" | |
| class DecodeConfig: | |
| do_sample: bool = False | |
| max_new_tokens: int = 300 | |
| min_length: int = 10 | |
| temperature: float = 1.0 | |
| top_k: int = 50 | |
| top_p: float = 0.9 | |
| num_beams: int = 1 | |
| num_return_sequences: int = 1 | |
| num_samples: int = 1 | |
| max_time: float = 0.0 | |
| repetition_penalty: float = 1.0 | |
| length_penalty: float = 1.0 | |
| early_stopping: bool = False | |
| no_repeat_ngram_size: int = 0 | |
| bad_words_ids: List = field(default_factory=list) | |
| num_beam_groups: int = 1 | |
| diversity_penalty: float = 0.0 | |
| task_type: str = "s2s" | |
| decode_text_only: bool = False | |
| class FSDPConfig: | |
| mixed_precision: bool = True | |
| use_fp16: bool = False | |
| # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD | |
| sharding_strategy: str = ( | |
| "NO_SHARD" # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP | |
| ) | |
| checkpoint_type: str = ( | |
| "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. | |
| ) | |
| fsdp_activation_checkpointing: bool = True | |
| fsdp_cpu_offload: bool = False | |
| pure_bf16: bool = False | |
| optimizer: str = "AdamW" | |
| class LogConfig: | |
| use_wandb: bool = False | |
| wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log" | |
| wandb_entity_name: str = "project_name" | |
| wandb_project_name: str = "project_name" | |
| wandb_exp_name: str = "exp_name" | |
| log_file: str = "/valleblob/v-wenxichen/exp/log/test.log" | |
| log_interval: int = 10 | |
| online_output_dir: Optional[str] = None | |
| class InferenceConfig: | |
| dataset_config: DataConfig = field(default_factory=DataConfig) | |
| model_config: ModelConfig = field(default_factory=ModelConfig) | |
| train_config: TrainConfig = field(default_factory=TrainConfig) | |
| decode_config: DecodeConfig = field(default_factory=DecodeConfig) | |