Reality123b commited on
Commit
9b71c73
·
verified ·
1 Parent(s): 1794547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -1,7 +1,9 @@
1
- from config import Flask,pipeline_dict,Response,convHandler
 
2
  from application.chat_inference import ChatInference
3
- from flask import render_template,request
4
  from application.utils.image_captioning import ImageCaptioning
 
5
 
6
  app = Flask(__name__, template_folder='application/templates', static_folder='application/static')
7
 
@@ -10,38 +12,53 @@ image_captioning = ImageCaptioning()
10
 
11
  @app.route('/')
12
  def home():
13
- return render_template('index.html')
 
 
 
14
 
15
- @app.route('/completions',methods=['POST'])
16
  def completeions():
 
17
  data = request.json
18
  models = pipeline_dict['api']['models']
19
- if(data.get('model',None) not in models):
20
  return "Model Not Found", 404
21
  model_info = models[data['model']]
22
- data.update(
23
- {
24
- "base_url": model_info['api_url'],
25
- "type": model_info['type']
26
- }
27
- )
28
- return chat_inference.chat(data=data,handle_stream=pipeline_dict['handle_stream'],user=request.headers.get("X-Forwarded-For", "1,2,3").split(',')[0].strip())
29
-
30
  @app.route('/convs')
31
  def get_conv():
32
- return convHandler.get_conv(request.headers.get("X-Forwarded-For", "1,2,3").split(',')[0].strip())
 
33
 
34
  @app.route('/create', methods=['POST'])
35
  def create_conv():
 
36
  sysPrompt = request.json.get('system_prompt', '')
37
- return convHandler.create_conv(ip=request.headers.get("X-Forwarded-For", "1,2,3").split(',')[0].strip(),sysPrompt=sysPrompt)
 
38
  @app.route('/fetch', methods=['POST'])
39
  def fetch():
 
40
  convId = request.json.get('convId')
41
- return convHandler.fetch_conv(convId=convId,ip=request.headers.get("X-Forwarded-For", "1,2,3").split(',')[0].strip())
 
42
  @app.route('/models')
43
  def models():
44
  return list(pipeline_dict['api']['models'].keys())
45
 
 
 
 
 
 
 
 
 
46
  if __name__ == "__main__":
47
  app.run(host="0.0.0.0", port=7860, debug=False)
 
1
+ # app.py
2
+ from config import Flask, pipeline_dict, Response, convHandler, get_user_id
3
  from application.chat_inference import ChatInference
4
+ from flask import render_template, request, make_response
5
  from application.utils.image_captioning import ImageCaptioning
6
+ from application.utils.text_to_speech import generate_tts # Import
7
 
8
  app = Flask(__name__, template_folder='application/templates', static_folder='application/static')
9
 
 
12
 
13
  @app.route('/')
14
  def home():
15
+ user_id = get_user_id()
16
+ response = make_response(render_template('index.html'))
17
+ response.set_cookie('user_id', user_id) # Set the cookie
18
+ return response
19
 
20
+ @app.route('/completions', methods=['POST'])
21
  def completeions():
22
+ user_id = get_user_id()
23
  data = request.json
24
  models = pipeline_dict['api']['models']
25
+ if data.get('model', None) not in models:
26
  return "Model Not Found", 404
27
  model_info = models[data['model']]
28
+ data.update({
29
+ "base_url": model_info['api_url'],
30
+ "type": model_info['type']
31
+ })
32
+ return chat_inference.chat(data=data, handle_stream=pipeline_dict['handle_stream'], user=user_id)
33
+
 
 
34
  @app.route('/convs')
35
  def get_conv():
36
+ user_id = get_user_id()
37
+ return convHandler.get_conv(user_id)
38
 
39
  @app.route('/create', methods=['POST'])
40
  def create_conv():
41
+ user_id = get_user_id()
42
  sysPrompt = request.json.get('system_prompt', '')
43
+ return convHandler.create_conv(ip=user_id, sysPrompt=sysPrompt)
44
+
45
  @app.route('/fetch', methods=['POST'])
46
  def fetch():
47
+ user_id = get_user_id()
48
  convId = request.json.get('convId')
49
+ return convHandler.fetch_conv(convId=convId, ip=user_id)
50
+
51
  @app.route('/models')
52
  def models():
53
  return list(pipeline_dict['api']['models'].keys())
54
 
55
+ @app.route('/tts') # New route for TTS
56
+ def tts():
57
+ text = request.args.get('text')
58
+ if not text:
59
+ return "No text provided", 400
60
+ audio_stream = generate_tts(text)
61
+ return Response(audio_stream, mimetype="audio/wav")
62
+
63
  if __name__ == "__main__":
64
  app.run(host="0.0.0.0", port=7860, debug=False)