Spaces:
Build error
Build error
import os | |
import requests | |
from typing import Optional | |
import uvicorn | |
from subprocess import Popen | |
import yaml | |
import datetime | |
from fastapi import FastAPI, Header, BackgroundTasks | |
from fastapi.responses import FileResponse | |
from huggingface_hub.hf_api import HfApi | |
from src.models import config, WebhookPayload | |
app = FastAPI() | |
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET") | |
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN") | |
async def home(): | |
return FileResponse("home.html") | |
async def post_webhook( | |
payload: WebhookPayload, | |
task_queue: BackgroundTasks, | |
x_webhook_secret: Optional[str] = Header(default=None), | |
): | |
# if x_webhook_secret is None: | |
# raise HTTPException(401) | |
# if x_webhook_secret != WEBHOOK_SECRET: | |
# raise HTTPException(403) | |
# if not ( | |
# payload.event.action == "update" | |
# and payload.event.scope.startswith("repo.content") | |
# and payload.repo.name == config.input_dataset | |
# and payload.repo.type == "dataset" | |
# ): | |
# # no-op | |
# return {"processed": False} | |
schedule_retrain(payload=payload) | |
# task_queue.add_task( | |
# schedule_retrain, | |
# payload | |
# ) | |
return {"processed": True} | |
def schedule_retrain(payload: WebhookPayload): | |
# Create the autotrain project | |
try: | |
yaml_path = os.path.join(os.getcwd(), "config.yaml") | |
with open(yaml_path) as f: | |
list_doc = yaml.safe_load(f) | |
list_doc['project_name'] = datetime.datetime.now().isoformat() | |
with open(yaml_path) as f: | |
yaml.dump(list_doc, f, default_flow_style=False) | |
result = Popen(['autotrain', '--config', yaml_path]) | |
# project = AutoTrain.create_project(payload) | |
# AutoTrain.add_data(project_id=project["id"]) | |
# AutoTrain.start_processing(project_id=project["id"]) | |
except requests.HTTPError as err: | |
print("ERROR while requesting AutoTrain API:") | |
print(f" code: {err.response.status_code}") | |
print(f" {err.response.json()}") | |
raise | |
# Notify in the community tab | |
notify_success('vicuna') | |
print(result.returncode) | |
return {"processed": True} | |
def notify_success(project_id: str): | |
message = NOTIFICATION_TEMPLATE.format( | |
input_model=config.input_model, | |
input_dataset=config.input_dataset, | |
project_id=project_id, | |
) | |
return HfApi(token=HF_ACCESS_TOKEN).create_discussion( | |
repo_id=config.input_dataset, | |
repo_type="dataset", | |
title="✨ Retraining started!", | |
description=message, | |
token=HF_ACCESS_TOKEN, | |
) | |
NOTIFICATION_TEMPLATE = """\ | |
🌸 Hello there! | |
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain! | |
(This is an automated message) | |
""" | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |