File size: 10,825 Bytes
eb7f886
 
 
0032334
eb7f886
 
 
f8bc9fc
 
17de6a3
f8bc9fc
d48ce46
eb7f886
f8bc9fc
a9c787b
eb7f886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9c787b
eb7f886
 
 
 
a9c787b
 
eb7f886
 
 
a9c787b
eb7f886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0032334
eb7f886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
070f012
eb7f886
 
 
 
070f012
17de6a3
eb7f886
070f012
eb7f886
 
070f012
eb7f886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
070f012
eb7f886
17de6a3
eb7f886
 
17de6a3
 
eb7f886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17de6a3
eb7f886
 
ef1be3c
eb7f886
ef1be3c
eb7f886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
070f012
eb7f886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
070f012
eb7f886
 
070f012
eb7f886
070f012
eb7f886
070f012
eb7f886
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import os
from fastapi import FastAPI
from pydantic import BaseModel
from simple_salesforce import Salesforce
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Salesforce connection
def get_salesforce_connection():
    try:
        # Load credentials from environment variables
        username = os.getenv("SF_USERNAME")
        password = os.getenv("SF_PASSWORD")
        security_token = os.getenv("SF_SECURITY_TOKEN")
        domain = os.getenv("SF_DOMAIN", "login")  # Default to production (login.salesforce.com)

        # Validate credentials
        if not all([username, password, security_token, domain]):
            missing = []
            if not username:
                missing.append("SF_USERNAME")
            if not password:
                missing.append("SF_PASSWORD")
            if not security_token:
                missing.append("SF_SECURITY_TOKEN")
            if not domain:
                missing.append("SF_DOMAIN")
            raise ValueError(f"Missing environment variables: {', '.join(missing)}. Set them in .env or Space environment variables.")

        # Ensure all are strings
        if not all(isinstance(x, str) for x in [username, password, security_token, domain]):
            raise ValueError("All Salesforce credentials (SF_USERNAME, SF_PASSWORD, SF_SECURITY_TOKEN, SF_DOMAIN) must be strings.")

        sf = Salesforce(
            username=username,
            password=password,
            security_token=security_token,
            domain=domain
        )

        # Test connection by fetching user info
        sf.User.get(sf.user_id)
        return sf
    except Exception as e:
        raise Exception(f"Failed to connect to Salesforce: {str(e)}")

# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")

# Model configuration
MODEL_PATH = "facebook/bart-large"  # Public model
# MODEL_PATH = "your_actual_username/fine_tuned_bart_construction"  # Uncomment after uploading

try:
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN if HF_TOKEN else None)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_auth_token=HF_TOKEN if HF_TOKEN else None)
except Exception as e:
    raise Exception(f"Failed to load model: {str(e)}")

# Define input model for FastAPI
class ChecklistInput(BaseModel):
    role: str = "Supervisor"
    project_id: str = "Unknown"
    project_name: str = "Unknown Project"
    milestones: str = "No milestones provided"
    record_id: str = None
    supervisor_id: str = None
    project_id_sf: str = None
    reflection_log: str = None
    download_link: str = None

# Initialize FastAPI
app = FastAPI()

@app.post("/generate")
async def generate_checklist(data: ChecklistInput):
    try:
        inputs = f"Role: {data.role} Project: {data.project_id} ({data.project_name}) Milestones: {data.milestones}"
        input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
        outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
        checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
        tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
        kpi_flag = "delay" in data.milestones.lower() or "behind" in data.milestones.lower()

        if data.record_id:
            sf = get_salesforce_connection()
            existing_record = sf.Supervisor_AI_Coaching__c.get(data.record_id, default={
                'Name': '',
                'Supervisor_ID__c': None,
                'Project_ID__c': None,
                'Reflection_Log__c': '',
                'Download_Link__c': '',
                'Engagement_Score__c': 0,
                'KPI_Flag__c': False,
                'Daily_Checklist__c': '',
                'Suggested_Tips__c': ''
            })
            update_data = {
                'Daily_Checklist__c': checklist,
                'Suggested_Tips__c': tips,
                'Engagement_Score__c': existing_record.get('Engagement_Score__c', 0) + 10,
                'KPI_Flag__c': kpi_flag,
                'Supervisor_ID__c': data.supervisor_id if data.supervisor_id else existing_record.get('Supervisor_ID__c'),
                'Project_ID__c': data.project_id_sf if data.project_id_sf else existing_record.get('Project_ID__c'),
                'Reflection_Log__c': data.reflection_log if data.reflection_log else existing_record.get('Reflection_Log__c', ''),
                'Download_Link__c': data.download_link if data.download_link else existing_record.get('Download_Link__c', '')
            }
            sf.Supervisor_AI_Coaching__c.update(data.record_id, update_data)

        return {
            "checklist": checklist,
            "tips": tips,
            "kpi_flag": kpi_flag
        }
    except Exception as e:
        return {"error": str(e)}

