Spaces:
Sleeping
Sleeping
File size: 10,825 Bytes
eb7f886 0032334 eb7f886 f8bc9fc 17de6a3 f8bc9fc d48ce46 eb7f886 f8bc9fc a9c787b eb7f886 a9c787b eb7f886 a9c787b eb7f886 a9c787b eb7f886 0032334 eb7f886 070f012 eb7f886 070f012 17de6a3 eb7f886 070f012 eb7f886 070f012 eb7f886 070f012 eb7f886 17de6a3 eb7f886 17de6a3 eb7f886 17de6a3 eb7f886 ef1be3c eb7f886 ef1be3c eb7f886 070f012 eb7f886 070f012 eb7f886 070f012 eb7f886 070f012 eb7f886 070f012 eb7f886 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from fastapi import FastAPI
from pydantic import BaseModel
from simple_salesforce import Salesforce
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Salesforce connection
def get_salesforce_connection():
try:
# Load credentials from environment variables
username = os.getenv("SF_USERNAME")
password = os.getenv("SF_PASSWORD")
security_token = os.getenv("SF_SECURITY_TOKEN")
domain = os.getenv("SF_DOMAIN", "login") # Default to production (login.salesforce.com)
# Validate credentials
if not all([username, password, security_token, domain]):
missing = []
if not username:
missing.append("SF_USERNAME")
if not password:
missing.append("SF_PASSWORD")
if not security_token:
missing.append("SF_SECURITY_TOKEN")
if not domain:
missing.append("SF_DOMAIN")
raise ValueError(f"Missing environment variables: {', '.join(missing)}. Set them in .env or Space environment variables.")
# Ensure all are strings
if not all(isinstance(x, str) for x in [username, password, security_token, domain]):
raise ValueError("All Salesforce credentials (SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_DOMAIN) must be strings.")
sf = Salesforce(
username=username,
password=password,
security_token=security_token,
domain=domain
)
# Test connection by fetching user info
sf.User.get(sf.user_id)
return sf
except Exception as e:
raise Exception(f"Failed to connect to Salesforce: {str(e)}")
# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
# Model configuration
MODEL_PATH = "facebook/bart-large" # Public model
# MODEL_PATH = "your_actual_username/fine_tuned_bart_construction" # Uncomment after uploading
try:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN if HF_TOKEN else None)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN if HF_TOKEN else None)
except Exception as e:
raise Exception(f"Failed to load model: {str(e)}")
# Define input model for FastAPI
class ChecklistInput(BaseModel):
role: str = "Supervisor"
project_id: str = "Unknown"
project_name: str = "Unknown Project"
milestones: str = "No milestones provided"
record_id: str = None
supervisor_id: str = None
project_id_sf: str = None
reflection_log: str = None
download_link: str = None
# Initialize FastAPI
app = FastAPI()
@app.post("/generate")
async def generate_checklist(data: ChecklistInput):
try:
inputs = f"Role: {data.role} Project: {data.project_id} ({data.project_name}) Milestones: {data.milestones}"
input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
kpi_flag = "delay" in data.milestones.lower() or "behind" in data.milestones.lower()
if data.record_id:
sf = get_salesforce_connection()
existing_record = sf.Supervisor_AI_Coaching__c.get(data.record_id, default={
'Name': '',
'Supervisor_ID__c': None,
'Project_ID__c': None,
'Reflection_Log__c': '',
'Download_Link__c': '',
'Engagement_Score__c': 0,
'KPI_Flag__c': False,
'Daily_Checklist__c': '',
'Suggested_Tips__c': ''
})
update_data = {
'Daily_Checklist__c': checklist,
'Suggested_Tips__c': tips,
'Engagement_Score__c': existing_record.get('Engagement_Score__c', 0) + 10,
'KPI_Flag__c': kpi_flag,
'Supervisor_ID__c': data.supervisor_id if data.supervisor_id else existing_record.get('Supervisor_ID__c'),
'Project_ID__c': data.project_id_sf if data.project_id_sf else existing_record.get('Project_ID__c'),
'Reflection_Log__c': data.reflection_log if data.reflection_log else existing_record.get('Reflection_Log__c', ''),
'Download_Link__c': data.download_link if data.download_link else existing_record.get('Download_Link__c', '')
}
sf.Supervisor_AI_Coaching__c.update(data.record_id, update_data)
return {
"checklist": checklist,
"tips": tips,
"kpi_flag": kpi_flag
}
except Exception as e:
return {"error": str(e)}
# Login and display records
def login_and_display(project_id_sf):
try:
sf = get_salesforce_connection()
query = f"SELECT Id, Name, Supervisor_ID__c, Project_ID__c, Daily_Checklist__c, Suggested_Tips__c, Reflection_Log__c, Engagement_Score__c, KPI_Flag__c, Download_Link__c FROM Supervisor_AI_Coaching__c WHERE Project_ID__c = '{project_id_sf}'"
records = sf.query(query)["records"]
if not records:
return "No records found for Project ID.", "", False
output = "Supervisor_AI_Coaching__c Records:\n"
for record in records:
output += (
f"Record ID: {record['Id']}\n"
f"Name: {record['Name']}\n"
f"Supervisor ID: {record['Supervisor_ID__c']}\n"
f"Project ID: {record['Project_ID__c']}\n"
f"Daily Checklist: {record['Daily_Checklist__c'] or 'N/A'}\n"
f"Suggested Tips: {record['Suggested_Tips__c'] or 'N/A'}\n"
f"Reflection Log: {record['Reflection_Log__c'] or 'N/A'}\n"
f"Engagement Score: {record['Engagement_Score__c'] or 0}%\n"
f"KPI Flag: {record['KPI_Flag__c']}\n"
f"Download Link: {record['Download_Link__c'] or 'N/A'}\n"
f"{'-'*50}\n"
)
return output, "", False
except Exception as e:
return f"Error querying Salesforce: {str(e)}", "", False
# Generate checklist from record
def gradio_generate_checklist(record_id, role="Supervisor", project_id="Unknown", project_name="Unknown Project", milestones="No milestones provided", supervisor_id="", project_id_sf="", reflection_log="", download_link=""):
try:
sf = get_salesforce_connection()
existing_record = sf.Supervisor_AI_Coaching__c.get(record_id, default={
'Name': '',
'Supervisor_ID__c': None,
'Project_ID__c': None,
'Reflection_Log__c': '',
'Download_Link__c': '',
'Engagement_Score__c': 0,
'KPI_Flag__c': False,
'Daily_Checklist__c': '',
'Suggested_Tips__c': ''
})
inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
kpi_flag = "delay" in milestones.lower() or "behind" in milestones.lower()
update_data = {
'Daily_Checklist__c': checklist,
'Suggested_Tips__c': tips,
'Engagement_Score__c': existing_record.get('Engagement_Score__c', 0) + 10,
'KPI_Flag__c': kpi_flag,
'Supervisor_ID__c': supervisor_id if supervisor_id else existing_record.get('Supervisor_ID__c'),
'Project_ID__c': project_id_sf if project_id_sf else existing_record.get('Project_ID__c'),
'Reflection_Log__c': reflection_log if reflection_log else existing_record.get('Reflection_Log__c', ''),
'Download_Link__c': download_link if download_link else existing_record.get('Download_Link__c', '')
}
sf.Supervisor_AI_Coaching__c.update(record_id, update_data)
status = f"Updated Salesforce record {record_id}"
return checklist, tips, kpi_flag, status
except Exception as e:
return f"Error: {str(e)}", "", False, ""
# Define Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# AI Coach for Site Supervisors")
gr.Markdown("Enter a Project ID to view Supervisor_AI_Coaching__c records and generate checklists.")
with gr.Tab("Login"):
project_id_input = gr.Textbox(label="Project ID (Salesforce Project__c ID)", placeholder="Enter Project ID")
login_button = gr.Button("Submit")
records_output = gr.Textbox(label="Records", lines=10)
login_button.click(
fn=login_and_display,
inputs=project_id_input,
outputs=[records_output, gr.Textbox(visible=False), gr.Checkbox(visible=False)]
)
with gr.Tab("Generate Checklist"):
record_id = gr.Textbox(label="Record ID", placeholder="Enter Record ID from above")
role = gr.Textbox(label="Role", value="Supervisor")
project_id = gr.Textbox(label="Project ID", value="P001")
project_name = gr.Textbox(label="Project Name", value="Building A")
milestones = gr.Textbox(label="Milestones", value="Complete foundation by 5/15")
supervisor_id = gr.Textbox(label="Supervisor ID (Salesforce User ID, optional)", value="")
project_id_sf = gr.Textbox(label="Project ID (Salesforce Project__c ID, optional)", value="")
reflection_log = gr.Textbox(label="Reflection Log (optional)", value="")
download_link = gr.Textbox(label="Download Link (optional)", value="")
generate_button = gr.Button("Generate and Update")
checklist_output = gr.Textbox(label="Checklist")
tips_output = gr.Textbox(label="Tips")
kpi_flag_output = gr.Checkbox(label="KPI Flag")
status_output = gr.Textbox(label="Salesforce Status")
generate_button.click(
fn=gradio_generate_checklist,
inputs=[record_id, role, project_id, project_name, milestones, supervisor_id, project_id_sf, reflection_log, download_link],
outputs=[checklist_output, tips_output, kpi_flag_output, status_output]
)
# Mount FastAPI
iface.app = app
if __name__ == "__main__":
try:
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
except Exception as e:
print(f"Failed to launch Gradio: {str(e)}")
|