import os import requests from typing import Optional import uvicorn from subprocess import Popen import yaml import datetime from fastapi import FastAPI, Header, BackgroundTasks from fastapi.responses import FileResponse from huggingface_hub.hf_api import HfApi from src.models import config, WebhookPayload app = FastAPI() WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET") HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN") @app.get("/") async def home(): return FileResponse("home.html") @app.post("/webhook") async def post_webhook( payload: WebhookPayload, task_queue: BackgroundTasks, x_webhook_secret: Optional[str] = Header(default=None), ): # if x_webhook_secret is None: # raise HTTPException(401) # if x_webhook_secret != WEBHOOK_SECRET: # raise HTTPException(403) # if not ( # payload.event.action == "update" # and payload.event.scope.startswith("repo.content") # and payload.repo.name == config.input_dataset # and payload.repo.type == "dataset" # ): # # no-op # return {"processed": False} schedule_retrain(payload=payload) # task_queue.add_task( # schedule_retrain, # payload # ) return {"processed": True} def schedule_retrain(payload: WebhookPayload): # Create the autotrain project try: yaml_path = os.path.join(os.getcwd(), "config.yaml") with open(yaml_path) as f: list_doc = yaml.safe_load(f) list_doc['project_name'] = datetime.datetime.now().isoformat() with open(yaml_path) as f: yaml.dump(list_doc, f, default_flow_style=False) result = Popen(['autotrain', '--config', yaml_path]) # project = AutoTrain.create_project(payload) # AutoTrain.add_data(project_id=project["id"]) # AutoTrain.start_processing(project_id=project["id"]) except requests.HTTPError as err: print("ERROR while requesting AutoTrain API:") print(f" code: {err.response.status_code}") print(f" {err.response.json()}") raise # Notify in the community tab notify_success('vicuna') print(result.returncode) return {"processed": True} def notify_success(project_id: str): message = NOTIFICATION_TEMPLATE.format( input_model=config.input_model, input_dataset=config.input_dataset, project_id=project_id, ) return HfApi(token=HF_ACCESS_TOKEN).create_discussion( repo_id=config.input_dataset, repo_type="dataset", title="✨ Retraining started!", description=message, token=HF_ACCESS_TOKEN, ) NOTIFICATION_TEMPLATE = """\ 🌸 Hello there! Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain! (This is an automated message) """ if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)