import tensorflow as tf import numpy as np import joblib import os import sys module_path = os.path.abspath(os.path.join('./lib')) if module_path not in sys.path: sys.path.append(module_path) from MusicXMLParser import MusicXMLDocument as Parser model = tf.keras.models.load_model('./content/model.h5') tokenizer = joblib.load("./content/tokenizer.pkl") class Generator: NOTES = { "Cb": 11, "C": 0, "C#": 1, "Db": 1, "D": 2, "D#": 3, "Eb": 3, "E": 4, "E#": 5, "Fb": 4, "F": 5, "F#": 6, "Gb": 6, "G": 7, "G#": 8, "Ab": 8, "A": 9, "A#": 10, "Bb": 10, "B": 11, "B#": 0, } NUM_TO_NOTE = { 0: "C", 1: "Db", 2: "D", 3: "Eb", 4: "E", 5: "F", 6: "Gb", 7: "G", 8: "Ab", 9: "A", 10: "Bb", 11: "B", -1: "", } NUM_TO_QUALITY = { 1: "^7", 2: "-7", 3: "7", -1: "", } def decode_tokens(self, numpy_array): total = [] for i in numpy_array: total.append([i]) return tokenizer.sequences_to_texts(total) def decode_chords(self, chords): res = [] for chord in chords: res.append( Generator.NUM_TO_NOTE[chord[0]] + Generator.NUM_TO_QUALITY[chord[1]]) return res def encode_chords(self, chords): res = [] for chord in chords: if not chord: continue c = [] if len(chord) > 2 and chord[0] + chord[1] in Generator.NOTES: c.append(Generator.NOTES[chord[:2]]) else: c.append(Generator.NOTES[chord[0]]) if "^" in chord or "maj" in chord: c.append(1) elif "-" in chord or "m" in chord: c.append(2) elif "7" in chord or "alt" in chord or "11" in chord: c.append(3) else: c.append(1) res.append(c) return res def generateChords(self, chordsBefore, numGenerate = 10): chordsBefore = self.decode_chords(self.encode_chords(chordsBefore)) input_eval = [tokenizer.word_index[chord] for chord in chordsBefore] input_eval = tf.expand_dims(input_eval, 0) temperature = .1 newChords = [] for i in range(int(numGenerate)): predictions = model(input_eval) predictions = tf.squeeze(predictions, 0) predictions = predictions / temperature predicted_id = tf.random.categorical( predictions, num_samples=1)[-1, 0].numpy() input_eval = tf.expand_dims([predicted_id], 0) newChords.append(predicted_id) return self.decode_tokens(newChords)