File size: 12,032 Bytes
cd55d4c
4bdd3e8
59cb9c6
6bd44a2
1918df0
 
 
 
 
 
 
 
6bd44a2
b18c401
1918df0
 
 
b18c401
 
 
1918df0
 
 
6bd44a2
 
 
 
1918df0
 
 
b18c401
 
 
 
 
 
 
1918df0
 
6bd44a2
 
1918df0
 
 
 
af7a582
1918df0
 
 
 
fdc5025
1918df0
 
 
 
 
b18c401
1918df0
 
 
 
 
b18c401
 
 
 
 
 
 
 
 
 
 
 
 
1918df0
 
 
b18c401
 
 
 
4bdd3e8
b18c401
4bdd3e8
b18c401
 
 
 
4bdd3e8
b18c401
4bdd3e8
b18c401
4bdd3e8
b18c401
 
 
 
 
 
 
1918df0
b18c401
 
 
 
 
 
1918df0
 
b18c401
1918df0
 
 
fdc5025
1918df0
 
 
 
 
 
 
af7a582
2d4b08f
b18c401
2d4b08f
 
 
af7a582
1918df0
 
 
 
b18c401
fdc5025
1918df0
 
 
 
b18c401
fdc5025
1918df0
 
 
 
b18c401
fdc5025
1918df0
 
 
 
b18c401
fdc5025
1918df0
 
 
 
b18c401
fdc5025
1918df0
 
 
 
b18c401
 
 
 
 
 
 
 
 
 
 
 
 
fdc5025
1918df0
 
af7a582
1918df0
 
 
 
 
 
 
 
 
 
 
 
 
b18c401
 
 
1918df0
 
 
 
 
b18c401
 
 
1918df0
fdc5025
af7a582
 
b18c401
af7a582
b18c401
 
fdc5025
1918df0
fdc5025
1918df0
 
 
b18c401
 
4bdd3e8
b18c401
 
4bdd3e8
 
0b65c40
b18c401
 
4bdd3e8
0b65c40
59cb9c6
b18c401
 
0b65c40
4bdd3e8
0b65c40
 
59cb9c6
b18c401
 
4bdd3e8
59cb9c6
0b65c40
 
 
0d42a10
4bdd3e8
0b65c40
 
 
 
 
0d42a10
4bdd3e8
0b65c40
 
 
4bdd3e8
 
0b65c40
 
 
 
59cb9c6
0b65c40
 
 
 
59cb9c6
b18c401
 
1918df0
6bd44a2
b18c401
0b65c40
6bd44a2
0b65c40
 
4bdd3e8
0b65c40
 
 
1918df0
b478304
af7a582
1918df0
b18c401
0b65c40
 
 
 
 
 
 
af7a582
0d42a10
0b65c40
 
 
 
0d42a10
0b65c40
 
 
af7a582
0d42a10
0b65c40
b18c401
 
0b65c40
 
 
 
 
b18c401
 
 
0b65c40
4bdd3e8
0b65c40
4bdd3e8
0b65c40
 
 
 
 
 
4bdd3e8
0b65c40
b18c401
1918df0
 
4bdd3e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

# app.py
# app.py
from music_llm_agent import MusicAnalysisAgent, AudioFeatureExtractor
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import warnings
import librosa
import pandas as pd
import tensorflow_hub as hub
import tensorflow as tf
import json
import os

warnings.filterwarnings("ignore")

# Set default environment variables if not present
os.environ.setdefault("GROQ_API_KEY", "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")

class EnhancedInstrumentDetector:
    def __init__(self):
        self.yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')
        map_path = tf.keras.utils.get_file(
            'yamnet_class_map.csv',
            'https://raw.githubusercontent.com/tensorflow/models/master/research/audioset/yamnet/yamnet_class_map.csv'
        )
        class_map = pd.read_csv(map_path)
        self.class_names = class_map['display_name'].tolist()
        self.category_map = {
            'Percussion': ['Drum', 'Snare drum', 'Hi-hat', 'Drum kit', 'Tabla', 'Tambourine', 'Percussion'],
            'Vocals': ['Singing', 'Voice', 'Vocal music', 'Vocalist', 'Singer'],
            'High-pitched': ['Flute', 'Violin', 'Trumpet', 'Saxophone', 'Clarinet'],
            'Piano': ['Piano', 'Keyboard (musical)', 'Electric piano', 'Synthesizer'],
            'Guitar': ['Guitar', 'Acoustic guitar', 'Electric guitar', 'Guitar strum', 'Bass guitar'],
            'Electronic': ['Synthesizer', 'Electronic music', 'Techno', 'House music'],
            'Strings': ['Violin', 'Cello', 'Viola', 'String section', 'Orchestral music']
        }
        self.index_to_category = {
            i: cat for cat, names in self.category_map.items()
            for i, name in enumerate(self.class_names) if name in names
        }

    def detect_instruments(self, file):
        waveform, sr = librosa.load(file, sr=16000)
        scores, _, _ = self.yamnet_model(waveform)  # Fixed unpacking
        scores_np = scores.numpy()
        counts = {cat: 0 for cat in self.category_map}
        for frame in scores_np:
            for idx, val in enumerate(frame):
                if val > 0.1 and idx in self.index_to_category:  
                    counts[self.index_to_category[idx]] += 1
        total = sum(counts.values())
        return {k: (v / total) * 100 for k, v in counts.items() if v > 0} if total > 0 else {}

    def visualize_results(self, proportions):
        fig, ax = plt.subplots(figsize=(8, 5))
        if not proportions:
            ax.text(0.5, 0.5, "No instruments detected", ha='center', va='center')
            return fig
        labels = list(proportions.keys())
        values = list(proportions.values())
        
        colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(labels)))
        bars = ax.bar(labels, values, color=colors)
        
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 1,
                    f'{height:.1f}%', ha='center', va='bottom', fontweight='bold')
        
        ax.set_ylabel('% presence', fontsize=12)
        ax.set_title('Detected Instrument Categories', fontsize=14, fontweight='bold')
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        fig.tight_layout()
        return fig

