RishabhBhardwaj commited on
Commit
468731a
·
1 Parent(s): 0a9972d
Files changed (2) hide show
  1. app.py +3 -3
  2. app_walledguard.py +88 -0
app.py CHANGED
@@ -16,7 +16,7 @@ Answer: [/INST]
16
  # Load the model and tokenizer
17
  @st.cache_resource
18
  def load_model():
19
- model_name = "walledai/walledguard-c"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name)
22
  return tokenizer, model
@@ -35,8 +35,8 @@ def evaluate_text(user_input):
35
  tokenizer, model = st.session_state.model_and_tokenizer
36
 
37
  # Prepare input
38
- input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
39
-
40
  # Generate output
41
  output = model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=0)
42
 
 
16
  # Load the model and tokenizer
17
  @st.cache_resource
18
  def load_model():
19
+ model_name = "meta-llama/Llama-Guard-3-8B"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name)
22
  return tokenizer, model
 
35
  tokenizer, model = st.session_state.model_and_tokenizer
36
 
37
  # Prepare input
38
+ input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
39
+
40
  # Generate output
41
  output = model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=0)
42
 
app_walledguard.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import requests
4
+ from PIL import Image
5
+ from io import BytesIO
6
+
7
+ # Define the template
8
+ TEMPLATE = """[INST] Evaluate if the given text is harmful i.e. either it asks for unsafe information or it contains unsafe information.
9
+ <START TEXT>
10
+ {prompt}
11
+ <END TEXT>
12
+
13
+ Answer: [/INST]
14
+ """
15
+
16
+ # Load the model and tokenizer
17
+ @st.cache_resource
18
+ def load_model():
19
+ model_name = "walledai/walledguard-c"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForCausalLM.from_pretrained(model_name)
22
+ return tokenizer, model
23
+
24
+ # Function to load image from URL
25
+ @st.cache_data()
26
+ def load_image_from_url(url):
27
+ response = requests.get(url)
28
+ img = Image.open(BytesIO(response.content))
29
+ return img
30
+
31
+ # Evaluation function
32
+ def evaluate_text(user_input):
33
+ if user_input:
34
+ # Get model and tokenizer from session state
35
+ tokenizer, model = st.session_state.model_and_tokenizer
36
+
37
+ # Prepare input
38
+ input_ids = tokenizer.encode(TEMPLATE.format(prompt=user_input), return_tensors="pt")
39
+
40
+ # Generate output
41
+ output = model.generate(input_ids=input_ids, max_new_tokens=20, pad_token_id=0)
42
+
43
+ # Decode output
44
+ prompt_len = input_ids.shape[-1]
45
+ output_decoded = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
46
+
47
+ # Determine prediction
48
+ prediction = 'unsafe' if 'unsafe' in output_decoded.lower() else 'safe'
49
+
50
+ return prediction
51
+ return None
52
+
53
+ # Streamlit app
54
+ st.title("Text Safety Evaluator")
55
+
56
+ # Load model and tokenizer once and store in session state
57
+ if 'model_and_tokenizer' not in st.session_state:
58
+ st.session_state.model_and_tokenizer = load_model()
59
+
60
+ # User input
61
+ user_input = st.text_area("Enter the text you want to evaluate:", height=100)
62
+
63
+ # Create an empty container for the result
64
+ result_container = st.empty()
65
+
66
+ if st.button("Evaluate"):
67
+ prediction = evaluate_text(user_input)
68
+ if prediction:
69
+ result_container.subheader("Evaluation Result:")
70
+ result_container.write(f"The text is evaluated as: **{prediction.upper()}**")
71
+ else:
72
+ result_container.warning("Please enter some text to evaluate.")
73
+
74
+ # Add logo at the bottom center (only once)
75
+ #if 'logo_displayed' not in st.session_state:
76
+ col1, col2, col3 = st.columns([1,2,1])
77
+ with col2:
78
+ logo_url = "https://github.com/walledai/walledeval/assets/32847115/d8b1d14f-7071-448b-8997-2eeba4c2c8f6"
79
+ logo = load_image_from_url(logo_url)
80
+ st.image(logo, use_column_width=True, width=500) # Adjust the width as needed
81
+ #st.session_state.logo_displayed = True
82
+
83
+ # Add information about Walled Guard Advanced (only once)
84
+ #if 'info_displayed' not in st.session_state:
85
+ col1, col2, col3 = st.columns([1,2,1])
86
+ with col2:
87
+ st.info("For a more performant version, check out Walled Guard Advanced. Connect with us at admin@walled.ai for more information.")
88
+ #st.session_state.info_displayed = True