ViDove / SRT.py
worldqwq
Bug Fixed
fb45ef4
raw
history blame
8.76 kB
from datetime import timedelta
import os
import whisper
from csv import reader
import re
import openai
class SRT_segment(object):
def __init__(self, *args) -> None:
if isinstance(args[0], dict):
segment = args[0]
start_ms = int((segment['start']*100)%100*10)
end_ms = int((segment['end']*100)%100*10)
start_time = str(timedelta(seconds=int(segment['start']), milliseconds=start_ms))
end_time = str(timedelta(seconds=int(segment['end']), milliseconds=end_ms))
if start_ms == 0:
self.start_time_str = str(0)+start_time.split('.')[0]+',000'
else:
self.start_time_str = str(0)+start_time.split('.')[0]+','+start_time.split('.')[1][:3]
if end_ms == 0:
self.end_time_str = str(0)+end_time.split('.')[0]+',000'
else:
self.end_time_str = str(0)+end_time.split('.')[0]+','+end_time.split('.')[1][:3]
self.source_text = segment['text']
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
self.translation = ""
elif isinstance(args[0], list):
self.source_text = args[0][2]
self.duration = args[0][1]
self.start_time_str = self.duration.split(" --> ")[0]
self.end_time_str = self.duration.split(" --> ")[1]
self.translation = ""
def merge_seg(self, seg):
self.source_text += seg.source_text
self.translation += seg.translation
self.end_time_str = seg.end_time_str
self.duration = f"{self.start_time_str} --> {self.end_time_str}"
pass
def __str__(self) -> str:
return f'{self.duration}\n{self.source_text}\n\n'
def get_trans_str(self) -> str:
return f'{self.duration}\n{self.translation}\n\n'
def get_bilingual_str(self) -> str:
return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
class SRT_script():
def __init__(self, segments) -> None:
self.segments = []
for seg in segments:
srt_seg = SRT_segment(seg)
self.segments.append(srt_seg)
@classmethod
def parse_from_srt_file(cls, path:str):
with open(path, 'r', encoding="utf-8") as f:
script_lines = f.read().splitlines()
segments = []
for i in range(len(script_lines)):
if i % 4 == 0:
segments.append(list(script_lines[i:i+4]))
return cls(segments)
def merge_segs(self, idx_list) -> SRT_segment:
final_seg = self.segments[idx_list[0]]
if len(idx_list) == 1:
return final_seg
for idx in range(1, len(idx_list)):
final_seg.merge_seg(self.segments[idx_list[idx]])
return final_seg
def form_whole_sentence(self):
merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
sentence = []
for i, seg in enumerate(self.segments):
if seg.source_text[-1] == '.':
sentence.append(i)
merge_list.append(sentence)
sentence = []
else:
sentence.append(i)
segments = []
for idx_list in merge_list:
segments.append(self.merge_segs(idx_list))
self.segments = segments # need memory release?
def set_translation(self, translate:str, id_range:tuple):
start_seg_id = id_range[0]
end_seg_id = id_range[1]
lines = translate.split('\n\n')
if len(lines) < (end_seg_id - start_seg_id + 1):
input_str = "\n";
#initialize GPT input
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
#Append to prompt string
#Adds sentence index let GPT keep track of sentence breaks
input_str += translate
#append translate to prompt
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages = [
{"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
{"role": "system", "content": "You are provided with a translated Chinese transcript, you need to reformat the Chinese sentence to match the meaning and sentence number as the English transcript"},
{"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
{"role": "user", "content": 'Reformat the Chinese with the English transcript given: "{}"'.format(input_str)}
],
temperature=0.15
)
translate = response['choices'][0]['message']['content'].strip()
lines = translate.split('\n\n')
print("block used")
#print(id_range)
#for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# print(seg.source_text)
#print(translate)
for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
# naive way to due with merge translation problem
# TODO: need a smarter solution
if i < len(lines):
if "(Note:" in lines[i]: # to avoid note
lines.remove(lines[i])
if i == len(lines) - 1:
break
try:
seg.translation = lines[i].split(":" or ": ")[1]
except:
seg.translation = lines[i]
pass
def split_seg(self, seg_id):
# TODO: evenly split seg to 2 parts and add new seg into self.segments
pass
def check_len_and_split(self, threshold):
# TODO: if sentence length >= threshold, split this segments to two
pass
def get_source_only(self):
# return a string with pure source text
result = ""
for i, seg in enumerate(self.segments):
result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
return result
def reform_src_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += str(seg)
return result
def reform_trans_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_trans_str()
return result
def form_bilingual_str(self):
result = ""
for i, seg in enumerate(self.segments):
result += f'{i+1}\n'
result += seg.get_bilingual_str()
return result
def write_srt_file_src(self, path:str):
# write srt file to path
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_src_str())
pass
def write_srt_file_translate(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.reform_trans_str())
pass
def write_srt_file_bilingual(self, path:str):
with open(path, "w", encoding='utf-8') as f:
f.write(self.form_bilingual_str())
pass
def correct_with_force_term(self):
## force term correction
# TODO: shortcut translation i.e. VA, ob
# TODO: variety of translation
# load term dictionary
with open("finetune_data/dict.csv",'r', encoding='utf-8') as f:
csv_reader = reader(f)
term_dict = {rows[0]:rows[1] for rows in csv_reader}
# change term
for seg in self.segments:
ready_words = re.sub('\n', '\n ', seg.source_text).split(" ")
for i in range(len(ready_words)):
word = ready_words[i]
if word[-2:] == ".\n" :
if word[:-2].lower() in term_dict :
new_word = word.replace(word[:-2], term_dict.get(word[:-2].lower())) + ' '
ready_words[i] = new_word
else:
ready_words[i] = word + ' '
elif word.lower() in term_dict :
new_word = word.replace(word,term_dict.get(word.lower())) + ' '
ready_words[i] = new_word
else :
ready_words[i]= word + ' '
seg.source_text = re.sub('\n ', '\n', "".join(ready_words))
print(self)
pass