def plot_virality(energy, tempo):
    fig, ax = plt.subplots(figsize=(8, 5))
    
    boost = 1.0
    if energy > 0.22 and tempo > 115:
        boost = 1.15
    elif energy > 0.15 and 100 <= tempo <= 115:
        boost = 1.05
        
    virality_score = min(((tempo / 180.0) * 0.6 + (energy / 0.35) * 0.4) * boost, 1.0) * 100
    
    if virality_score < 40:
        color = '#3498db'
    elif virality_score < 70:
        color = '#f39c12'
    else:
        color = '#e74c3c'
    
    ax.bar(['Virality Score'], [virality_score], color=color)
    ax.axhline(y=30, color='r', linestyle='--', alpha=0.3, label='Low Virality')
    ax.axhline(y=60, color='g', linestyle='--', alpha=0.3, label='Medium Virality')
    ax.axhline(y=85, color='b', linestyle='--', alpha=0.3, label='High Virality')
    ax.text(0, virality_score + 2, f'{virality_score:.1f}%', ha='center', fontsize=14, fontweight='bold')
    
    ax.set_ylim([0, 100])
    ax.set_ylabel('Virality Rate (%)', fontsize=12)
    ax.set_title('Estimated Virality Potential', fontsize=14, fontweight='bold')
    ax.grid(axis='y', linestyle='--', alpha=0.5)
    ax.legend(loc='upper right')
    
    fig.tight_layout()
    return fig

def process_audio(audio_path, llm_provider="groq"):
    extractor = AudioFeatureExtractor(audio_path)
    if not extractor.extract_all_features():
        return "Failed to extract features.", None, "Virality prediction error", "LLM insight unavailable", None
    
    try:
        features_json = extractor.to_json()
        summary = json.loads(features_json)
        detector = EnhancedInstrumentDetector()
        proportions = detector.detect_instruments(audio_path)
        plot = detector.visualize_results(proportions)
        instruments = list(proportions.keys())
        
        try:
            agent = MusicAnalysisAgent(model="llama3-70b-8192", provider=llm_provider)
        except Exception as e:
            print(f"Error initializing MusicAnalysisAgent: {e}")
            return "Agent initialization error", None, "Could not initialize music analysis agent", "LLM unavailable", None
        
        try:
            virality_result = agent.analyze_song_features(features_json)
        except Exception as e:
            print("LLM analysis error:", e)
            virality_result = f"LLM Analysis Error: {str(e)}"
            
        try:
            improvement = agent.get_song_improvement_suggestions(features_json)
        except Exception as e:
            print("Improvement error:", e)
            improvement = f"Improvement Suggestion Error: {str(e)}"
            
        try:
            workout_fit = agent.assess_workout_playlist_fit(features_json)
        except Exception as e:
            print("Workout Fit error:", e)
            workout_fit = f"Workout Fit Suggestion Error: {str(e)}"
            
        try:
            marketing = agent.suggest_marketing_channels(features_json)
        except Exception as e:
            print("Marketing error:", e)
            marketing = f"Marketing Suggestion Error: {str(e)}"
            
        try:
            genre = agent.recommend_genre_classification(features_json)
        except Exception as e:
            print("Genre classification error:", e)
            genre = f"Genre Classification Error: {str(e)}"
            
        try:
            mood_type = agent.recommend_music_category(features_json)
        except Exception as e:
            print("Music type error:", e)
            mood_type = f"Music Type Unavailable: {str(e)}"
            
        try:
            lyric_suggestions = agent.analyze_lyric_improvement(features_json)
        except Exception as e:
            print("Lyric suggestion error:", e)
            lyric_suggestions = f"Lyric suggestions unavailable: {str(e)}"
            
        try:
            commercial_potential = agent.analyze_commercial_potential(features_json)
        except Exception as e:
            print("Commercial potential analysis error:", e)
            commercial_potential = f"Commercial potential analysis unavailable: {str(e)}"
        
        llm_output = f"""
## 🎧 Instrument-Based Analysis
Detected Instruments: {', '.join(instruments) if instruments else 'None Detected'}

## 🎼 Genre Classification
{genre}

## πŸ’– Music Type (e.g., Romantic, Chill, Party)
{mood_type}

## πŸŽ› Overall Analysis
{virality_result}

## 🎚 Suggestions for Improvement
{improvement}

## πŸ–‹ Lyric Improvement Suggestions
{lyric_suggestions}

## πŸ‹οΈ Workout Fit Assessment
{workout_fit}

## πŸ“ˆ Marketing Recommendations
{marketing}

## πŸ’° Commercial Potential Analysis
{commercial_potential}
"""
        
        virality_text = f"""🎼 Track Overview
- File: {summary['file_name']}
- Tempo: {summary['tempo']:.2f} BPM
- Key: {summary['key']}
- Mood: {', '.join(summary['mood_indicators'])}
- Energy: {summary['energy']:.4f}"""
        
        virality_chart = plot_virality(summary['energy'], summary['tempo'])
        
        return ", ".join(instruments), plot, virality_text, llm_output, virality_chart
    except Exception as e:
        print("Full audio processing error:", e)
        return f"Error: {str(e)}", None, "Error", f"LLM Error: {str(e)}", None

