OSUM-EChat / wenet /utils /init_model.py
xlgeng's picture
开始部署
841f290
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from wenet.osum_echat.init_llmasr import init_llmasr
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.cmvn import GlobalCMVN
from wenet.transformer.ctc import CTC
from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder
from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder
from wenet.whisper.whisper import Whisper
from wenet.utils.cmvn import load_cmvn
from wenet.utils.checkpoint import load_checkpoint, load_trained_modules
WENET_ENCODER_CLASSES = {
"transformer": TransformerEncoder,
"conformer": ConformerEncoder,
}
WENET_DECODER_CLASSES = {
"transformer": TransformerDecoder,
"bitransformer": BiTransformerDecoder,
}
WENET_CTC_CLASSES = {
"ctc": CTC,
}
WENET_MODEL_CLASSES = {
"asr_model": ASRModel,
"whisper": Whisper,
}
def init_speech_model(args, configs):
# TODO(xcsong): Forcefully read the 'cmvn' attribute.
if configs.get('cmvn', None) == 'global_cmvn':
mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'],
configs['cmvn_conf']['is_json_cmvn'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float())
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
encoder_type = configs.get('encoder', 'conformer')
decoder_type = configs.get('decoder', 'bitransformer')
ctc_type = configs.get('ctc', 'ctc')
encoder = WENET_ENCODER_CLASSES[encoder_type](
input_dim,
global_cmvn=global_cmvn,
**configs['encoder_conf'],
**configs['encoder_conf']['efficient_conf']
if 'efficient_conf' in configs['encoder_conf'] else {})
decoder = WENET_DECODER_CLASSES[decoder_type](vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = WENET_CTC_CLASSES[ctc_type](
vocab_size,
encoder.output_size(),
blank_id=configs['ctc_conf']['ctc_blank_id']
if 'ctc_conf' in configs else 0)
model_type = configs.get('model', 'asr_model')
model = WENET_MODEL_CLASSES[model_type](
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
special_tokens=configs.get('tokenizer_conf',
{}).get('special_tokens', None),
**configs['model_conf'])
return model, configs
def init_model(args, configs):
model_type = configs.get('model', 'asr_model')
configs['model'] = model_type
if model_type == "osum_echat":
is_inference =configs.get('is_inference', False)
model = init_llmasr(args, configs, is_inference=is_inference)
return model
else:
model, configs = init_speech_model(args, configs)
# If specify checkpoint, load some info from checkpoint
if hasattr(args, 'checkpoint') and args.checkpoint is not None:
infos = load_checkpoint(model, args.checkpoint)
elif hasattr(args, 'enc_init') and args.enc_init is not None:
infos = load_trained_modules(model, args)
else:
infos = {}
if configs.get('init_step', False):
infos = {}
configs["init_infos"] = infos
if hasattr(args, 'use_lora') and args.use_lora:
if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path:
load_checkpoint(model, args.lora_ckpt_path)
print(configs)
# Trye to tie some weights
if hasattr(model, 'tie_or_clone_weights'):
if not hasattr(args, 'jit'):
args.jit = True # i.e. export onnx/jit/ipex
model.tie_or_clone_weights(args.jit)
if int(os.environ.get('RANK', 0)) == 0:
print(configs)
return model, configs