javi8979 commited on
Commit
b8b5a68
·
verified ·
1 Parent(s): 1dfda0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -54
app.py CHANGED
@@ -1,49 +1,48 @@
1
  import gradio as gr
2
- from huggingface_hub import snapshot_download
3
- from vllm import LLM, SamplingParams
 
4
 
5
  # ------------------------
6
  # 1) Load the Model
7
  # ------------------------
8
- # Download the model repository, specify revision if needed
9
- model_dir = snapshot_download(repo_id="BSC-LT/salamandraTA-7B-instruct-GGUF", revision="main", allow_patterns=[
10
- "salamandrata_7b_inst_q4.gguf",
11
- "*tokenizer*",
12
- "tokenizer_config.json",
13
- "tokenizer.model",
14
- "config.json",
15
- ])
16
- model_name = "salamandrata_7b_inst_q4.gguf"
17
-
18
- # Create an LLM instance from vLLM
19
- llm = LLM(model=model_dir + '/' + model_name, tokenizer=model_dir)
20
-
21
- # We can define a single helper function to call the model:
22
  def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256):
23
- """
24
- Sends the prompt to the LLM using vLLM's chat interface.
25
- """
26
- messages = [{'role': 'user', 'content': prompt}]
27
- outputs = llm.chat(
28
- messages,
29
- sampling_params=SamplingParams(
30
- temperature=temperature,
31
- stop_token_ids=[5], # you can adjust the stop token ID if needed
32
- max_tokens=max_tokens
33
- )
 
 
 
 
 
 
 
 
34
  )
35
- # The model returns a list of "Generation" objects, each containing .outputs
36
- return outputs[0].outputs[0].text if outputs else ""
37
 
38
  # ------------------------
39
  # 2) Task-specific functions
40
  # ------------------------
41
 
42
  def general_translation(source_lang, target_lang, text):
43
- """
44
- General translation prompt:
45
- Translate from source_lang into target_lang.
46
- """
47
  prompt = (
48
  f"Translate the following text from {source_lang} into {target_lang}.\n"
49
  f"{source_lang}: {text}\n"
@@ -52,10 +51,6 @@ def general_translation(source_lang, target_lang, text):
52
  return call_model(prompt, temperature=0.1)
53
 
54
  def post_editing(source_lang, target_lang, source_text, machine_translation):
55
- """
56
- Post-editing prompt:
57
- Ask the model to fix any mistakes in the machine translation or keep it unedited.
58
- """
59
  prompt = (
60
  f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
61
  f"Source: {source_text}\n"
@@ -65,10 +60,6 @@ def post_editing(source_lang, target_lang, source_text, machine_translation):
65
  return call_model(prompt, temperature=0.1)
66
 
67
  def document_level_translation(source_lang, target_lang, document_text):
68
- """
69
- Document-level translation prompt:
70
- Translate a multi-paragraph document.
71
- """
72
  prompt = (
73
  f"Please translate this text from {source_lang} into {target_lang}.\n"
74
  f"{source_lang}: {document_text}\n"
@@ -77,16 +68,7 @@ def document_level_translation(source_lang, target_lang, document_text):
77
  return call_model(prompt, temperature=0.1)
78
 
79
  def named_entity_recognition(tokenized_text):
80
- """
81
- Named-entity recognition prompt:
82
- Label tokens as ORG, PER, LOC, MISC, or O.
83
- Expects the user to provide a list of tokens.
84
- """
85
- # Convert the input string into a list of tokens, if the user typed them as space-separated words
86
- # or if the user provided them as a Python list string, we can try to parse that.
87
- # For simplicity, let's assume it's a space-separated string.
88
  tokens = tokenized_text.strip().split()
89
-
90
  prompt = (
91
  "Analyse the following tokenized text and mark the tokens containing named entities.\n"
92
  "Use the following annotation guidelines with these tags for named entities:\n"
@@ -102,10 +84,6 @@ def named_entity_recognition(tokenized_text):
102
  return call_model(prompt, temperature=0.1)
103
 
104
  def grammar_checker(source_lang, sentence):
105
- """
106
- Grammar checker prompt:
107
- Fix any mistakes in the given source_lang sentence or keep it unedited if correct.
108
- """
109
  prompt = (
110
  f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
111
  f"Sentence: {sentence}\n"
 
1
  import gradio as gr
2
+ from datetime import datetime
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
 
6
  # ------------------------
7
  # 1) Load the Model
8
  # ------------------------
9
+ model_id = "BSC-LT/salamandraTA-7b-instruct"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_id,
13
+ device_map="auto",
14
+ torch_dtype=torch.bfloat16
15
+ )
16
+
17
+ # Common function to generate text using transformers
 
 
 
 
 
18
  def call_model(prompt: str, temperature: float = 0.1, max_tokens: int = 256):
19
+ message = [{"role": "user", "content": prompt}]
20
+ date_string = datetime.today().strftime('%Y-%m-%d')
21
+
22
+ chat_prompt = tokenizer.apply_chat_template(
23
+ message,
24
+ tokenize=False,
25
+ add_generation_prompt=True,
26
+ date_string=date_string
27
+ )
28
+
29
+ inputs = tokenizer.encode(chat_prompt, return_tensors="pt").to(model.device)
30
+ input_length = inputs.shape[1]
31
+ outputs = model.generate(
32
+ input_ids=inputs,
33
+ max_new_tokens=max_tokens,
34
+ do_sample=True,
35
+ temperature=temperature,
36
+ num_beams=5,
37
+ early_stopping=True
38
  )
39
+ return tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True)
 
40
 
41
  # ------------------------
42
  # 2) Task-specific functions
43
  # ------------------------
44
 
45
  def general_translation(source_lang, target_lang, text):
 
 
 
 
46
  prompt = (
47
  f"Translate the following text from {source_lang} into {target_lang}.\n"
48
  f"{source_lang}: {text}\n"
 
51
  return call_model(prompt, temperature=0.1)
52
 
53
  def post_editing(source_lang, target_lang, source_text, machine_translation):
 
 
 
 
54
  prompt = (
55
  f"Please fix any mistakes in the following {source_lang}-{target_lang} machine translation or keep it unedited if it's correct.\n"
56
  f"Source: {source_text}\n"
 
60
  return call_model(prompt, temperature=0.1)
61
 
62
  def document_level_translation(source_lang, target_lang, document_text):
 
 
 
 
63
  prompt = (
64
  f"Please translate this text from {source_lang} into {target_lang}.\n"
65
  f"{source_lang}: {document_text}\n"
 
68
  return call_model(prompt, temperature=0.1)
69
 
70
  def named_entity_recognition(tokenized_text):
 
 
 
 
 
 
 
 
71
  tokens = tokenized_text.strip().split()
 
72
  prompt = (
73
  "Analyse the following tokenized text and mark the tokens containing named entities.\n"
74
  "Use the following annotation guidelines with these tags for named entities:\n"
 
84
  return call_model(prompt, temperature=0.1)
85
 
86
  def grammar_checker(source_lang, sentence):
 
 
 
 
87
  prompt = (
88
  f"Please fix any mistakes in the following {source_lang} sentence or keep it unedited if it's correct.\n"
89
  f"Sentence: {sentence}\n"