|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import re |
|
from pathlib import Path |
|
import sys |
|
|
|
pwd = os.path.abspath(os.path.dirname(__file__)) |
|
sys.path.append(os.path.join(pwd, "../../")) |
|
|
|
from google import genai |
|
from google.genai import types |
|
|
|
from project_settings import environment, project_path |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--google_application_credentials", |
|
default=(project_path / "dotenv/potent-veld-462405-t3-8091a29b2894.json").as_posix(), |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--model_name", |
|
default="gemini-2.5-pro", |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--speech_audio_dir", |
|
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-SG\2025-06-17", |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--output_file", |
|
|
|
default=r"vad.jsonl", |
|
type=str |
|
) |
|
parser.add_argument( |
|
"--gemini_api_key", |
|
default=environment.get("GEMINI_API_KEY", dtype=str), |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
speech_audio_dir = Path(args.speech_audio_dir) |
|
output_file = Path(args.output_file) |
|
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = args.google_application_credentials |
|
os.environ["gemini_api_key"] = args.gemini_api_key |
|
|
|
|
|
developer_client = genai.Client( |
|
api_key=args.gemini_api_key, |
|
) |
|
client = genai.Client( |
|
vertexai=True, |
|
project="potent-veld-462405-t3", |
|
location="global", |
|
) |
|
generate_content_config = types.GenerateContentConfig( |
|
temperature=1, |
|
top_p=0.95, |
|
max_output_tokens=8192, |
|
response_modalities=["TEXT"], |
|
) |
|
|
|
|
|
finished_set = set() |
|
if output_file.exists(): |
|
with open(output_file.as_posix(), "r", encoding="utf-8") as f: |
|
for row in f: |
|
row = json.loads(row) |
|
name = row["name"] |
|
finished_set.add(name) |
|
print(f"finished count: {len(finished_set)}") |
|
|
|
with open(output_file.as_posix(), "a+", encoding="utf-8") as f: |
|
|
|
for filename in speech_audio_dir.glob("**/*.wav"): |
|
name = filename.name |
|
if name in finished_set: |
|
continue |
|
finished_set.add(name) |
|
|
|
|
|
audio_file = developer_client.files.upload( |
|
file=filename.as_posix(), |
|
config=None |
|
) |
|
print(f"upload file: {audio_file.name}") |
|
|
|
prompt = f""" |
|
给我这段音频中的语音分段的开始和结束时间,单位为秒,精确到毫秒,并输出JSON格式, |
|
例如: |
|
```json |
|
[[0.254, 1.214], [2.200, 3.100]], |
|
``` |
|
如果没有语音段则输出: |
|
```json |
|
[] |
|
``` |
|
""".strip() |
|
|
|
try: |
|
contents = [ |
|
types.Content( |
|
role="user", |
|
parts=[ |
|
types.Part(text=prompt), |
|
types.Part.from_uri( |
|
file_uri=audio_file.uri, |
|
mime_type=audio_file.mime_type, |
|
) |
|
] |
|
) |
|
] |
|
response: types.GenerateContentResponse = developer_client.models.generate_content( |
|
model=args.model_name, |
|
contents=contents, |
|
config=generate_content_config, |
|
) |
|
answer = response.candidates[0].content.parts[0].text |
|
print(answer) |
|
finally: |
|
|
|
print(f"delete file: {audio_file.name}") |
|
developer_client.files.delete(name=audio_file.name) |
|
|
|
pattern = "```json(.+?)```" |
|
match = re.search(pattern=pattern, string=answer, flags=re.DOTALL | re.IGNORECASE) |
|
if match is None: |
|
raise AssertionError(f"answer: {answer}") |
|
vad_segments = match.group(1) |
|
vad_segments = json.loads(vad_segments) |
|
row = { |
|
"name": name, |
|
"filename": filename.as_posix(), |
|
"vad_segments": vad_segments |
|
} |
|
row = json.dumps(row, ensure_ascii=False) |
|
|
|
f.write(f"{row}\n") |
|
exit(0) |
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|