rishiraj commited on
Commit
043c55e
·
verified ·
1 Parent(s): d439d5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -1,25 +1,40 @@
1
  import spaces
2
  import gradio as gr
3
- from transformers import pipeline
4
  import torch
5
 
6
- # Initialize the pipeline
7
- pipe = pipeline(
8
- "text-generation",
9
- model="sarvamai/sarvam-translate",
10
- torch_dtype=torch.float32,
11
- device="cuda:0",
12
- )
13
 
14
- @spaces.GPU
15
  def generate_response(tgt_lang, user_prompt):
16
  messages = [
17
  {"role": "system", "content": f"Translate the following sentence into {tgt_lang}."},
18
  {"role": "user", "content": user_prompt},
19
  ]
20
 
21
- output = pipe(messages, max_new_tokens=2048)
22
- return output[0]["generated_text"][-1]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Create Gradio UI
25
  demo = gr.Interface(
 
1
  import spaces
2
  import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
+ model_name = "sarvamai/sarvam-translate"
7
+
8
+ # Load tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).to('cuda:0')
 
 
11
 
12
+ @spaces.GPU(duration=120)
13
  def generate_response(tgt_lang, user_prompt):
14
  messages = [
15
  {"role": "system", "content": f"Translate the following sentence into {tgt_lang}."},
16
  {"role": "user", "content": user_prompt},
17
  ]
18
 
19
+ # Apply chat template to structure the conversation
20
+ text = tokenizer.apply_chat_template(
21
+ messages,
22
+ tokenize=False,
23
+ )
24
+
25
+ # Tokenize and move input to model device
26
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
27
+
28
+ # Generate the output
29
+ generated_ids = model.generate(
30
+ **model_inputs,
31
+ max_new_tokens=1024,
32
+ do_sample=True,
33
+ temperature=0.01,
34
+ num_return_sequences=1
35
+ )
36
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
37
+ return tokenizer.decode(output_ids, skip_special_tokens=True)
38
 
39
  # Create Gradio UI
40
  demo = gr.Interface(