SATEv1.5 / process_enni_dataset.py
Shuwei Hou
initial_for_hf
5806e12
#!/usr/bin/env python3
"""
Script to process all MP3 files in ENNI SLI and TD datasets
Performs transcription and C-unit segmentation, then provides statistics
"""
import os
import glob
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple
from transcription import translate_audio_file
from segmentation import segment_batchalign
from segmentation.segment import reorganize_transcription_c_unit
def find_mp3_files(base_paths: List[str]) -> Dict[str, List[str]]:
"""Find all MP3 files in the given directories"""
all_files = {}
for base_path in base_paths:
if not os.path.exists(base_path):
print(f"Warning: Path does not exist: {base_path}")
continue
mp3_files = glob.glob(os.path.join(base_path, "**/*.mp3"), recursive=True)
dataset_name = os.path.basename(base_path)
all_files[dataset_name] = mp3_files
print(f"Found {len(mp3_files)} MP3 files in {dataset_name}")
return all_files
def process_single_audio(audio_path: str, device: str = "cuda") -> Tuple[int, int, bool]:
"""
Process a single audio file and return C-unit statistics
Returns: (cunit_count, ignored_boundary_count, success)
"""
try:
print(f"\nProcessing: {os.path.basename(audio_path)}")
# Transcription
result_data, session_id = translate_audio_file(
model="mazeWhisper",
audio_path=audio_path,
device=device,
enable_alignment=True,
align_language="en"
)
# C-unit segmentation
cunit_count, ignored_count = reorganize_transcription_c_unit(
session_id,
segment_batchalign
)
print(f" → {cunit_count} C-units, {ignored_count} ignored boundaries")
return cunit_count, ignored_count, True
except Exception as e:
print(f" → Error processing {audio_path}: {str(e)}")
return 0, 0, False
def process_dataset(dataset_files: Dict[str, List[str]], device: str = "cuda") -> Dict[str, Dict]:
"""Process all files in the dataset and collect statistics"""
results = {}
for dataset_name, file_list in dataset_files.items():
print(f"\n{'='*60}")
print(f"Processing {dataset_name} dataset ({len(file_list)} files)")
print(f"{'='*60}")
dataset_stats = {
'total_files': len(file_list),
'processed_files': 0,
'failed_files': 0,
'total_cunits': 0,
'total_ignored_boundaries': 0,
'processing_times': [],
'failed_files_list': []
}
for i, audio_path in enumerate(file_list, 1):
start_time = time.time()
print(f"[{i}/{len(file_list)}] Processing: {os.path.basename(audio_path)}")
cunit_count, ignored_count, success = process_single_audio(audio_path, device)
processing_time = time.time() - start_time
dataset_stats['processing_times'].append(processing_time)
if success:
dataset_stats['processed_files'] += 1
dataset_stats['total_cunits'] += cunit_count
dataset_stats['total_ignored_boundaries'] += ignored_count
else:
dataset_stats['failed_files'] += 1
dataset_stats['failed_files_list'].append(audio_path)
print(f" → Time: {processing_time:.2f}s")
results[dataset_name] = dataset_stats
return results
def print_statistics(results: Dict[str, Dict]):
"""Print comprehensive statistics"""
print(f"\n{'='*80}")
print("COMPREHENSIVE STATISTICS")
print(f"{'='*80}")
total_files = 0
total_processed = 0
total_failed = 0
total_cunits = 0
total_ignored = 0
for dataset_name, stats in results.items():
print(f"\n{dataset_name.upper()} DATASET:")
print(f" Total files: {stats['total_files']}")
print(f" Successfully processed: {stats['processed_files']}")
print(f" Failed: {stats['failed_files']}")
print(f" Success rate: {(stats['processed_files']/stats['total_files']*100):.1f}%")
print(f" Total C-units: {stats['total_cunits']}")
print(f" Total ignored boundaries: {stats['total_ignored_boundaries']}")
if stats['processing_times']:
avg_time = sum(stats['processing_times']) / len(stats['processing_times'])
print(f" Average processing time: {avg_time:.2f}s per file")
if stats['processed_files'] > 0:
avg_cunits = stats['total_cunits'] / stats['processed_files']
print(f" Average C-units per file: {avg_cunits:.1f}")
if stats['failed_files_list']:
print(f" Failed files:")
for failed_file in stats['failed_files_list']:
print(f" - {os.path.basename(failed_file)}")
total_files += stats['total_files']
total_processed += stats['processed_files']
total_failed += stats['failed_files']
total_cunits += stats['total_cunits']
total_ignored += stats['total_ignored_boundaries']
print(f"\nGLOBAL STATISTICS:")
print(f" Total files across all datasets: {total_files}")
print(f" Total successfully processed: {total_processed}")
print(f" Total failed: {total_failed}")
print(f" Overall success rate: {(total_processed/total_files*100):.1f}%")
print(f" Total C-units generated: {total_cunits}")
print(f" Total ignored boundaries: {total_ignored}")
if total_processed > 0:
print(f" Average C-units per processed file: {total_cunits/total_processed:.1f}")
print(f" Average ignored boundaries per processed file: {total_ignored/total_processed:.1f}")
def save_results(results: Dict[str, Dict], output_file: str = "enni_processing_results.json"):
"""Save results to JSON file"""
# Remove non-serializable data
clean_results = {}
for dataset_name, stats in results.items():
clean_results[dataset_name] = {
'total_files': stats['total_files'],
'processed_files': stats['processed_files'],
'failed_files': stats['failed_files'],
'total_cunits': stats['total_cunits'],
'total_ignored_boundaries': stats['total_ignored_boundaries'],
'average_processing_time': sum(stats['processing_times']) / len(stats['processing_times']) if stats['processing_times'] else 0,
'failed_files_list': [os.path.basename(f) for f in stats['failed_files_list']]
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(clean_results, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to: {output_file}")
def main():
"""Main processing function"""
# Define dataset paths
dataset_paths = [
"/home/easgrad/shuweiho/workspace/volen/data/ENNI/SLI",
"/home/easgrad/shuweiho/workspace/volen/data/ENNI/TD"
]
print("ENNI Dataset Processing Script")
print("="*50)
# Find all MP3 files
print("Searching for MP3 files...")
dataset_files = find_mp3_files(dataset_paths)
if not any(dataset_files.values()):
print("No MP3 files found in the specified directories!")
return
total_files = sum(len(files) for files in dataset_files.values())
print(f"\nTotal MP3 files found: {total_files}")
# Ask for confirmation
response = input(f"\nProceed with processing {total_files} files? (y/N): ")
if response.lower() != 'y':
print("Processing cancelled.")
return
# Process all files
device = "cuda" # Change to "cpu" if needed
print(f"\nUsing device: {device}")
start_time = time.time()
results = process_dataset(dataset_files, device)
total_time = time.time() - start_time
# Print statistics
print_statistics(results)
print(f"\nTotal processing time: {total_time/60:.1f} minutes")
# Save results
save_results(results)
print("\nProcessing complete!")
if __name__ == "__main__":
main()