Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
""" | |
Endpoint management for video generation services. | |
""" | |
import time | |
import datetime | |
import logging | |
from asyncio import Lock | |
from contextlib import asynccontextmanager | |
from typing import List | |
from .models import Endpoint | |
from .api_config import VIDEO_ROUND_ROBIN_ENDPOINT_URLS | |
logger = logging.getLogger(__name__) | |
class EndpointManager: | |
"""Manages multiple video generation endpoints with load balancing and error handling.""" | |
def __init__(self): | |
self.endpoints: List[Endpoint] = [] | |
self.lock = Lock() | |
self.initialize_endpoints() | |
self.last_used_index = -1 # Track the last used endpoint for round-robin | |
def initialize_endpoints(self): | |
"""Initialize the list of endpoints""" | |
for i, url in enumerate(VIDEO_ROUND_ROBIN_ENDPOINT_URLS): | |
endpoint = Endpoint(id=i + 1, url=url) | |
self.endpoints.append(endpoint) | |
def _get_next_free_endpoint(self): | |
"""Get the next available non-busy endpoint, or oldest endpoint if all are busy""" | |
current_time = time.time() | |
# First priority: Get any non-busy and non-error endpoint | |
free_endpoints = [ | |
ep for ep in self.endpoints | |
if not ep.busy and current_time > ep.error_until | |
] | |
if free_endpoints: | |
# Return the least recently used free endpoint | |
return min(free_endpoints, key=lambda ep: ep.last_used) | |
# Second priority: If all busy/error, use round-robin but skip error endpoints | |
tried_count = 0 | |
next_index = self.last_used_index | |
while tried_count < len(self.endpoints): | |
next_index = (next_index + 1) % len(self.endpoints) | |
tried_count += 1 | |
# If endpoint is not in error state, use it | |
if current_time > self.endpoints[next_index].error_until: | |
self.last_used_index = next_index | |
return self.endpoints[next_index] | |
# If all endpoints are in error state, use the one with earliest error expiry | |
self.last_used_index = next_index | |
return min(self.endpoints, key=lambda ep: ep.error_until) | |
async def get_endpoint(self, max_wait_time: int = 10): | |
"""Get the next available endpoint using a context manager""" | |
start_time = time.time() | |
endpoint = None | |
try: | |
while True: | |
if time.time() - start_time > max_wait_time: | |
raise TimeoutError(f"Could not acquire an endpoint within {max_wait_time} seconds") | |
async with self.lock: | |
# Get the next available endpoint using our selection strategy | |
endpoint = self._get_next_free_endpoint() | |
# Mark it as busy | |
endpoint.busy = True | |
endpoint.last_used = time.time() | |
break | |
yield endpoint | |
finally: | |
if endpoint: | |
async with self.lock: | |
endpoint.busy = False | |
endpoint.last_used = time.time() | |
async def mark_endpoint_error(self, endpoint: Endpoint, is_timeout: bool = False): | |
"""Mark an endpoint as being in error state with exponential backoff""" | |
async with self.lock: | |
endpoint.error_count += 1 | |
# Calculate backoff time exponentially based on error count | |
# Start with 15 seconds, then 30, 60, etc. up to a max of 5 minutes | |
# Using shorter backoffs since generation should be fast | |
backoff_seconds = min(15 * (2 ** (endpoint.error_count - 1)), 300) | |
# Add extra backoff for timeouts which are more indicative of serious issues | |
if is_timeout: | |
backoff_seconds *= 2 | |
endpoint.error_until = time.time() + backoff_seconds | |
logger.warning( | |
f"Endpoint {endpoint.id} marked as in error state (count: {endpoint.error_count}, " | |
f"unavailable until: {datetime.datetime.fromtimestamp(endpoint.error_until).strftime('%H:%M:%S')})" | |
) |