File size: 11,425 Bytes
05be5a5
 
 
 
 
bb49e0d
05be5a5
a1f4a1e
 
05be5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c13ce0c
 
cf3d6df
 
05be5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
cf3d6df
 
05be5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c13ce0c
05be5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
bb49e0d
05be5a5
 
 
 
c13ce0c
05be5a5
c458a5a
 
bb49e0d
c458a5a
05be5a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fef987b
05be5a5
 
 
 
 
 
ccc081e
05be5a5
ccc081e
05be5a5
 
 
ccc081e
05be5a5
c38c187
 
05be5a5
 
 
 
 
 
 
 
ccc081e
c38c187
 
05be5a5
 
 
 
 
ccc081e
05be5a5
 
 
ccc081e
 
 
 
 
05be5a5
c38c187
 
05be5a5
ccc081e
05be5a5
 
ccc081e
 
 
 
 
 
 
 
 
 
c38c187
 
 
ccc081e
c38c187
 
 
 
ccc081e
c38c187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccc081e
 
 
c38c187
 
 
ccc081e
 
 
c38c187
ccc081e
05be5a5
 
 
c38c187
 
3bc1feb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# utils.py
import os
import shutil
from glob import glob
from typing import List, Union
from . import utils

os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def get_abs_path(relative_path: str) -> str:
    """Convert relative path to absolute path."""
    return os.path.abspath(relative_path)

def get_image_paths(directories: Union[str, List[str]]) -> List[str]:
    """
    Get all image paths from given directories.
    
    Args:
        directories: Single directory path or list of directory paths
        
    Returns:
        List of image file paths
    """
    if isinstance(directories, str):
        directories = [directories]
    
    all_images = []
    for directory in directories:
        abs_dir = get_abs_path(directory)
        if not os.path.isdir(abs_dir):
            print(f"⚠️ Warning: Skipping non-directory {abs_dir}")
            continue
            
        # Support multiple image extensions
        for ext in Config.SUPPORTED_EXTENSIONS:
            pattern = os.path.join(abs_dir, f'*.{ext}')
            images = sorted(glob(pattern))
            all_images.extend(images)
    
    return list(set(all_images))  # Remove duplicates

def backup_file(source_path: str, backup_path: str) -> str:
    """Backup a file to specified location."""
    backup_path = get_abs_path(backup_path)
    os.makedirs(os.path.dirname(backup_path), exist_ok=True)
    shutil.copy(source_path, backup_path)
    print(f"βœ… File backed up to: {backup_path}")
    return backup_path

# yolo_manager.py
import os
import cv2
from ultralytics import YOLO
from typing import List, Optional, Dict, Any
from .utils import get_abs_path, clean_directory
from .config import Config
from dotenv import load_dotenv
load_dotenv()

