|
import importlib |
|
import os |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from queue import Queue |
|
from types import ModuleType |
|
from typing import Any, List |
|
|
|
from tqdm import tqdm |
|
|
|
from facefusion import logger, state_manager, wording |
|
from facefusion.exit_helper import hard_exit |
|
from facefusion.types import ProcessFrames, QueuePayload |
|
|
|
PROCESSORS_METHODS =\ |
|
[ |
|
'get_inference_pool', |
|
'clear_inference_pool', |
|
'register_args', |
|
'apply_args', |
|
'pre_check', |
|
'pre_process', |
|
'post_process', |
|
'get_reference_frame', |
|
'process_frame', |
|
'process_frames', |
|
'process_image', |
|
'process_video' |
|
] |
|
|
|
|
|
def load_processor_module(processor : str) -> Any: |
|
try: |
|
processor_module = importlib.import_module('facefusion.processors.modules.' + processor) |
|
for method_name in PROCESSORS_METHODS: |
|
if not hasattr(processor_module, method_name): |
|
raise NotImplementedError |
|
except ModuleNotFoundError as exception: |
|
logger.error(wording.get('processor_not_loaded').format(processor = processor), __name__) |
|
logger.debug(exception.msg, __name__) |
|
hard_exit(1) |
|
except NotImplementedError: |
|
logger.error(wording.get('processor_not_implemented').format(processor = processor), __name__) |
|
hard_exit(1) |
|
return processor_module |
|
|
|
|
|
def get_processors_modules(processors : List[str]) -> List[ModuleType]: |
|
processor_modules = [] |
|
|
|
for processor in processors: |
|
processor_module = load_processor_module(processor) |
|
processor_modules.append(processor_module) |
|
return processor_modules |
|
|
|
|
|
def multi_process_frames(source_paths : List[str], temp_frame_paths : List[str], process_frames : ProcessFrames) -> None: |
|
queue_payloads = create_queue_payloads(temp_frame_paths) |
|
with tqdm(total = len(queue_payloads), desc = wording.get('processing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: |
|
progress.set_postfix(execution_providers = state_manager.get_item('execution_providers')) |
|
with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: |
|
futures = [] |
|
queue : Queue[QueuePayload] = create_queue(queue_payloads) |
|
queue_per_future = max(len(queue_payloads) // state_manager.get_item('execution_thread_count') * state_manager.get_item('execution_queue_count'), 1) |
|
|
|
while not queue.empty(): |
|
future = executor.submit(process_frames, source_paths, pick_queue(queue, queue_per_future), progress.update) |
|
futures.append(future) |
|
|
|
for future_done in as_completed(futures): |
|
future_done.result() |
|
|
|
|
|
def create_queue(queue_payloads : List[QueuePayload]) -> Queue[QueuePayload]: |
|
queue : Queue[QueuePayload] = Queue() |
|
for queue_payload in queue_payloads: |
|
queue.put(queue_payload) |
|
return queue |
|
|
|
|
|
def pick_queue(queue : Queue[QueuePayload], queue_per_future : int) -> List[QueuePayload]: |
|
queues = [] |
|
for _ in range(queue_per_future): |
|
if not queue.empty(): |
|
queues.append(queue.get()) |
|
return queues |
|
|
|
|
|
def create_queue_payloads(temp_frame_paths : List[str]) -> List[QueuePayload]: |
|
queue_payloads = [] |
|
temp_frame_paths = sorted(temp_frame_paths, key = os.path.basename) |
|
|
|
for frame_number, frame_path in enumerate(temp_frame_paths): |
|
frame_payload : QueuePayload =\ |
|
{ |
|
'frame_number': frame_number, |
|
'frame_path': frame_path |
|
} |
|
queue_payloads.append(frame_payload) |
|
return queue_payloads |
|
|