Spaces:
Sleeping
Sleeping
File size: 7,645 Bytes
0135475 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
# Ultralytics YOLO π, GPL-3.0 license
from copy import copy
import torch
import torch.nn.functional as F
from ultralytics.nn.tasks import SegmentationModel
from ultralytics.yolo import v8
from ultralytics.yolo.utils import DEFAULT_CFG, RANK
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import make_anchors
from ultralytics.yolo.utils.torch_utils import de_parallel
from ultralytics.yolo.v8.detect.train import Loss
# BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None:
overrides = {}
overrides['task'] = 'segment'
super().__init__(cfg, overrides)
def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch):
if not hasattr(self, 'compute_loss'):
self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
return self.compute_loss(preds, batch)
def plot_training_samples(self, batch, ni):
images = batch['img']
masks = batch['masks']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
def plot_metrics(self):
plot_results(file=self.csv, segment=True) # save results.png
# Criterion class for computing training losses
class SegLoss(Loss):
def __init__(self, model, overlap=True): # model must be de-paralleled
super().__init__(model)
self.nm = model.model[-1].nm # number of masks
self.overlap = overlap
def __call__(self, preds, batch):
loss = torch.zeros(4, device=self.device) # box, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].to(dtype)), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
except RuntimeError as e:
raise TypeError('ERROR β segment dataset incorrectly formatted or not a segment dataset.\n'
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
# pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
target_scores_sum = max(target_scores.sum(), 1)
# cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
if fg_mask.sum():
# bbox loss
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
target_scores, target_scores_sum, fg_mask)
# masks loss
masks = batch['masks'].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
for i in range(batch_size):
if fg_mask[i].sum():
mask_idx = target_gt_idx[i][fg_mask[i]]
if self.overlap:
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box / batch_size # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
# Mask loss for one image
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
def train(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = SegmentationTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()
|