|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Dasheng model configuration""" |
|
|
|
from transformers import PretrainedConfig |
|
|
|
DASHENG_PRETRAINED_CONFIG_ARCHIVE_MAP = { |
|
"mispeech/dasheng-base": "https://huggingface.co/mispeech/dasheng-base/resolve/main/config.json", |
|
"mispeech/dasheng-0.6B": "https://huggingface.co/mispeech/dasheng-0.6B/resolve/main/config.json", |
|
"mispeech/dasheng-1.2B": "https://huggingface.co/mispeech/dasheng-1.2B/resolve/main/config.json", |
|
} |
|
|
|
|
|
class DashengConfig(PretrainedConfig): |
|
model_type = "dasheng" |
|
|
|
def __init__( |
|
self, |
|
name: str = "dasheng-base", |
|
loss: str = "BCELoss", |
|
**kwargs, |
|
): |
|
r""" |
|
Configuration class for the Dasheng model. |
|
|
|
Args: |
|
name (str, *optional*): |
|
Can be "dasheng-base", "dasheng-0.6B", or "dasheng-1.2B". Default to "dasheng-base". |
|
loss (str, *optional*): |
|
Name of the loss function to use. Can be any loss in `nn.modules.loss`. Default to "BCELoss". |
|
kwargs (dict, *optional*): |
|
Additional keyword arguments, see `dasheng_model.modeling_dasheng.DashengFeatureExtractor` and `dasheng_model.modeling_dasheng.AudioTransformerMAE_Encoder` for more details. |
|
""" |
|
|
|
super().__init__(**kwargs) |
|
|
|
encoder_kwargs = dict(target_length=1008, patch_size=[64, 4], patch_stride=[64, 4]) |
|
|
|
if name == "dasheng-1.2B": |
|
encoder_kwargs["embed_dim"] = 1536 |
|
encoder_kwargs["depth"] = 40 |
|
encoder_kwargs["num_heads"] = 24 |
|
elif name == "dasheng-0.6B": |
|
encoder_kwargs["embed_dim"] = 1280 |
|
encoder_kwargs["depth"] = 32 |
|
encoder_kwargs["num_heads"] = 16 |
|
elif name == "dasheng-base": |
|
encoder_kwargs["embed_dim"] = 768 |
|
encoder_kwargs["depth"] = 12 |
|
encoder_kwargs["num_heads"] = 12 |
|
else: |
|
raise ValueError(f"Unrecognized model name: {name}") |
|
self.name = name |
|
|
|
encoder_kwargs.update((k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
|
self.encoder_kwargs = {**encoder_kwargs, **kwargs} |
|
|
|
self.loss = loss |
|
|