# βœ… Fixed CSS block
custom_css = """
body {
    background-color: #87CEEB;
    color: #333333;
    font-family: 'Poppins', sans-serif;
}
.gradio-container {
    background-color: #a9d6f5;
    border-radius: 15px;
    box-shadow: 0 10px 25px rgba(0, 0, 0, 0.2);
}
.output-markdown {
    line-height: 1.7;
    background-color: #c3e1f7;
    padding: 15px;
    border-radius: 10px;
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
h2 {
    color: #555555;
    border-bottom: 2px solid #555555;
    padding-bottom: 10px;
    margin-top: 20px;
}
.block-container {
    background-color: #b8dcf7;
    border-radius: 10px;
    padding: 15px;
    margin-bottom: 15px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.block-container label {
    color: #555555;
    font-weight: bold;
}
button {
    background-color: #ffb6c1 !important;
    color: #333333 !important;
    border: none !important;
    padding: 10px 20px !important;
    border-radius: 8px !important;
    font-weight: bold !important;
    box-shadow: 0 4px 8px rgba(255, 182, 193, 0.3) !important;
    transition: all 0.3s ease !important;
}
button:hover {
    transform: translateY(-2px) !important;
    box-shadow: 0 6px 12px rgba(255, 182, 193, 0.4) !important;
}
"""

with gr.Blocks(
    css=custom_css,
    theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="rose")
) as demo:
    gr.Markdown("""
    # 🎡 AI-Powered Music Analysis Suite

    Upload your track to discover its musical characteristics, commercial potential, and receive expert recommendations.
    """)
    
    with gr.Row():
        audio_input = gr.Audio(type="filepath", label="Upload Song (MP3/WAV)", show_label=True, elem_classes=["audio-component"])
    
    with gr.Row():
        with gr.Column(scale=3):
            llm_choice = gr.Dropdown(
                choices=["groq", "openai", "huggingface"], 
                label="Choose LLM Provider",
                value="groq"
            )
        with gr.Column(scale=1):
            analyze_btn = gr.Button("πŸ” Analyze Audio", variant="primary", size="lg")
    
    with gr.Group(elem_classes=["block-container"]):
        gr.Markdown("### 🎸 Detected Instruments")
        instrument_text = gr.Textbox(label="Instruments Found", interactive=False)
        instrument_plot = gr.Plot(label="Instrument Distribution")
    
    with gr.Group(elem_classes=["block-container"]):
        gr.Markdown("### πŸ“Š Track Overview & Virality")
        virality_output = gr.Textbox(label="Audio Summary", lines=5, interactive=False)
        virality_plot = gr.Plot(label="Virality Score")
    
    with gr.Group(elem_classes=["block-container"]):
        gr.Markdown("### πŸ€– AI Music Analysis")
        llm_output = gr.Markdown(label="AI Insight")
    
    analyze_btn.click(
        fn=process_audio, 
        inputs=[audio_input, llm_choice], 
        outputs=[instrument_text, instrument_plot, virality_output, llm_output, virality_plot]
    )
    
    gr.Markdown("""
    ---
    ### 🎧 About This Tool

    This AI-powered music analysis tool combines powerful audio feature extraction with LLM insights to help musicians, producers, and marketers understand their tracks better.

    **Features:**
    - Instrument detection with TensorFlow YAMNet
    - Audio feature extraction (tempo, key, energy, mood)
    - LLM-powered creative insights
    - Commercial potential analysis
    - Marketing recommendations

    Β© 2025 Music AI Suite | Built with TensorFlow Hub + LangChain + Gradio
    """)

if __name__ == "__main__":
    demo.launch()