Machlovi commited on
Commit
79de16d
·
verified ·
1 Parent(s): 87e3a34

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +143 -4
handler.py CHANGED
@@ -1,5 +1,3 @@
1
-
2
- # handler.py
3
  import os
4
  import torch
5
  from transformers import AutoTokenizer, TextStreamer
@@ -79,6 +77,8 @@ class EndpointHandler:
79
  lora_adapter = config["lora_adapter"]
80
 
81
  # Load the model and tokenizer
 
 
82
  self.model, self.tokenizer = FastLanguageModel.from_pretrained(
83
  model_name=model_id,
84
  max_seq_length=self.max_seq_length,
@@ -89,6 +89,9 @@ class EndpointHandler:
89
  self.model = PeftModel.from_pretrained(self.model, lora_adapter)
90
  self.model.eval()
91
 
 
 
 
92
  print(f"Loaded model: {self.selected_model_name}")
93
  print(f"Chat template: {self.chat_template}")
94
  print(f"LoRA adapter: {lora_adapter}")
@@ -106,8 +109,8 @@ class EndpointHandler:
106
  "" # Leave output blank for generation
107
  )
108
 
109
- # Tokenize input
110
- inputs = self.tokenizer([formatted_input], return_tensors="pt")
111
 
112
  # Generate response
113
  with torch.no_grad():
@@ -134,6 +137,142 @@ class EndpointHandler:
134
  }
135
 
136
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # from unsloth import FastLanguageModel # FastVisionModel for LLMs
138
  # import torch
139
  # import os
 
 
 
1
  import os
2
  import torch
3
  from transformers import AutoTokenizer, TextStreamer
 
77
  lora_adapter = config["lora_adapter"]
78
 
79
  # Load the model and tokenizer
80
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
  self.model, self.tokenizer = FastLanguageModel.from_pretrained(
83
  model_name=model_id,
84
  max_seq_length=self.max_seq_length,
 
89
  self.model = PeftModel.from_pretrained(self.model, lora_adapter)
90
  self.model.eval()
91
 
92
+ # Move model to the device (GPU or CPU)
93
+ self.model.to(self.device)
94
+
95
  print(f"Loaded model: {self.selected_model_name}")
96
  print(f"Chat template: {self.chat_template}")
97
  print(f"LoRA adapter: {lora_adapter}")
 
109
  "" # Leave output blank for generation
110
  )
111
 
112
+ # Tokenize input and move to the same device as the model
113
+ inputs = self.tokenizer([formatted_input], return_tensors="pt").to(self.device)
114
 
115
  # Generate response
116
  with torch.no_grad():
 
137
  }
138
 
139
  return response
