Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import random | |
| from unidecode import unidecode | |
| from samplings import top_p_sampling, temperature_sampling | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| description = """ | |
| <div> | |
| <a style="display:inline-block" href='https://github.com/sander-wood/text-to-music'><img src='https://img.shields.io/github/stars/sander-wood/text-to-music?style=social' /></a> | |
| <a style="display:inline-block" href="https://arxiv.org/pdf/2211.11216.pdf"><img src="https://img.shields.io/badge/arXiv-2211.11216-b31b1b.svg"></a> | |
| <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/sander-wood/text-to-music?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg" alt="Duplicate Space"></a> | |
| </div> | |
| ## ℹ️ How to use this demo? | |
| 1. Enter a query in the text box. | |
| 2. You can set the parameters (i.e., number of tunes, maximum length, top-p, temperature, and random seed) for the generation. (optional) | |
| 3. Click "Submit" and wait for the result. | |
| 4. The generated ABC notation can be converted to MIDI or PDF using [EasyABC](https://sourceforge.net/projects/easyabc/), you can also use this [online renderer](https://ldzhangyx.github.io/abc/) to render the ABC notation. | |
| ## ❕Notice | |
| - The text box is case-sensitive. | |
| - The demo is based on BART-base and fine-tuned on the Textune dataset (282,870 text-music pairs). | |
| - The demo only supports English text as the input. | |
| - The demo is still in the early stage, and the generated music is not perfect. If you have any suggestions, please feel free to contact me via [email](mailto:shangda@mail.ccom.edu.cn). | |
| """ | |
| examples = [ | |
| ["This is a traditional Irish dance music.\nNote Length-1/8\nMeter-6/8\nKey-D", 3, 1024, 0.9, 1.0, 0], | |
| ["This is a jazz-swing lead sheet with chord and vocal.", 3, 1024, 0.9, 1.0, 0] | |
| ] | |
| def generate_abc(text, num_tunes, max_length, top_p, temperature, seed): | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = None | |
| print("Input Text:\n" + text) | |
| text = unidecode(text) | |
| tokenizer = AutoTokenizer.from_pretrained('sander-wood/text-to-music') | |
| model = AutoModelForSeq2SeqLM.from_pretrained('sander-wood/text-to-music') | |
| model = model.to(device) | |
| input_ids = tokenizer(text, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=max_length)['input_ids'].to(device) | |
| decoder_start_token_id = model.config.decoder_start_token_id | |
| eos_token_id = model.config.eos_token_id | |
| random.seed(seed) | |
| tunes = "" | |
| for n_idx in range(num_tunes): | |
| print("\nX:"+str(n_idx+1)+"\n", end="") | |
| tunes += "X:"+str(n_idx+1)+"\n" | |
| decoder_input_ids = torch.tensor([[decoder_start_token_id]]) | |
| for t_idx in range(max_length): | |
| if seed!=None: | |
| n_seed = random.randint(0, 1000000) | |
| random.seed(n_seed) | |
| else: | |
| n_seed = None | |
| outputs = model(input_ids=input_ids, | |
| decoder_input_ids=decoder_input_ids.to(device)) | |
| probs = outputs.logits[0][-1] | |
| probs = torch.nn.Softmax(dim=-1)(probs).cpu().detach().numpy() | |
| sampled_id = temperature_sampling(probs=top_p_sampling(probs, | |
| top_p=top_p, | |
| seed=n_seed, | |
| return_probs=True), | |
| seed=n_seed, | |
| temperature=temperature) | |
| decoder_input_ids = torch.cat((decoder_input_ids, torch.tensor([[sampled_id]])), 1) | |
| if sampled_id!=eos_token_id: | |
| sampled_token = tokenizer.decode([sampled_id]) | |
| print(sampled_token, end="") | |
| tunes += sampled_token | |
| else: | |
| tunes += '\n' | |
| break | |
| return tunes | |
| input_text = gr.inputs.Textbox(lines=5, label="Input Text", placeholder="Describe the music you want to generate ...") | |
| input_num_tunes = gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Tunes") | |
| input_max_length = gr.inputs.Slider(minimum=10, maximum=1000, step=10, default=500, label="Max Length") | |
| input_top_p = gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.05, default=0.9, label="Top P") | |
| input_temperature = gr.inputs.Slider(minimum=0.0, maximum=2.0, step=0.1, default=1.0, label="Temperature") | |
| input_seed = gr.inputs.Textbox(lines=1, label="Seed (int)", default="None") | |
| output_abc = gr.outputs.Textbox(label="Generated Tunes") | |
| gr.Interface(fn=generate_abc, | |
| inputs=[input_text, input_num_tunes, input_max_length, input_top_p, input_temperature, input_seed], | |
| outputs=output_abc, | |
| title="Textune: Generating Tune from Text", | |
| description=description, | |
| examples=examples).launch(debug=True) |