Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import torch | |
from wenet.osum_echat.init_llmasr import init_llmasr | |
from wenet.transformer.asr_model import ASRModel | |
from wenet.transformer.cmvn import GlobalCMVN | |
from wenet.transformer.ctc import CTC | |
from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder | |
from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder | |
from wenet.whisper.whisper import Whisper | |
from wenet.utils.cmvn import load_cmvn | |
from wenet.utils.checkpoint import load_checkpoint, load_trained_modules | |
WENET_ENCODER_CLASSES = { | |
"transformer": TransformerEncoder, | |
"conformer": ConformerEncoder, | |
} | |
WENET_DECODER_CLASSES = { | |
"transformer": TransformerDecoder, | |
"bitransformer": BiTransformerDecoder, | |
} | |
WENET_CTC_CLASSES = { | |
"ctc": CTC, | |
} | |
WENET_MODEL_CLASSES = { | |
"asr_model": ASRModel, | |
"whisper": Whisper, | |
} | |
def init_speech_model(args, configs): | |
# TODO(xcsong): Forcefully read the 'cmvn' attribute. | |
if configs.get('cmvn', None) == 'global_cmvn': | |
mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], | |
configs['cmvn_conf']['is_json_cmvn']) | |
global_cmvn = GlobalCMVN( | |
torch.from_numpy(mean).float(), | |
torch.from_numpy(istd).float()) | |
else: | |
global_cmvn = None | |
input_dim = configs['input_dim'] | |
vocab_size = configs['output_dim'] | |
encoder_type = configs.get('encoder', 'conformer') | |
decoder_type = configs.get('decoder', 'bitransformer') | |
ctc_type = configs.get('ctc', 'ctc') | |
encoder = WENET_ENCODER_CLASSES[encoder_type]( | |
input_dim, | |
global_cmvn=global_cmvn, | |
**configs['encoder_conf'], | |
**configs['encoder_conf']['efficient_conf'] | |
if 'efficient_conf' in configs['encoder_conf'] else {}) | |
decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, | |
encoder.output_size(), | |
**configs['decoder_conf']) | |
ctc = WENET_CTC_CLASSES[ctc_type]( | |
vocab_size, | |
encoder.output_size(), | |
blank_id=configs['ctc_conf']['ctc_blank_id'] | |
if 'ctc_conf' in configs else 0) | |
model_type = configs.get('model', 'asr_model') | |
model = WENET_MODEL_CLASSES[model_type]( | |
vocab_size=vocab_size, | |
encoder=encoder, | |
decoder=decoder, | |
ctc=ctc, | |
special_tokens=configs.get('tokenizer_conf', | |
{}).get('special_tokens', None), | |
**configs['model_conf']) | |
return model, configs | |
def init_model(args, configs): | |
model_type = configs.get('model', 'asr_model') | |
configs['model'] = model_type | |
if model_type == "osum_echat": | |
is_inference =configs.get('is_inference', False) | |
model = init_llmasr(args, configs, is_inference=is_inference) | |
return model | |
else: | |
model, configs = init_speech_model(args, configs) | |
# If specify checkpoint, load some info from checkpoint | |
if hasattr(args, 'checkpoint') and args.checkpoint is not None: | |
infos = load_checkpoint(model, args.checkpoint) | |
elif hasattr(args, 'enc_init') and args.enc_init is not None: | |
infos = load_trained_modules(model, args) | |
else: | |
infos = {} | |
if configs.get('init_step', False): | |
infos = {} | |
configs["init_infos"] = infos | |
if hasattr(args, 'use_lora') and args.use_lora: | |
if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path: | |
load_checkpoint(model, args.lora_ckpt_path) | |
print(configs) | |
# Trye to tie some weights | |
if hasattr(model, 'tie_or_clone_weights'): | |
if not hasattr(args, 'jit'): | |
args.jit = True # i.e. export onnx/jit/ipex | |
model.tie_or_clone_weights(args.jit) | |
if int(os.environ.get('RANK', 0)) == 0: | |
print(configs) | |
return model, configs | |