140
+
141
+ # # handler.py
142
+ # import os
143
+ # import torch
144
+ # from transformers import AutoTokenizer, TextStreamer
145
+ # from unsloth import FastLanguageModel
146
+ # from peft import PeftModel
147
+
148
+ # class EndpointHandler:
149
+ # def __init__(self, model_dir):
150
+ # # Configuration for your safety model
151
+ # self.max_seq_length = 2048
152
+ # self.load_in_4bit = True
153
+
154
+ # # Get model configuration from environment variables or use defaults
155
+ # self.selected_model_name = os.environ.get("SELECTED_MODEL", "Phi-4-old")
156
+
157
+ # # Model configurations
158
+ # self.model_options = {
159
+ # "Llama-3.1-8B": {
160
+ # "model_id": "unsloth/Meta-Llama-3.1-8B-Instruct",
161
+ # "chat_template": "llama-3",
162
+ # "lora_adapter": "Machlovi/SafeLlama"
163
+ # },
164
+ # "Qwen2.5-7B": {
165
+ # "model_id": "unsloth/Qwen2.5-7B",
166
+ # "chat_template": "chatml",
167
+ # "lora_adapter": "Machlovi/Qwen2.57bHB"
168
+ # },
169
+ # "Phi-4": {
170
+ # "model_id": "unsloth/Phi-4-unsloth-bnb-4bit",
171
+ # "chat_template": "phi-4",
172
+ # "lora_adapter": "Machlovi/Phi4_Safe"
173
+ # },
174
+ # "Gemma3-4": {
175
+ # "model_id": "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
176
+ # "chat_template": "gemma-3",
177
+ # "lora_adapter": "Machlovi/Gemma3-4"
178
+ # },
179
+ # "Phi-4-old": {
180
+ # "model_id": "unsloth/Phi-4-unsloth-bnb-4bit",
181
+ # "chat_template": "safe_phi",
182
+ # "lora_adapter": "Machlovi/Safe_Phi4"
183
+ # },
184
+ # }
185
+
186
+ # # Safety policy template
187
+ # self.safe_phi_template = """
188
+ # You are a policy expert trying to help determine whether a user
189
+ # prompt is in violation of the defined safety policies.
190
+ # Our safety principle is defined in the below:
191
+ # ### Safety_policy:
192
+ # {}
193
+ # <start_of_turn>
194
+ # ### Input:
195
+ # {}
196
+ # <end_of_turn>
197
+ # Does the human question violate the above principle? Your answer must start
198
+ # with 'Safe' or 'Unsafe'.
199
+ # - First line must read 'safe' or 'unsafe'.
200
+ # - If unsafe, a second line must include a comma-separated list of violated categories.
201
+ # ### Response:
202
+ # {}
203
+ # """
204
+
205
+ # self.policy = """You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.
206
+ # * "No Harassment": The input shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
207
+ # * "No Dangerous Content": The input shall not contain or seek generation of content that harming oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide, violent protest).
208
+ # * "No Hate Speech": The input shall not contain or seek generation of content that targeting identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying on the bases of race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.
209
+ # * "No Sexually Explicit Information": The input shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually arousal descriptions, content aimed at causing arousal, sexual assault (ex: rape),sexual harassment (physical, verbal, or visual)). Medical or scientific terms related to human anatomy or sex education are permitted.
210
+ # Does the human question violate the above principle? Your answer must start with 'Safe' or 'Unsafe'.
211
+ # - If unsafe, a second line must include a comma-separated list of violated categories.
212
+ # """
213
+
214
+ # # Load model configuration
215
+ # config = self.model_options[self.selected_model_name]
216
+ # model_id = config["model_id"]
217
+ # self.chat_template = config["chat_template"]
218
+ # lora_adapter = config["lora_adapter"]
219
+
220
+ # # Load the model and tokenizer
221
+ # self.model, self.tokenizer = FastLanguageModel.from_pretrained(
222
+ # model_name=model_id,
223
+ # max_seq_length=self.max_seq_length,
224
+ # load_in_4bit=self.load_in_4bit,
225
+ # )
226
+
227
+ # # Load LoRA adapter
228
+ # self.model = PeftModel.from_pretrained(self.model, lora_adapter)
229
+ # self.model.eval()
230
+
231
+ # print(f"Loaded model: {self.selected_model_name}")
232
+ # print(f"Chat template: {self.chat_template}")
233
+ # print(f"LoRA adapter: {lora_adapter}")
234
+
235
+ # def __call__(self, data):
236
+ # """
237
+ # Run safety check on input text
238
+ # """
239
+ # input_text = data.get("inputs", "")
240
+
241
+ # # Prepare input with the safety template
242
+ # formatted_input = self.safe_phi_template.format(
243
+ # self.policy,
244
+ # input_text,
245
+ # "" # Leave output blank for generation
246
+ # )
247
+
248
+ # # Tokenize input
249
+ # inputs = self.tokenizer([formatted_input], return_tensors="pt")
250
+
251
+ # # Generate response
252
+ # with torch.no_grad():
253
+ # text_streamer = TextStreamer(self.tokenizer)
254
+ # output = self.model.generate(
255
+ # **inputs,
256
+ # streamer=text_streamer,
257
+ # max_new_tokens=24
258
+ # )
259
+
260
+ # # Decode the output
261
+ # decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
262
+
263
+ # # Extract safety classification
264
+ # safety_result = decoded_output.split("### Response:")[-1].strip()
265
+
266
+ # # Determine if the input is safe or not
267
+ # is_safe = safety_result.lower().startswith("safe")
268
+
269
+ # # Prepare the response
270
+ # response = {
271
+ # "is_safe": is_safe,
272
+ # "safety_result": safety_result
273
+ # }
274
+
275
+ # return response
276
  # from unsloth import FastLanguageModel # FastVisionModel for LLMs
277
  # import torch
278
  # import os