# Login and display records
def login_and_display(project_id_sf):
    try:
        sf = get_salesforce_connection()
        query = f"SELECT Id, Name, Supervisor_ID__c, Project_ID__c, Daily_Checklist__c, Suggested_Tips__c, Reflection_Log__c, Engagement_Score__c, KPI_Flag__c, Download_Link__c FROM Supervisor_AI_Coaching__c WHERE Project_ID__c = '{project_id_sf}'"
        records = sf.query(query)["records"]
        if not records:
            return "No records found for Project ID.", "", False

        output = "Supervisor_AI_Coaching__c Records:\n"
        for record in records:
            output += (
                f"Record ID: {record['Id']}\n"
                f"Name: {record['Name']}\n"
                f"Supervisor ID: {record['Supervisor_ID__c']}\n"
                f"Project ID: {record['Project_ID__c']}\n"
                f"Daily Checklist: {record['Daily_Checklist__c'] or 'N/A'}\n"
                f"Suggested Tips: {record['Suggested_Tips__c'] or 'N/A'}\n"
                f"Reflection Log: {record['Reflection_Log__c'] or 'N/A'}\n"
                f"Engagement Score: {record['Engagement_Score__c'] or 0}%\n"
                f"KPI Flag: {record['KPI_Flag__c']}\n"
                f"Download Link: {record['Download_Link__c'] or 'N/A'}\n"
                f"{'-'*50}\n"
            )
        return output, "", False
    except Exception as e:
        return f"Error querying Salesforce: {str(e)}", "", False

# Generate checklist from record
def gradio_generate_checklist(record_id, role="Supervisor", project_id="Unknown", project_name="Unknown Project", milestones="No milestones provided", supervisor_id="", project_id_sf="", reflection_log="", download_link=""):
    try:
        sf = get_salesforce_connection()
        existing_record = sf.Supervisor_AI_Coaching__c.get(record_id, default={
            'Name': '',
            'Supervisor_ID__c': None,
            'Project_ID__c': None,
            'Reflection_Log__c': '',
            'Download_Link__c': '',
            'Engagement_Score__c': 0,
            'KPI_Flag__c': False,
            'Daily_Checklist__c': '',
            'Suggested_Tips__c': ''
        })
        inputs = f"Role: {role} Project: {project_id} ({project_name}) Milestones: {milestones}"
        input_ids = tokenizer(inputs, return_tensors="pt", max_length=128, truncation=True).input_ids
        outputs = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
        checklist = tokenizer.decode(outputs[0], skip_special_tokens=True)
        tips = "1. Prioritize safety checks\n2. Review milestones\n3. Log progress"
        kpi_flag = "delay" in milestones.lower() or "behind" in milestones.lower()

        update_data = {
            'Daily_Checklist__c': checklist,
            'Suggested_Tips__c': tips,
            'Engagement_Score__c': existing_record.get('Engagement_Score__c', 0) + 10,
            'KPI_Flag__c': kpi_flag,
            'Supervisor_ID__c': supervisor_id if supervisor_id else existing_record.get('Supervisor_ID__c'),
            'Project_ID__c': project_id_sf if project_id_sf else existing_record.get('Project_ID__c'),
            'Reflection_Log__c': reflection_log if reflection_log else existing_record.get('Reflection_Log__c', ''),
            'Download_Link__c': download_link if download_link else existing_record.get('Download_Link__c', '')
        }
        sf.Supervisor_AI_Coaching__c.update(record_id, update_data)
        status = f"Updated Salesforce record {record_id}"

        return checklist, tips, kpi_flag, status
    except Exception as e:
        return f"Error: {str(e)}", "", False, ""

# Define Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# AI Coach for Site Supervisors")
    gr.Markdown("Enter a Project ID to view Supervisor_AI_Coaching__c records and generate checklists.")

    with gr.Tab("Login"):
        project_id_input = gr.Textbox(label="Project ID (Salesforce Project__c ID)", placeholder="Enter Project ID")
        login_button = gr.Button("Submit")
        records_output = gr.Textbox(label="Records", lines=10)
        login_button.click(
            fn=login_and_display,
            inputs=project_id_input,
            outputs=[records_output, gr.Textbox(visible=False), gr.Checkbox(visible=False)]
        )

    with gr.Tab("Generate Checklist"):
        record_id = gr.Textbox(label="Record ID", placeholder="Enter Record ID from above")
        role = gr.Textbox(label="Role", value="Supervisor")
        project_id = gr.Textbox(label="Project ID", value="P001")
        project_name = gr.Textbox(label="Project Name", value="Building A")
        milestones = gr.Textbox(label="Milestones", value="Complete foundation by 5/15")
        supervisor_id = gr.Textbox(label="Supervisor ID (Salesforce User ID, optional)", value="")
        project_id_sf = gr.Textbox(label="Project ID (Salesforce Project__c ID, optional)", value="")
        reflection_log = gr.Textbox(label="Reflection Log (optional)", value="")
        download_link = gr.Textbox(label="Download Link (optional)", value="")
        generate_button = gr.Button("Generate and Update")
        checklist_output = gr.Textbox(label="Checklist")
        tips_output = gr.Textbox(label="Tips")
        kpi_flag_output = gr.Checkbox(label="KPI Flag")
        status_output = gr.Textbox(label="Salesforce Status")
        generate_button.click(
            fn=gradio_generate_checklist,
            inputs=[record_id, role, project_id, project_name, milestones, supervisor_id, project_id_sf, reflection_log, download_link],
            outputs=[checklist_output, tips_output, kpi_flag_output, status_output]
        )

# Mount FastAPI
iface.app = app

if __name__ == "__main__":
    try:
        iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
    except Exception as e:
        print(f"Failed to launch Gradio: {str(e)}")