class YOLOManager:
    """Manages YOLO model training and inference operations."""
    
    def __init__(self, model_name: Optional[str] = None):
        self.model_name = model_name or Config.YOLO_MODEL_NAME
        self.model = None
    
    def load_model(self, weights_path: Optional[str] = None) -> YOLO:
        """Load YOLO model from weights or pretrained model."""
        if weights_path and os.path.isfile(weights_path):
            print(f"πŸ“¦ Loading model from: {weights_path}")
            self.model = YOLO(weights_path)
        else:
            print(f"✨ Loading pretrained model '{Config.yolo_base_model_path}'")
            self.model = YOLO(f"{Config.yolo_base_model_path}")
        return self.model
    
    def train(self, 
              data_yaml_path: str,
              run_name: Optional[str] = None,
              device: int = 0,
              resume: bool = True,
              **kwargs) -> YOLO:
        """
        Train YOLO model with given parameters.
        
        Args:
            data_yaml_path: Path to dataset YAML file
            run_name: Name for the training run
            device: Device to use for training
            resume: Whether to resume from checkpoint if available
            **kwargs: Additional training parameters
        """
        run_name = run_name or self.model_name
        checkpoint_path = f"{Config.current_path}/runs/detect/{run_name}/weights/last.pt"
        
        # Check for existing checkpoint
        if resume and os.path.isfile(checkpoint_path):
            print(f"πŸ”„ Resuming training from checkpoint: {checkpoint_path}")
            self.model = YOLO(checkpoint_path)
            resume_flag = True
        else:
            self.load_model()
            resume_flag = False
        
        # Default training parameters
        train_params = {
            'data': data_yaml_path,
            'imgsz': Config.DEFAULT_IMAGE_SIZE,
            'epochs': Config.EPOCH,
            'batch': 10,
            'name': run_name,
            'device': device,
            'cache': True,
            'project': f'{Config.current_path}/runs/detect',
            'exist_ok': True,
            'pose': False,
            'resume': resume_flag,
            'save_period': 10,
            'amp': False,  # 🚫 Disable AMP to prevent yolo11n.pt download
        }
        
        # Update with custom parameters
        train_params.update(kwargs)
        
        print(f"πŸš€ Starting training with parameters: {train_params}")
        self.model.train(**train_params)
        return self.model
    
    def validate(self) -> Dict[str, Any]:
        """Validate the model and return metrics."""
        if not self.model:
            raise ValueError("❌ No model loaded. Please train or load a model first.")
        
        metrics = self.model.val()
        print("πŸ“Š Validation Metrics:", metrics)
        return metrics
    
    def get_best_weights_path(self, run_name: Optional[str] = None) -> str:
        """Get path to best trained weights."""
        run_name = run_name or self.model_name
        weights_path = os.path.join(Config.current_path, 'runs', 'detect', run_name, 'weights', 'best.pt')
        
        if not os.path.isfile(weights_path):
            raise FileNotFoundError(f"❌ Trained weights not found at: {weights_path}")
        
        return weights_path
    
    def annotate_images(self, image_paths: List[str], output_dir: str = 'temp_dir', image_size: int = None, save_image: bool = True, label_path: str = None) -> None:
        """
        Annotate images with model predictions and save YOLO-format label files.
        
        Args:
            image_paths: List of image file paths
            output_dir: Directory to save annotated images and labels
            image_size: Size for inference
            save_image: Whether to save annotated images
            label_path: Optional specific path for label file
        """
        if not self.model:
            raise ValueError("❌ No model loaded. Please load a model first.")
        
        if not image_paths:
            raise ValueError("❌ No images provided for annotation.")
        
        image_size = image_size or Config.DEFAULT_IMAGE_SIZE
        # clean_directory(output_dir)
        total_images = len(image_paths)
        print(f"🎨 Annotating {total_images} images and saving labels...")
        
        for idx, image_path in enumerate(image_paths):
            if not os.path.isfile(image_path):
                print(f"⚠️ Warning: Skipping non-existent file {image_path}")
                continue
            
            print(f'πŸ” Processing ({idx+1}/{len(image_paths)}): {os.path.basename(image_path)}')
            
            try:
                # Load image for size info
                img = cv2.imread(image_path)
                h, w = img.shape[:2]
                
                # Run inference
                results = self.model(image_path, imgsz=image_size)
                result = results[0]
                annotated_frame = result.plot()
                
                # Prepare save paths
                original_name = os.path.basename(image_path)
                name, ext = os.path.splitext(original_name)

                save_img_path = None
                save_txt_path = os.path.join(output_dir, f'{name}.txt')  # YOLO label txt
                if save_image:
                    save_img_path = os.path.join(output_dir, f'annotated_{name}{ext}')
                    # Save annotated image
                    cv2.imwrite(save_img_path, annotated_frame)

                # Write YOLO label file
                with open(save_txt_path, 'w') as f:
                    # Check if we have segmentation masks (YOLO-seg model)
                    if hasattr(result, 'masks') and result.masks is not None:
                        print(f"πŸ“ Processing segmentation masks...")

                        # Process segmentation masks
                        masks = result.masks
                        for i, mask in enumerate(masks.xy):  # masks.xy gives polygon coordinates
                            cls_id = int(result.boxes.cls[i].item())

                            # mask is already in pixel coordinates
                            # Normalize coordinates to [0,1] range
                            normalized_coords = []
                            for point in mask:
                                x_norm = point[0] / w
                                y_norm = point[1] / h
                                normalized_coords.extend([x_norm, y_norm])

                            # Write segmentation format: class_id x1 y1 x2 y2 x3 y3 ...
                            coords_str = ' '.join(f'{coord:.6f}' for coord in normalized_coords)
                            f.write(f"{cls_id} {coords_str}\n")

                    # Fallback to bounding boxes if no masks (YOLO detection model)
                    elif hasattr(result, 'boxes') and result.boxes is not None:
                        print(f"πŸ“¦ Processing bounding boxes...")

                        for box in result.boxes:
                            # box.xyxy format: (xmin, ymin, xmax, ymax)
                            xyxy = box.xyxy[0].tolist()
                            cls_id = int(box.cls[0].item())

                            xmin, ymin, xmax, ymax = xyxy
                            # Convert to YOLO format (normalized)
                            x_center = ((xmin + xmax) / 2) / w
                            y_center = ((ymin + ymax) / 2) / h
                            width = (xmax - xmin) / w
                            height = (ymax - ymin) / h

                            # Write bounding box format: class_id x_center y_center width height
                            f.write(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")

                    else:
                        print("⚠️ No detections found in this image")

                if label_path:
                    shutil.copyfile(save_txt_path, label_path)

                if save_img_path:
                    print(f'βœ… Saved annotated image: {save_img_path}')
                print(f'βœ… Saved label file: {save_txt_path}')
                print(f"πŸŽ‰ Annotation and label saving complete! Results saved to: {output_dir}")

                if total_images == 1:
                    return save_img_path, save_txt_path
                
            except Exception as e:
                print(f"❌ Error processing {image_path}: {str(e)}")
                if total_images == 1:
                    return None, None

    def __enter__(self):
        # When entering context, just return self
        return self

    def __del__(self):
        # On exit, unload model and clear cache
        self.unload_model()

    def __exit__(self, exc_type, exc_value, traceback):
        # On exit, unload model and clear cache
        self.unload_model()

    def unload_model(self):
        if self.model is not None:
            print("🧹 Unloading YOLO model and clearing CUDA cache...")
            try:
                import torch
                import gc
                del self.model
                self.model = None
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.ipc_collect()
                print("βœ… Model unloaded and GPU cache cleared.")
            except Exception as e:
                print(f"❌ Error unloading model: {e}")
        else:
            print("⚠️ No model loaded to unload.")