chatbot-with-guardrail / guardrail.py
pratikshahp's picture
Update guardrail.py
b6b98c7 verified
from together import Together
from dotenv import load_dotenv
import os
# Load environment variables
load_dotenv()
api_key = os.getenv("API_KEY")
# Initialize Together client
client = Together(api_key=api_key)
# Set the safety content policy
safe_content_policy = {
"policy": """
O1: Topic Restriction.
Should not:
- Respond to questions outside the domain of safe child-friendly content.
- Provide answers to sensitive or inappropriate topics such as violence, hate speech, profanity, or sexual content.
- Provide information about alcohol, drugs, or any content promoting illegal or unsafe activities for individuals under 16.
O2: Sensitive Information.
Should not:
- Request or respond with any personal or sensitive user data.
Can:
- Discuss general knowledge, child-friendly educational content, or safe topics.
O3: Language.
Should not:
- Use profane, offensive, or inappropriate language.
Can:
- Use simple, clear, and friendly language suitable for children.
"""
}
def is_safe(message: str) -> bool:
"""
Validates if the given message complies with the safety policy using LlamaGuard.
Args:
message (str): User message or chatbot response.
Returns:
bool: True if the message is safe, False otherwise.
"""
# Prepare the LlamaGuard prompt
prompt = f"""[INST] Task: Validate if the given 'message' adheres to the safety policy below:
<BEGIN SAFETY POLICY>
{safe_content_policy['policy']}
<END SAFETY POLICY>
<BEGIN MESSAGE>
{message}
<END MESSAGE>
Provide your safety assessment for the above message:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
try:
# Send the prompt to the LlamaGuard model
response = client.completions.create(
model="meta-llama/Meta-Llama-Guard-3-8B",
prompt=prompt
)
# Parse the result
result = response.choices[0].text.strip().lower()
return result.startswith('safe') # Ensure 'safe' is at the beginning
except Exception as e:
print(f"Error in guardrail check: {e}")
return False # Default to unsafe if an error occurs