testsite / serve.py
peteralexandercharles's picture
Upload 11 files
98d3f44
from twisted.web.static import File
from twisted.web.resource import Resource
from twisted.web.server import Site, NOT_DONE_YET
from twisted.internet import reactor, threads
from twisted.web._responses import FOUND
import json
import logging
import multiprocessing
import os
import shutil
import uuid
import wave
from gentle.util.paths import get_resource, get_datadir
from gentle.util.cyst import Insist
import gentle
class TranscriptionStatus(Resource):
def __init__(self, status_dict):
self.status_dict = status_dict
Resource.__init__(self)
def render_GET(self, req):
req.setHeader(b"Content-Type", "application/json")
return json.dumps(self.status_dict).encode()
class Transcriber():
def __init__(self, data_dir, nthreads=4, ntranscriptionthreads=2):
self.data_dir = data_dir
self.nthreads = nthreads
self.ntranscriptionthreads = ntranscriptionthreads
self.resources = gentle.Resources()
self.full_transcriber = gentle.FullTranscriber(self.resources, nthreads=ntranscriptionthreads)
self._status_dicts = {}
def get_status(self, uid):
return self._status_dicts.setdefault(uid, {})
def out_dir(self, uid):
return os.path.join(self.data_dir, 'transcriptions', uid)
# TODO(maxhawkins): refactor so this is returned by transcribe()
def next_id(self):
uid = None
while uid is None or os.path.exists(os.path.join(self.data_dir, uid)):
uid = uuid.uuid4().hex[:8]
return uid
def transcribe(self, uid, transcript, audio, async_mode, **kwargs):
status = self.get_status(uid)
status['status'] = 'STARTED'
output = {
'transcript': transcript
}
outdir = os.path.join(self.data_dir, 'transcriptions', uid)
tran_path = os.path.join(outdir, 'transcript.txt')
with open(tran_path, 'w') as tranfile:
tranfile.write(transcript)
audio_path = os.path.join(outdir, 'upload')
with open(audio_path, 'wb') as wavfile:
wavfile.write(audio)
status['status'] = 'ENCODING'
wavfile = os.path.join(outdir, 'a.wav')
if gentle.resample(os.path.join(outdir, 'upload'), wavfile) != 0:
status['status'] = 'ERROR'
status['error'] = "Encoding failed. Make sure that you've uploaded a valid media file."
# Save the status so that errors are recovered on restart of the server
# XXX: This won't work, because the endpoint will override this file
with open(os.path.join(outdir, 'status.json'), 'w') as jsfile:
json.dump(status, jsfile, indent=2)
return
#XXX: Maybe we should pass this wave object instead of the
# file path to align_progress
wav_obj = wave.open(wavfile, 'rb')
status['duration'] = wav_obj.getnframes() / float(wav_obj.getframerate())
status['status'] = 'TRANSCRIBING'
def on_progress(p):
print(p)
for k,v in p.items():
status[k] = v
if len(transcript.strip()) > 0:
trans = gentle.ForcedAligner(self.resources, transcript, nthreads=self.nthreads, **kwargs)
elif self.full_transcriber.available:
trans = self.full_transcriber
else:
status['status'] = 'ERROR'
status['error'] = 'No transcript provided and no language model for full transcription'
return
output = trans.transcribe(wavfile, progress_cb=on_progress, logging=logging)
# ...remove the original upload
os.unlink(os.path.join(outdir, 'upload'))
# Save
with open(os.path.join(outdir, 'align.json'), 'w') as jsfile:
jsfile.write(output.to_json(indent=2))
with open(os.path.join(outdir, 'align.csv'), 'w') as csvfile:
csvfile.write(output.to_csv())
# Inline the alignment into the index.html file.
htmltxt = open(get_resource('www/view_alignment.html')).read()
htmltxt = htmltxt.replace("var INLINE_JSON;", "var INLINE_JSON=%s;" % (output.to_json()));
open(os.path.join(outdir, 'index.html'), 'w').write(htmltxt)
status['status'] = 'OK'
logging.info('done with transcription.')
return output
class TranscriptionsController(Resource):
def __init__(self, transcriber):
Resource.__init__(self)
self.transcriber = transcriber
def getChild(self, uid, req):
uid = uid.decode()
out_dir = self.transcriber.out_dir(uid)
trans_ctrl = File(out_dir)
# Add a Status endpoint to the file
trans_status = TranscriptionStatus(self.transcriber.get_status(uid))
trans_ctrl.putChild(b"status.json", trans_status)
return trans_ctrl
def render_POST(self, req):
uid = self.transcriber.next_id()
tran = req.args.get(b'transcript', [b''])[0].decode()
audio = req.args[b'audio'][0]
disfluency = True if b'disfluency' in req.args else False
conservative = True if b'conservative' in req.args else False
kwargs = {'disfluency': disfluency,
'conservative': conservative,
'disfluencies': set(['uh', 'um'])}
async_mode = True
if b'async' in req.args and req.args[b'async'][0] == b'false':
async_mode = False
# We need to make the transcription directory here, so that
# when we redirect the user we are sure that there's a place
# for them to go.
outdir = os.path.join(self.transcriber.data_dir, 'transcriptions', uid)
os.makedirs(outdir)
# Copy over the HTML
shutil.copy(get_resource('www/view_alignment.html'), os.path.join(outdir, 'index.html'))
result_promise = threads.deferToThreadPool(
reactor, reactor.getThreadPool(),
self.transcriber.transcribe,
uid, tran, audio, async_mode, **kwargs)
if not async_mode:
def write_result(result):
'''Write JSON to client on completion'''
req.setHeader("Content-Type", "application/json")
req.write(result.to_json(indent=2).encode())
req.finish()
result_promise.addCallback(write_result)
result_promise.addErrback(lambda _: None) # ignore errors
req.notifyFinish().addErrback(lambda _: result_promise.cancel())
return NOT_DONE_YET
req.setResponseCode(FOUND)
req.setHeader(b"Location", "/transcriptions/%s" % (uid))
return b''
class LazyZipper(Insist):
def __init__(self, cachedir, transcriber, uid):
self.transcriber = transcriber
self.uid = uid
Insist.__init__(self, os.path.join(cachedir, '%s.zip' % (uid)))
def serialize_computation(self, outpath):
shutil.make_archive('.'.join(outpath.split('.')[:-1]), # We need to strip the ".zip" from the end
"zip", # ...because `shutil.make_archive` adds it back
os.path.join(self.transcriber.out_dir(self.uid)))
class TranscriptionZipper(Resource):
def __init__(self, cachedir, transcriber):
self.cachedir = cachedir
self.transcriber = transcriber
Resource.__init__(self)
def getChild(self, path, req):
uid = path.decode().split('.')[0]
t_dir = self.transcriber.out_dir(uid)
if os.path.exists(t_dir):
# TODO: Check that "status" is complete and only create a LazyZipper if so
# Otherwise, we could have incomplete transcriptions that get permanently zipped.
# For now, a solution will be hiding the button in the client until it's done.
lz = LazyZipper(self.cachedir, self.transcriber, uid)
if not isinstance(path, bytes):
path = path.encode()
self.putChild(path, lz)
return lz
else:
return Resource.getChild(self, path, req)
def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, ntranscriptionthreads=2, data_dir=get_datadir('webdata')):
logging.info("SERVE %d, %s, %d", port, interface, installSignalHandlers)
if not os.path.exists(data_dir):
os.makedirs(data_dir)
zip_dir = os.path.join(data_dir, 'zip')
if not os.path.exists(zip_dir):
os.makedirs(zip_dir)
f = File(data_dir)
f.putChild(b'', File(get_resource('www/index.html')))
f.putChild(b'status.html', File(get_resource('www/status.html')))
f.putChild(b'preloader.gif', File(get_resource('www/preloader.gif')))
trans = Transcriber(data_dir, nthreads=nthreads, ntranscriptionthreads=ntranscriptionthreads)
trans_ctrl = TranscriptionsController(trans)
f.putChild(b'transcriptions', trans_ctrl)
trans_zippr = TranscriptionZipper(zip_dir, trans)
f.putChild(b'zip', trans_zippr)
s = Site(f)
logging.info("about to listen")
reactor.listenTCP(port, s, interface=interface)
logging.info("listening")
reactor.run(installSignalHandlers=installSignalHandlers)
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser(
description='Align a transcript to audio by generating a new language model.')
parser.add_argument('--host', default="0.0.0.0",
help='host to run http server on')
parser.add_argument('--port', default=8765, type=int,
help='port number to run http server on')
parser.add_argument('--nthreads', default=multiprocessing.cpu_count(), type=int,
help='number of alignment threads')
parser.add_argument('--ntranscriptionthreads', default=2, type=int,
help='number of full-transcription threads (memory intensive)')
parser.add_argument('--log', default="INFO",
help='the log level (DEBUG, INFO, WARNING, ERROR, or CRITICAL)')
args = parser.parse_args()
log_level = args.log.upper()
logging.getLogger().setLevel(log_level)
logging.info('gentle %s' % (gentle.__version__))
logging.info('listening at %s:%d\n' % (args.host, args.port))
serve(args.port, args.host, nthreads=args.nthreads, ntranscriptionthreads=args.ntranscriptionthreads, installSignalHandlers=1)