# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from wenet.osum_echat.init_llmasr import init_llmasr from wenet.transformer.asr_model import ASRModel from wenet.transformer.cmvn import GlobalCMVN from wenet.transformer.ctc import CTC from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.whisper.whisper import Whisper from wenet.utils.cmvn import load_cmvn from wenet.utils.checkpoint import load_checkpoint, load_trained_modules WENET_ENCODER_CLASSES = { "transformer": TransformerEncoder, "conformer": ConformerEncoder, } WENET_DECODER_CLASSES = { "transformer": TransformerDecoder, "bitransformer": BiTransformerDecoder, } WENET_CTC_CLASSES = { "ctc": CTC, } WENET_MODEL_CLASSES = { "asr_model": ASRModel, "whisper": Whisper, } def init_speech_model(args, configs): # TODO(xcsong): Forcefully read the 'cmvn' attribute. if configs.get('cmvn', None) == 'global_cmvn': mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], configs['cmvn_conf']['is_json_cmvn']) global_cmvn = GlobalCMVN( torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) else: global_cmvn = None input_dim = configs['input_dim'] vocab_size = configs['output_dim'] encoder_type = configs.get('encoder', 'conformer') decoder_type = configs.get('decoder', 'bitransformer') ctc_type = configs.get('ctc', 'ctc') encoder = WENET_ENCODER_CLASSES[encoder_type]( input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'], **configs['encoder_conf']['efficient_conf'] if 'efficient_conf' in configs['encoder_conf'] else {}) decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size, encoder.output_size(), **configs['decoder_conf']) ctc = WENET_CTC_CLASSES[ctc_type]( vocab_size, encoder.output_size(), blank_id=configs['ctc_conf']['ctc_blank_id'] if 'ctc_conf' in configs else 0) model_type = configs.get('model', 'asr_model') model = WENET_MODEL_CLASSES[model_type]( vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, special_tokens=configs.get('tokenizer_conf', {}).get('special_tokens', None), **configs['model_conf']) return model, configs def init_model(args, configs): model_type = configs.get('model', 'asr_model') configs['model'] = model_type if model_type == "osum_echat": is_inference =configs.get('is_inference', False) model = init_llmasr(args, configs, is_inference=is_inference) return model else: model, configs = init_speech_model(args, configs) # If specify checkpoint, load some info from checkpoint if hasattr(args, 'checkpoint') and args.checkpoint is not None: infos = load_checkpoint(model, args.checkpoint) elif hasattr(args, 'enc_init') and args.enc_init is not None: infos = load_trained_modules(model, args) else: infos = {} if configs.get('init_step', False): infos = {} configs["init_infos"] = infos if hasattr(args, 'use_lora') and args.use_lora: if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path: load_checkpoint(model, args.lora_ckpt_path) print(configs) # Trye to tie some weights if hasattr(model, 'tie_or_clone_weights'): if not hasattr(args, 'jit'): args.jit = True # i.e. export onnx/jit/ipex model.tie_or_clone_weights(args.jit) if int(os.environ.get('RANK', 0)) == 0: print(configs) return model, configs