jebin2's picture
inf all
fbf6388
# inference.py
from .yolo_manager import YOLOManager
from .utils import get_abs_path, get_image_paths
import os
from .config import Config
def run_inference(weights_path: str, images_dirs, output_dir: str = 'temp_dir') -> None:
"""
Run inference on images using trained model.
Args:
weights_path: Path to model weights
images_dirs: Directory or list of directories containing images
output_dir: Directory to save annotated results
"""
try:
# Validate weights file
weights_path = get_abs_path(weights_path)
if not os.path.isfile(weights_path):
raise FileNotFoundError(f"❌ Weights file not found: {weights_path}")
# Get image paths
image_paths = get_image_paths(images_dirs)
if not image_paths:
raise ValueError("❌ No images found in the provided directories.")
print(f"πŸ” Found {len(image_paths)} images for inference")
# Initialize YOLO manager and load model
yolo_manager = YOLOManager()
yolo_manager.load_model(weights_path)
# Run inference
yolo_manager.annotate_images(image_paths, output_dir)
print("πŸŽ‰ Inference completed successfully!")
except Exception as e:
print(f"❌ Inference failed: {str(e)}")
raise
def main():
"""Main inference function."""
weights_path = Config.yolo_trained_model_path
images_dirs = [
'./dataset/images/train',
'./dataset/images/val',
'./dataset/images/test'
]
run_inference(weights_path, images_dirs, './temp_dir')
def annotate_all_image():
with YOLOManager() as yolo_manager:
weights_path = Config.yolo_trained_model_path
yolo_manager.load_model(weights_path)
IMAGE_ROOT = os.path.join(Config.current_path, "dataset/images")
IMAGE_LABEL_ROOT = os.path.join(Config.current_path, "image_labels")
for root, _, files in os.walk(IMAGE_ROOT):
for file in sorted(files):
if file.lower().endswith((".jpg", ".jpeg", ".png")):
name, ext = os.path.splitext(file)
save_txt_path = os.path.join(IMAGE_LABEL_ROOT, f'{name}.txt') # YOLO label txt
if not os.path.exists(save_txt_path):
image_path = os.path.join(root, file)
yolo_manager.annotate_images(
image_paths=[image_path],
output_dir=IMAGE_LABEL_ROOT,
save_image=False
)
if __name__ == "__main__":
annotate_all_image()