Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
import typing | |
import types # fusion of forward() of Wav2Vec2 | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import torch | |
import torch.nn as nn | |
from transformers import Wav2Vec2Processor | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel | |
import audiofile | |
from tts import StyleTTS2 | |
import audresample | |
import json | |
import re | |
import unicodedata | |
import textwrap | |
import nltk | |
from num2words import num2words | |
from num2word_greek.numbers2words import convert_numbers | |
from audionar import VitsModel, VitsTokenizer | |
nltk.download('punkt', download_dir='./') | |
nltk.download('punkt_tab', download_dir='./') | |
nltk.data.path.append('.') | |
device = 'cpu' | |
def fix_vocals(text, lang='ron'): | |
# Longer phrases should come before shorter ones to prevent partial matches. | |
ron_replacements = { | |
'ţ': 'ț', | |
'ț': 'ts', | |
'î': 'u', | |
'â': 'a', | |
'ş': 's', | |
'w': 'oui', | |
'k': 'c', | |
'l': 'll', | |
# Math symbols | |
'sqrt': ' rădăcina pătrată din ', | |
'^': ' la puterea ', | |
'+': ' plus ', | |
' - ': ' minus ', # only replace if standalone so to not say minus if is a-b-c | |
'*': ' ori ', # times | |
'/': ' împărțit la ', # divided by | |
'=': ' egal cu ', # equals | |
'pi': ' pi ', | |
'<': ' mai mic decât ', | |
'>': ' mai mare decât', | |
'%': ' la sută ', # percent (from previous) | |
'(': ' paranteză deschisă ', | |
')': ' paranteză închisă ', | |
'[': ' paranteză pătrată deschisă ', | |
']': ' paranteză pătrată închisă ', | |
'{': ' acoladă deschisă ', | |
'}': ' acoladă închisă ', | |
'≠': ' nu este egal cu ', | |
'≤': ' mai mic sau egal cu ', | |
'≥': ' mai mare sau egal cu ', | |
'≈': ' aproximativ ', | |
'∞': ' infinit ', | |
'€': ' euro ', | |
'$': ' dolar ', | |
'£': ' liră ', | |
'&': ' și ', # and | |
'@': ' la ', # at | |
'#': ' diez ', # hash | |
'∑': ' sumă ', | |
'∫': ' integrală ', | |
'√': ' rădăcina pătrată a ', # more generic square root | |
} | |
eng_replacements = { | |
'wik': 'weaky', | |
'sh': 'ss', | |
'ch': 'ttss', | |
'oo': 'oeo', | |
# Math symbols for English | |
'sqrt': ' square root of ', | |
'^': ' to the power of ', | |
'+': ' plus ', | |
' - ': ' minus ', | |
'*': ' times ', | |
' / ': ' divided by ', | |
'=': ' equals ', | |
'pi': ' pi ', | |
'<': ' less than ', | |
'>': ' greater than ', | |
# Additional common math symbols from previous list | |
'%': ' percent ', | |
'(': ' open parenthesis ', | |
')': ' close parenthesis ', | |
'[': ' open bracket ', | |
']': ' close bracket ', | |
'{': ' open curly brace ', | |
'}': ' close curly brace ', | |
'∑': ' sum ', | |
'∫': ' integral ', | |
'√': ' square root of ', | |
'≠': ' not equals ', | |
'≤': ' less than or equals ', | |
'≥': ' greater than or equals ', | |
'≈': ' approximately ', | |
'∞': ' infinity ', | |
'€': ' euro ', | |
'$': ' dollar ', | |
'£': ' pound ', | |
'&': ' and ', | |
'@': ' at ', | |
'#': ' hash ', | |
} | |
serbian_replacements = { | |
'rn': 'rrn', | |
'ć': 'č', | |
'c': 'č', | |
'đ': 'd', | |
'j': 'i', | |
'l': 'lll', | |
'w': 'v', | |
# https://huggingface.co/facebook/mms-tts-rmc-script_latin | |
'sqrt': 'kvadratni koren iz', | |
'^': ' na stepen ', | |
'+': ' plus ', | |
' - ': ' minus ', | |
'*': ' puta ', | |
' / ': ' podeljeno sa ', | |
'=': ' jednako ', | |
'pi': ' pi ', | |
'<': ' manje od ', | |
'>': ' veće od ', | |
'%': ' procenat ', | |
'(': ' otvorena zagrada ', | |
')': ' zatvorena zagrada ', | |
'[': ' otvorena uglasta zagrada ', | |
']': ' zatvorena uglasta zagrada ', | |
'{': ' otvorena vitičasta zagrada ', | |
'}': ' zatvorena vitičasta zagrada ', | |
'∑': ' suma ', | |
'∫': ' integral ', | |
'√': ' kvadratni koren ', | |
'≠': ' nije jednako ', | |
'≤': ' manje ili jednako od ', | |
'≥': ' veće ili jednako od ', | |
'≈': ' približno ', | |
'∞': ' beskonačnost ', | |
'€': ' evro ', | |
'$': ' dolar ', | |
'£': ' funta ', | |
'&': ' i ', | |
'@': ' et ', | |
'#': ' taraba ', | |
# Others | |
# 'rn': 'rrn', | |
# 'ć': 'č', | |
# 'c': 'č', | |
# 'đ': 'd', | |
# 'l': 'le', | |
# 'ij': 'i', | |
# 'ji': 'i', | |
# 'j': 'i', | |
# 'služ': 'sloooozz', # 'službeno' | |
# 'suver': 'siuveeerra', # 'suverena' | |
# 'država': 'dirrezav', # 'država' | |
# 'iči': 'ici', # 'Graniči' | |
# 's ': 'se', # a s with space | |
# 'q': 'ku', | |
# 'w': 'aou', | |
# 'z': 's', | |
# "š": "s", | |
# 'th': 'ta', | |
# 'v': 'vv', | |
# "ć": "č", | |
# "đ": "ď", | |
# "lj": "ľ", | |
# "nj": "ň", | |
# "ž": "z", | |
# "c": "č" | |
} | |
deu_replacements = { | |
'sch': 'sh', | |
'ch': 'kh', | |
'ie': 'ee', | |
'ei': 'ai', | |
'ä': 'ae', | |
'ö': 'oe', | |
'ü': 'ue', | |
'ß': 'ss', | |
# Math symbols for German | |
'sqrt': ' Quadratwurzel aus ', | |
'^': ' hoch ', | |
'+': ' plus ', | |
' - ': ' minus ', | |
'*': ' mal ', | |
' / ': ' geteilt durch ', | |
'=': ' gleich ', | |
'pi': ' pi ', | |
'<': ' kleiner als ', | |
'>': ' größer als', | |
# Additional common math symbols from previous list | |
'%': ' prozent ', | |
'(': ' Klammer auf ', | |
')': ' Klammer zu ', | |
'[': ' eckige Klammer auf ', | |
']': ' eckige Klammer zu ', | |
'{': ' geschweifte Klammer auf ', | |
'}': ' geschweifte Klammer zu ', | |
'∑': ' Summe ', | |
'∫': ' Integral ', | |
'√': ' Quadratwurzel ', | |
'≠': ' ungleich ', | |
'≤': ' kleiner oder gleich ', | |
'≥': ' größer oder gleich ', | |
'≈': ' ungefähr ', | |
'∞': ' unendlich ', | |
'€': ' euro ', | |
'$': ' dollar ', | |
'£': ' pfund ', | |
'&': ' und ', | |
'@': ' at ', # 'Klammeraffe' is also common but 'at' is simpler | |
'#': ' raute ', | |
} | |
fra_replacements = { | |
# French specific phonetic replacements (add as needed) | |
# e.g., 'ç': 's', 'é': 'e', etc. | |
'w': 'v', | |
# Math symbols for French | |
'sqrt': ' racine carrée de ', | |
'^': ' à la puissance ', | |
'+': ' plus ', | |
' - ': ' moins ', # tiré ; | |
'*': ' fois ', | |
' / ': ' divisé par ', | |
'=': ' égale ', | |
'pi': ' pi ', | |
'<': ' inférieur à ', | |
'>': ' supérieur à ', | |
# Add more common math symbols as needed for French | |
'%': ' pour cent ', | |
'(': ' parenthèse ouverte ', | |
')': ' parenthèse fermée ', | |
'[': ' crochet ouvert ', | |
']': ' crochet fermé ', | |
'{': ' accolade ouverte ', | |
'}': ' accolade fermée ', | |
'∑': ' somme ', | |
'∫': ' intégrale ', | |
'√': ' racine carrée ', | |
'≠': ' n\'égale pas ', | |
'≤': ' inférieur ou égal à ', | |
'≥': ' supérieur ou égal à ', | |
'≈': ' approximativement ', | |
'∞': ' infini ', | |
'€': ' euro ', | |
'$': ' dollar ', | |
'£': ' livre ', | |
'&': ' et ', | |
'@': ' arobase ', | |
'#': ' dièse ', | |
} | |
hun_replacements = { | |
# Hungarian specific phonetic replacements (add as needed) | |
# e.g., 'á': 'a', 'é': 'e', etc. | |
'ch': 'ts', | |
'cs': 'tz', | |
'g': 'gk', | |
'w': 'v', | |
'z': 'zz', | |
# Math symbols for Hungarian | |
'sqrt': ' négyzetgyök ', | |
'^': ' hatvány ', | |
'+': ' plusz ', | |
' - ': ' mínusz ', | |
'*': ' szorozva ', | |
' / ': ' osztva ', | |
'=': ' egyenlő ', | |
'pi': ' pi ', | |
'<': ' kisebb mint ', | |
'>': ' nagyobb mint ', | |
# Add more common math symbols as needed for Hungarian | |
'%': ' százalék ', | |
'(': ' nyitó zárójel ', | |
')': ' záró zárójel ', | |
'[': ' nyitó szögletes zárójel ', | |
']': ' záró szögletes zárójel ', | |
'{': ' nyitó kapcsos zárójel ', | |
'}': ' záró kapcsos zárójel ', | |
'∑': ' szumma ', | |
'∫': ' integrál ', | |
'√': ' négyzetgyök ', | |
'≠': ' nem egyenlő ', | |
'≤': ' kisebb vagy egyenlő ', | |
'≥': ' nagyobb vagy egyenlő ', | |
'≈': ' körülbelül ', | |
'∞': ' végtelen ', | |
'€': ' euró ', | |
'$': ' dollár ', | |
'£': ' font ', | |
'&': ' és ', | |
'@': ' kukac ', | |
'#': ' kettőskereszt ', | |
} | |
grc_replacements = { | |
# Ancient Greek specific phonetic replacements (add as needed) | |
# These are more about transliterating Greek letters if they are in the input text. | |
# Math symbols for Ancient Greek (literal translations) | |
'sqrt': ' τετραγωνικὴ ῥίζα ', | |
'^': ' εἰς τὴν δύναμιν ', | |
'+': ' σὺν ', | |
' - ': ' χωρὶς ', | |
'*': ' πολλάκις ', | |
' / ': ' διαιρέω ', | |
'=': ' ἴσον ', | |
'pi': ' πῖ ', | |
'<': ' ἔλαττον ', | |
'>': ' μεῖζον ', | |
# Add more common math symbols as needed for Ancient Greek | |
'%': ' τοῖς ἑκατόν ', # tois hekaton - 'of the hundred' | |
'(': ' ἀνοικτὴ παρένθεσις ', | |
')': ' κλειστὴ παρένθεσις ', | |
'[': ' ἀνοικτὴ ἀγκύλη ', | |
']': ' κλειστὴ ἀγκύλη ', | |
'{': ' ἀνοικτὴ σγουρὴ ἀγκύλη ', | |
'}': ' κλειστὴ σγουρὴ ἀγκύλη ', | |
'∑': ' ἄθροισμα ', | |
'∫': ' ὁλοκλήρωμα ', | |
'√': ' τετραγωνικὴ ῥίζα ', | |
'≠': ' οὐκ ἴσον ', | |
'≤': ' ἔλαττον ἢ ἴσον ', | |
'≥': ' μεῖζον ἢ ἴσον ', | |
'≈': ' περίπου ', | |
'∞': ' ἄπειρον ', | |
'€': ' εὐρώ ', | |
'$': ' δολάριον ', | |
'£': ' λίρα ', | |
'&': ' καὶ ', | |
'@': ' ἀτ ', # at | |
'#': ' δίεση ', # hash | |
} | |
# Select the appropriate replacement dictionary based on the language | |
replacements_map = { | |
'grc': grc_replacements, | |
'ron': ron_replacements, | |
'eng': eng_replacements, | |
'deu': deu_replacements, | |
'fra': fra_replacements, | |
'hun': hun_replacements, | |
'rmc-script_latin': serbian_replacements, | |
} | |
current_replacements = replacements_map.get(lang) | |
if current_replacements: | |
# Sort replacements by length of the key in descending order. | |
# This is crucial for correctly replacing multi-character strings (like 'sqrt', 'sch') | |
# before their shorter substrings ('s', 'ch', 'q', 'r', 't'). | |
sorted_replacements = sorted(current_replacements.items(), key=lambda item: len(item[0]), reverse=True) | |
for old, new in sorted_replacements: | |
text = text.replace(old, new) | |
return text | |
else: | |
# If the language is not supported, return the original text | |
print(f"Warning: Language '{lang}' not supported for text replacement. Returning original text.") | |
return text | |
def _num2words(text='01234', lang=None): | |
if lang == 'grc': | |
return convert_numbers(text) | |
return num2words(text, lang=lang) # HAS TO BE kwarg lang=lang | |
def transliterate_number(number_string, | |
lang=None): | |
if lang == 'rmc-script_latin': | |
lang = 'sr' | |
exponential_pronoun = ' puta deset na stepen od ' | |
comma = ' tačka ' | |
elif lang == 'ron': | |
lang = 'ro' | |
exponential_pronoun = ' tízszer a erejéig ' | |
comma = ' virgulă ' | |
elif lang == 'hun': | |
lang = 'hu' | |
exponential_pronoun = ' tízszer a erejéig ' | |
comma = ' virgula ' | |
elif lang == 'deu': | |
exponential_pronoun = ' mal zehn hoch ' | |
comma = ' komma ' | |
elif lang == 'fra': | |
lang = 'fr' | |
exponential_pronoun = ' puissance ' | |
comma = 'virgule' | |
elif lang == 'grc': | |
exponential_pronoun = ' εις την δυναμην του ' | |
comma = 'κομμα' | |
else: | |
lang = lang[:2] | |
exponential_pronoun = ' times ten to the power of ' | |
comma = ' point ' | |
def replace_number(match): | |
prefix = match.group(1) or "" | |
number_part = match.group(2) | |
suffix = match.group(5) or "" | |
try: | |
if 'e' in number_part.lower(): | |
base, exponent = number_part.lower().split('e') | |
words = _num2words(base, lang=lang) + exponential_pronoun + _num2words(exponent, lang=lang) | |
elif '.' in number_part: | |
integer_part, decimal_part = number_part.split('.') | |
words = _num2words(integer_part, lang=lang) + comma + " ".join( | |
[_num2words(digit, lang=lang) for digit in decimal_part]) | |
else: | |
words = _num2words(number_part, lang=lang) | |
return prefix + words + suffix | |
except ValueError: | |
return match.group(0) # Return original if conversion fails | |
pattern = r'([^\d]*)(\d+(\.\d+)?([Ee][+-]?\d+)?)([^\d]*)' | |
return re.sub(pattern, replace_number, number_string) | |
language_names = ['Ancient greek', | |
'English', | |
'Deutsch', | |
'French', | |
'Hungarian', | |
'Romanian', | |
'Serbian (Approx.)'] | |
def audionar_tts(text=None, | |
lang='romanian'): | |
# https://huggingface.co/dkounadis/artificial-styletts2/blob/main/msinference.py | |
lang = lang.lower() | |
# https://huggingface.co/spaces/mms-meta/MMS | |
if 'hun' in lang: | |
lang_code = 'hun' | |
elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]): | |
# romani carpathian (has also Vlax) - cooler voice | |
lang_code = 'rmc-script_latin' | |
elif 'rom' in lang: | |
lang_code = 'ron' | |
elif 'ger' in lang or 'deu' in lang or 'allem' in lang: | |
lang_code = 'deu' | |
elif 'french' in lang: | |
lang_code = 'fra' | |
elif 'eng' in lang: | |
lang_code = 'eng' | |
elif 'ancient greek' in lang: | |
lang_code = 'grc' | |
else: | |
lang_code = lang.split()[0].strip() # latin & future option | |
# LATIN / GRC / CYRILLIC | |
text = only_greek_or_only_latin(text, lang=lang_code) # assure gr-chars if lang=='grc' / latin if lang!='grc' | |
# NUMERALS (^ in math expression found & substituted here before arriving to fix_vocals) | |
text = transliterate_number(text, lang=lang_code) | |
# PRONOUNC. | |
text = fix_vocals(text, lang=lang_code) | |
# VITS | |
global cached_lang_code, cached_net_g, cached_tokenizer | |
if 'cached_lang_code' not in globals() or cached_lang_code != lang_code: | |
cached_lang_code = lang_code | |
cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device) | |
cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}') | |
net_g = cached_net_g | |
tokenizer = cached_tokenizer | |
total_audio = [] | |
if not isinstance(text, list): | |
text = textwrap.wrap(text, width=439) | |
for _t in text: | |
inputs = tokenizer(_t, return_tensors="pt") | |
with torch.no_grad(): | |
x = net_g(input_ids=inputs.input_ids.to(device), | |
attention_mask=inputs.attention_mask.to(device), | |
lang_code=lang_code, | |
)[0, :] | |
total_audio.append(x) | |
print(f'\n\n_______________________________ {_t} {x.shape=}') | |
x = torch.cat(total_audio).cpu().numpy() | |
tmp_file = f'_speech.wav' | |
audiofile.write(tmp_file, x, 16000) | |
return tmp_file | |
# -- | |
device = 0 if torch.cuda.is_available() else "cpu" | |
duration = 2 # limit processing of audio | |
age_gender_model_name = "audeering/wav2vec2-large-robust-6-ft-age-gender" | |
expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" | |
class AgeGenderHead(nn.Module): | |
r"""Age-gender model head.""" | |
def __init__(self, config, num_labels): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = nn.Dropout(config.final_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, num_labels) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class AgeGenderModel(Wav2Vec2PreTrainedModel): | |
r"""Age-gender recognition model.""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.wav2vec2 = Wav2Vec2Model(config) | |
self.age = AgeGenderHead(config, 1) | |
self.gender = AgeGenderHead(config, 3) | |
self.init_weights() | |
def forward( | |
self, | |
frozen_cnn7, | |
): | |
hidden_states = self.wav2vec2(frozen_cnn7=frozen_cnn7) # runs only Transformer layers | |
hidden_states = torch.mean(hidden_states, dim=1) | |
logits_age = self.age(hidden_states) | |
logits_gender = torch.softmax(self.gender(hidden_states), dim=1) | |
return hidden_states, logits_age, logits_gender | |
# AgeGenderModel.forward() is switched to accept computed frozen CNN7 features from ExpressioNmodel | |
def _forward( | |
self, | |
frozen_cnn7=None, # CNN7 fetures of wav2vec2 calc. from CNN7 feature extractor (once) | |
attention_mask=None): | |
if attention_mask is not None: | |
# compute reduced attention_mask corresponding to feature vectors | |
attention_mask = self._get_feature_vector_attention_mask( | |
frozen_cnn7.shape[1], attention_mask, add_adapter=False | |
) | |
hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) | |
hidden_states = self.wav2vec2.encoder( | |
hidden_states, | |
attention_mask=attention_mask, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
)[0] | |
return hidden_states | |
def _forward_and_cnn7( | |
self, | |
input_values, | |
attention_mask=None): | |
frozen_cnn7 = self.wav2vec2.feature_extractor(input_values) | |
frozen_cnn7 = frozen_cnn7.transpose(1, 2) | |
if attention_mask is not None: | |
# compute reduced attention_mask corresponding to feature vectors | |
attention_mask = self.wav2vec2._get_feature_vector_attention_mask( | |
frozen_cnn7.shape[1], attention_mask, add_adapter=False | |
) | |
hidden_states, _ = self.wav2vec2.feature_projection(frozen_cnn7) # grad=True non frozen | |
hidden_states = self.wav2vec2.encoder( | |
hidden_states, | |
attention_mask=attention_mask, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
)[0] | |
return hidden_states, frozen_cnn7 #feature_proj is trainable thus we have to access the frozen_cnn7 before projection layer | |
class ExpressionHead(nn.Module): | |
r"""Expression model head.""" | |
def __init__(self, config): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.dropout = nn.Dropout(config.final_dropout) | |
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
def forward(self, features, **kwargs): | |
x = features | |
x = self.dropout(x) | |
x = self.dense(x) | |
x = torch.tanh(x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class ExpressionModel(Wav2Vec2PreTrainedModel): | |
r"""speech expression model.""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.wav2vec2 = Wav2Vec2Model(config) | |
self.classifier = ExpressionHead(config) | |
self.init_weights() | |
def forward(self, input_values): | |
hidden_states, frozen_cnn7 = self.wav2vec2(input_values) | |
hidden_states = torch.mean(hidden_states, dim=1) | |
logits = self.classifier(hidden_states) | |
return hidden_states, logits, frozen_cnn7 | |
# Load models from hub | |
age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name) | |
expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name) | |
expression_model = ExpressionModel.from_pretrained(expression_model_name) | |
# Emotion Calc. CNN features | |
age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model) | |
expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model) | |
def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]: | |
# batch audio | |
y = expression_processor(x, sampling_rate=sampling_rate) | |
y = y['input_values'][0] | |
y = y.reshape(1, -1) | |
y = torch.from_numpy(y).to(device) | |
# run through expression model | |
with torch.no_grad(): | |
_, logits_expression, frozen_cnn7 = expression_model(y) | |
_, logits_age, logits_gender = age_gender_model(frozen_cnn7=frozen_cnn7) | |
# Plot A/D/V values | |
plot_expression(logits_expression[0, 0].item(), # implicit detach().cpu().numpy() | |
logits_expression[0, 1].item(), | |
logits_expression[0, 2].item()) | |
expression_file = "expression.png" | |
plt.savefig(expression_file) | |
return ( | |
f"{round(100 * logits_age[0, 0].item())} years", # age | |
{ | |
"female": logits_gender[0, 0].item(), | |
"male": logits_gender[0, 1].item(), | |
"child": logits_gender[0, 2].item(), | |
}, | |
expression_file, | |
) | |
def recognize(input_file): | |
if input_file is None: | |
raise gr.Error( | |
"No audio file submitted! " | |
"Please upload or record an audio file " | |
"before submitting your request." | |
) | |
signal, sampling_rate = audiofile.read(input_file, duration=duration) | |
# Resample to sampling rate supported byu the models | |
target_rate = 16000 | |
signal = audresample.resample(signal, sampling_rate, target_rate) | |
return process_func(signal, target_rate) | |
def explode(data): | |
""" | |
Expands a 3D array by creating gaps between voxels. | |
This function is used to create the visual separation between the voxels. | |
""" | |
shape_orig = np.array(data.shape) | |
shape_new = shape_orig * 2 - 1 | |
retval = np.zeros(shape_new, dtype=data.dtype) | |
retval[::2, ::2, ::2] = data | |
return retval | |
def explode(data): | |
""" | |
Expands a 3D array by adding new voxels between existing ones. | |
This is used to create the gaps in the 3D plot. | |
""" | |
shape = data.shape | |
new_shape = (2 * shape[0] - 1, 2 * shape[1] - 1, 2 * shape[2] - 1) | |
new_data = np.zeros(new_shape, dtype=data.dtype) | |
new_data[::2, ::2, ::2] = data | |
return new_data | |
def plot_expression(arousal, dominance, valence): | |
'''_h = cuda tensor (N_PIX, N_PIX, N_PIX)''' | |
N_PIX = 5 | |
_h = np.random.rand(N_PIX, N_PIX, N_PIX) * 1e-3 | |
adv = np.array([arousal, .994 - dominance, valence]).clip(0, .99) | |
arousal, dominance, valence = (adv * N_PIX).astype(np.int64) # find voxel | |
_h[arousal, dominance, valence] = .22 | |
filled = np.ones((N_PIX, N_PIX, N_PIX), dtype=bool) | |
# upscale the above voxel image, leaving gaps | |
filled_2 = explode(filled) | |
# Shrink the gaps | |
x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2 | |
x[1::2, :, :] += 1 | |
y[:, 1::2, :] += 1 | |
z[:, :, 1::2] += 1 | |
fig = plt.figure() | |
ax = fig.add_subplot(projection='3d') | |
f_2 = np.ones([2 * N_PIX - 1, | |
2 * N_PIX - 1, | |
2 * N_PIX - 1, 4], dtype=np.float64) | |
f_2[:, :, :, 3] = explode(_h) | |
cm = plt.get_cmap('cool') | |
f_2[:, :, :, :3] = cm(f_2[:, :, :, 3])[..., :3] | |
f_2[:, :, :, 3] = f_2[:, :, :, 3].clip(.01, .74) | |
ecolors_2 = f_2 | |
ax.voxels(x, y, z, filled_2, facecolors=f_2, edgecolors=.006 * ecolors_2) | |
ax.set_aspect('equal') | |
ax.set_zticks([0, N_PIX]) | |
ax.set_xticks([0, N_PIX]) | |
ax.set_yticks([0, N_PIX]) | |
ax.set_zticklabels([f'{n/N_PIX:.2f}'[0:] for n in ax.get_zticks()]) | |
ax.set_zlabel('valence', fontsize=10, labelpad=0) | |
ax.set_xticklabels([f'{n/N_PIX:.2f}' for n in ax.get_xticks()]) | |
ax.set_xlabel('arousal', fontsize=10, labelpad=7) | |
# The y-axis rotation is corrected here from 275 to 90 degrees | |
ax.set_yticklabels([f'{1-n/N_PIX:.2f}' for n in ax.get_yticks()], rotation=90) | |
ax.set_ylabel('dominance', fontsize=10, labelpad=10) | |
ax.grid(False) | |
ax.plot([N_PIX, N_PIX], [0, N_PIX + .2], [N_PIX, N_PIX], 'g', linewidth=1) | |
ax.plot([0, N_PIX], [N_PIX, N_PIX + .24], [N_PIX, N_PIX], 'k', linewidth=1) | |
# Missing lines on the top face | |
ax.plot([0, 0], [0, N_PIX], [N_PIX, N_PIX], 'darkred', linewidth=1) | |
ax.plot([0, N_PIX], [0, 0], [N_PIX, N_PIX], 'darkblue', linewidth=1) | |
# Set pane colors after plotting the lines | |
# UPDATED: Replaced `w_xaxis` with `xaxis` and `w_yaxis` with `yaxis`. | |
ax.xaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) | |
ax.yaxis.set_pane_color((0.8, 0.8, 0.8, 0.5)) | |
ax.zaxis.set_pane_color((0.8, 0.8, 0.8, 0.0)) | |
# Restore the limits to prevent the plot from expanding | |
ax.set_xlim(0, N_PIX) | |
ax.set_ylim(0, N_PIX) | |
ax.set_zlim(0, N_PIX) | |
# plt.show() | |
# TTS | |
VOICES = [f'wav/{vox}' for vox in os.listdir('wav')] | |
_tts = StyleTTS2().to('cpu') | |
def only_greek_or_only_latin(text, lang='grc'): | |
''' | |
str: The converted string in the specified target script. | |
Characters not found in any mapping are preserved as is. | |
Latin accented characters in the input (e.g., 'É', 'ü') will | |
be preserved in their lowercase form (e.g., 'é', 'ü') if | |
converting to Latin. | |
''' | |
# --- Mapping Dictionaries --- | |
# Keys are in lowercase as input text is case-folded. | |
# If the output needs to maintain original casing, additional logic is required. | |
latin_to_greek_map = { | |
'a': 'α', 'b': 'β', 'g': 'γ', 'd': 'δ', 'e': 'ε', | |
'ch': 'τσο', # Example of a multi-character Latin sequence | |
'z': 'ζ', 'h': 'χ', 'i': 'ι', 'k': 'κ', 'l': 'λ', | |
'm': 'μ', 'n': 'ν', 'x': 'ξ', 'o': 'ο', 'p': 'π', | |
'v': 'β', 'sc': 'σκ', 'r': 'ρ', 's': 'σ', 't': 'τ', | |
'u': 'ου', 'f': 'φ', 'c': 'σ', 'w': 'β', 'y': 'γ', | |
} | |
greek_to_latin_map = { | |
'ου': 'ou', # Prioritize common diphthongs/digraphs | |
'α': 'a', 'β': 'v', 'γ': 'g', 'δ': 'd', 'ε': 'e', | |
'ζ': 'z', 'η': 'i', 'θ': 'th', 'ι': 'i', 'κ': 'k', | |
'λ': 'l', 'μ': 'm', 'ν': 'n', 'ξ': 'x', 'ο': 'o', | |
'π': 'p', 'ρ': 'r', 'σ': 's', 'τ': 't', 'υ': 'y', # 'y' is a common transliteration for upsilon | |
'φ': 'f', 'χ': 'ch', 'ψ': 'ps', 'ω': 'o', | |
'ς': 's', # Final sigma | |
} | |
cyrillic_to_latin_map = { | |
'а': 'a', 'б': 'b', 'в': 'v', 'г': 'g', 'д': 'd', 'е': 'e', 'ё': 'yo', 'ж': 'zh', | |
'з': 'z', 'и': 'i', 'й': 'y', 'к': 'k', 'л': 'l', 'м': 'm', 'н': 'n', 'о': 'o', | |
'п': 'p', 'р': 'r', 'с': 's', 'т': 't', 'у': 'u', 'ф': 'f', 'х': 'kh', 'ц': 'ts', | |
'ч': 'ch', 'ш': 'sh', 'щ': 'shch', 'ъ': '', 'ы': 'y', 'ь': '', 'э': 'e', 'ю': 'yu', | |
'я': 'ya', | |
} | |
# Direct Cyrillic to Greek mapping based on phonetic similarity. | |
# These are approximations and may not be universally accepted transliterations. | |
cyrillic_to_greek_map = { | |
'а': 'α', 'б': 'β', 'в': 'β', 'г': 'γ', 'д': 'δ', 'е': 'ε', 'ё': 'ιο', 'ж': 'ζ', | |
'з': 'ζ', 'и': 'ι', 'й': 'ι', 'κ': 'κ', 'λ': 'λ', 'м': 'μ', 'н': 'ν', 'о': 'ο', | |
'π': 'π', 'ρ': 'ρ', 'σ': 'σ', 'τ': 'τ', 'у': 'ου', 'ф': 'φ', 'х': 'χ', 'ц': 'τσ', | |
'ч': 'τσ', # or τζ depending on desired sound | |
'ш': 'σ', 'щ': 'σ', # approximations | |
'ъ': '', 'ы': 'ι', 'ь': '', 'э': 'ε', 'ю': 'ιου', | |
'я': 'ια', | |
} | |
# Convert the input text to lowercase, preserving accents for Latin characters. | |
# casefold() is used for more robust caseless matching across Unicode characters. | |
lowercased_text = text.lower() #casefold() | |
output_chars = [] | |
current_index = 0 | |
if lang == 'grc': | |
# Combine all relevant maps for direct lookup to Greek | |
conversion_map = {**latin_to_greek_map, **cyrillic_to_greek_map} | |
# Sort keys by length in reverse order to handle multi-character sequences first | |
sorted_source_keys = sorted( | |
list(latin_to_greek_map.keys()) + list(cyrillic_to_greek_map.keys()), | |
key=len, | |
reverse=True | |
) | |
while current_index < len(lowercased_text): | |
found_conversion = False | |
for key in sorted_source_keys: | |
if lowercased_text.startswith(key, current_index): | |
output_chars.append(conversion_map[key]) | |
current_index += len(key) | |
found_conversion = True | |
break | |
if not found_conversion: | |
# If no specific mapping found, append the character as is. | |
# This handles unmapped characters and already Greek characters. | |
output_chars.append(lowercased_text[current_index]) | |
current_index += 1 | |
return ''.join(output_chars) | |
else: # Default to 'lat' conversion | |
# Combine Greek to Latin and Cyrillic to Latin maps. | |
# Cyrillic map keys will take precedence in case of overlap if defined after Greek. | |
combined_to_latin_map = {**greek_to_latin_map, **cyrillic_to_latin_map} | |
# Sort all relevant source keys by length in reverse for replacement | |
sorted_source_keys = sorted( | |
list(greek_to_latin_map.keys()) + list(cyrillic_to_latin_map.keys()), | |
key=len, | |
reverse=True | |
) | |
while current_index < len(lowercased_text): | |
found_conversion = False | |
for key in sorted_source_keys: | |
if lowercased_text.startswith(key, current_index): | |
latin_equivalent = combined_to_latin_map[key] | |
# Strip accents ONLY if the source character was from the Greek map. | |
# This preserves accents on original Latin characters (like 'é') | |
# and allows for intentional accent stripping from Greek transliterations. | |
if key in greek_to_latin_map: | |
normalized_latin = unicodedata.normalize('NFD', latin_equivalent) | |
stripped_latin = ''.join(c for c in normalized_latin if not unicodedata.combining(c)) | |
output_chars.append(stripped_latin) | |
else: | |
output_chars.append(latin_equivalent) | |
current_index += len(key) | |
found_conversion = True | |
break | |
if not found_conversion: | |
# If no conversion happened from Greek or Cyrillic, append the character as is. | |
# This preserves existing Latin characters (including accented ones from input), | |
# numbers, punctuation, and other symbols. | |
output_chars.append(lowercased_text[current_index]) | |
current_index += 1 | |
return ''.join(output_chars) | |
def other_tts(text='Hallov worlds Far over the', | |
ref_s='wav/af_ZA_google-nwu_0184.wav'): | |
text = only_greek_or_only_latin(text, lang='eng') | |
x = _tts.inference(text, ref_s=ref_s)[0:1, 0, :] | |
x = torch.cat([.99 * x, | |
.94 * x], 0).cpu().numpy() # Stereo | |
# x /= np.abs(x).max() + 1e-7 ~ Volume normalisation @api.py:tts_multi_sentence() OR demo.py | |
tmp_file = f'_speech.wav' # N x clients (cleanup vs tmp file / client) | |
audiofile.write(tmp_file, x, 24000) | |
return tmp_file | |
def update_selected_voice(voice_filename): | |
return 'wav/' + voice_filename + '.wav' | |
description = ( | |
"Estimate **age**, **gender**, and **expression** " | |
"of the speaker contained in an audio file or microphone recording. \n" | |
f"The model [{age_gender_model_name}]" | |
f"(https://huggingface.co/{age_gender_model_name}) " | |
"recognises age and gender, " | |
f"whereas [{expression_model_name}]" | |
f"(https://huggingface.co/{expression_model_name}) " | |
"recognises the expression dimensions arousal, dominance, and valence. " | |
) | |
css_buttons = """ | |
.cool-button { | |
background-color: #1a2a40; /* Slightly lighter dark blue */ | |
color: white; | |
padding: 15px 32px; | |
text-align: center; | |
font-size: 16px; | |
border-radius: 12px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4); | |
transition: all 0.3s ease-in-out; | |
border: none; | |
cursor: pointer; | |
} | |
.cool-button:hover { | |
background-color: #1a2a40; /* Slightly lighter dark blue */ | |
transform: scale(1.05); | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.4); | |
} | |
.cool-row { | |
margin-bottom: 10px; | |
} | |
""" | |
with gr.Blocks(theme='huggingface', css=css_buttons) as demo: | |
with gr.Tab(label="other TTS"): | |
selected_voice = gr.State(value='wav/en_US_m-ailabs_mary_ann.wav') | |
with gr.Row(): | |
voice_info = gr.Markdown(f'Vox = `{selected_voice.value}`') | |
# Main input and output components | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Enter text for TTS:", | |
placeholder="Type your message here...", | |
lines=4, | |
value="Farover the misty mountains cold too dungeons deep and caverns old.", | |
) | |
generate_button = gr.Button("Generate Audio", variant="primary") | |
output_audio = gr.Audio(label="TTS Output") | |
with gr.Column(): | |
voice_buttons = [] | |
for i in range(0, len(VOICES), 7): | |
with gr.Row(elem_classes=["cool-row"]): | |
for voice_filename in VOICES[i:i+7]: | |
voice_filename = voice_filename[4:-4] # drop wav/ for visibility | |
button = gr.Button(voice_filename, elem_classes=["cool-button"]) | |
button.click( | |
fn=update_selected_voice, | |
inputs=[gr.Textbox(value=voice_filename, visible=False)], | |
outputs=[selected_voice] | |
) | |
button.click( | |
fn=lambda v=voice_filename: f'Vox = `{v}`', | |
inputs=None, | |
outputs=voice_info | |
) | |
voice_buttons.append(button) | |
generate_button.click( | |
fn=other_tts, | |
inputs=[text_input, selected_voice], | |
outputs=output_audio | |
) | |
with gr.Tab(label="Speech Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(description) | |
input = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Audio input", | |
min_length=0.025, # seconds | |
) | |
gr.Examples( | |
[ | |
"wav/female-46-neutral.wav", | |
"wav/female-20-happy.wav", | |
"wav/male-60-angry.wav", | |
"wav/male-27-sad.wav", | |
], | |
[input], | |
label="Examples from CREMA-D, ODbL v1.0 license", | |
) | |
gr.Markdown("Only the first two seconds of the audio will be processed.") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
output_age = gr.Textbox(label="Age") | |
output_gender = gr.Label(label="Gender") | |
output_expression = gr.Image(label="Expression") | |
outputs = [output_age, output_gender, output_expression] | |
submit_btn.click(recognize, input, outputs) | |
with gr.Tab("audionar TTS"): | |
with gr.Row(): | |
text_input = gr.Textbox( | |
lines=4, | |
value='Η γρηγορη καφετι αλεπου πειδαει πανω απο τον τεμπελη σκυλο.', | |
label="Type text for TTS" | |
) | |
lang_dropdown = gr.Dropdown( | |
choices=language_names, | |
label="TTS language", | |
value="Ancient greek", | |
) | |
# Create a button to trigger the TTS function | |
tts_button = gr.Button("Generate Audio") | |
# Create the output audio component | |
audio_output = gr.Audio(label="Generated Audio") | |
# Link the button click event to the mms_tts function | |
tts_button.click( | |
fn=audionar_tts, | |
inputs=[text_input, lang_dropdown], | |
outputs=audio_output | |
) | |
demo.launch(debug=True) | |