| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import argparse |
| | from model import SALMONN |
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--device", type=str, default="cuda") |
| | parser.add_argument("--ckpt_path", type=str, default='./salomnn_7b.bin') |
| | parser.add_argument("--whisper_path", type=str, default='whisper-large-v2') |
| | parser.add_argument("--beats_path", type=str, default='BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt') |
| | parser.add_argument("--vicuna_path", type=str, default='vicuna-7b-v1.5') |
| | parser.add_argument("--low_resource", action='store_true', default=False) |
| | parser.add_argument("--debug", action="store_true", default=False) |
| |
|
| | args = parser.parse_args() |
| |
|
| | model = SALMONN( |
| | ckpt=args.ckpt_path, |
| | whisper_path=args.whisper_path, |
| | beats_path=args.beats_path, |
| | vicuna_path=args.vicuna_path |
| | ).to(torch.float16).cuda() |
| |
|
| | prompt = 'First describe the music in general in terms of mood, theme, tempo, melody, instruments and chord progression. Then provide a detailed music analysis by describing each functional segment and its time boundaries.' |
| | prompt_tmp = 'This is a Pop music of 69 beat-per-minute (BPM). First describe the music in general in terms of mood, theme, tempo, melody, instruments and chord progression. Then provide a detailed music analysis by describing each functional segment and its time boundaries. Please note that the music boundaries are [0, 41, 58, 83, 100].' |
| | model.eval() |
| | while True: |
| | print("=====================================") |
| | wav_path = input("Your Wav Path:\n") |
| | prompt = input("Your Prompt:\n") |
| | try: |
| | print("Output:") |
| | |
| | with torch.cuda.amp.autocast(dtype=torch.float16): |
| | print(model.generate(wav_path, prompt=prompt, repetition_penalty=1.5, num_beams=10, top_p=.7, temperature=.2)[0]) |
| | except Exception as e: |
| | print(e) |
| | if args.debug: |
| | import pdb |
| |
|
| | pdb.set_trace() |
| |
|