Spaces:
Build error
Build error
# inference.py | |
from .yolo_manager import YOLOManager | |
from .utils import get_abs_path, get_image_paths | |
import os | |
from .config import Config | |
def run_inference(weights_path: str, images_dirs, output_dir: str = 'temp_dir') -> None: | |
""" | |
Run inference on images using trained model. | |
Args: | |
weights_path: Path to model weights | |
images_dirs: Directory or list of directories containing images | |
output_dir: Directory to save annotated results | |
""" | |
try: | |
# Validate weights file | |
weights_path = get_abs_path(weights_path) | |
if not os.path.isfile(weights_path): | |
raise FileNotFoundError(f"β Weights file not found: {weights_path}") | |
# Get image paths | |
image_paths = get_image_paths(images_dirs) | |
if not image_paths: | |
raise ValueError("β No images found in the provided directories.") | |
print(f"π Found {len(image_paths)} images for inference") | |
# Initialize YOLO manager and load model | |
yolo_manager = YOLOManager() | |
yolo_manager.load_model(weights_path) | |
# Run inference | |
yolo_manager.annotate_images(image_paths, output_dir) | |
print("π Inference completed successfully!") | |
except Exception as e: | |
print(f"β Inference failed: {str(e)}") | |
raise | |
def main(): | |
"""Main inference function.""" | |
weights_path = Config.yolo_trained_model_path | |
images_dirs = [ | |
'./dataset/images/train', | |
'./dataset/images/val', | |
'./dataset/images/test' | |
] | |
run_inference(weights_path, images_dirs, './temp_dir') | |
def annotate_all_image(): | |
with YOLOManager() as yolo_manager: | |
weights_path = Config.yolo_trained_model_path | |
yolo_manager.load_model(weights_path) | |
IMAGE_ROOT = os.path.join(Config.current_path, "dataset/images") | |
IMAGE_LABEL_ROOT = os.path.join(Config.current_path, "image_labels") | |
for root, _, files in os.walk(IMAGE_ROOT): | |
for file in sorted(files): | |
if file.lower().endswith((".jpg", ".jpeg", ".png")): | |
name, ext = os.path.splitext(file) | |
save_txt_path = os.path.join(IMAGE_LABEL_ROOT, f'{name}.txt') # YOLO label txt | |
if not os.path.exists(save_txt_path): | |
image_path = os.path.join(root, file) | |
yolo_manager.annotate_images( | |
image_paths=[image_path], | |
output_dir=IMAGE_LABEL_ROOT, | |
save_image=False | |
) | |
if __name__ == "__main__": | |
annotate_all_image() |