diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..64966c7640f2792e5611250eb1ba050f7954aff8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,25 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/acoustics/gsam_whisper_inpainting_demo.png filter=lfs diff=lfs merge=lfs -text +assets/acoustics/gsam_whisper_inpainting_pipeline.png filter=lfs diff=lfs merge=lfs -text +assets/demo9.jpg filter=lfs diff=lfs merge=lfs -text +assets/gradio_demo.png filter=lfs diff=lfs merge=lfs -text +assets/grounded_sam_demo3_demo4.png filter=lfs diff=lfs merge=lfs -text +assets/grounded_sam_inpainting_demo.png filter=lfs diff=lfs merge=lfs -text +assets/grounded_sam_new_demo_image.png filter=lfs diff=lfs merge=lfs -text +assets/mask_3dbox.png filter=lfs diff=lfs merge=lfs -text +assets/osx/grounded_sam_osx_demo.png filter=lfs diff=lfs merge=lfs -text +assets/osx/grouned_sam_osx_demo.gif filter=lfs diff=lfs merge=lfs -text +assets/ram_grounded_sam_new.png filter=lfs diff=lfs merge=lfs -text +EfficientSAM/LightHQSAM/example_light_hqsam.png filter=lfs diff=lfs merge=lfs -text +GroundingDINO/.asset/GD_GLIGEN.png filter=lfs diff=lfs merge=lfs -text +GroundingDINO/.asset/GD_SD.png filter=lfs diff=lfs merge=lfs -text +GroundingDINO/.asset/hero_figure.png filter=lfs diff=lfs merge=lfs -text +segment_anything/assets/masks1.png filter=lfs diff=lfs merge=lfs -text +segment_anything/assets/notebook2.png filter=lfs diff=lfs merge=lfs -text +voxelnext_3d_box/images/image_boxes1.png filter=lfs diff=lfs merge=lfs -text +voxelnext_3d_box/images/image_boxes2.png filter=lfs diff=lfs merge=lfs -text +voxelnext_3d_box/images/image_boxes3.png filter=lfs diff=lfs merge=lfs -text +voxelnext_3d_box/images/mask_box.png filter=lfs diff=lfs merge=lfs -text +voxelnext_3d_box/images/sam-voxelnext.png filter=lfs diff=lfs merge=lfs -text diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000000000000000000000000000000000000..0c3221a96e68e96b5fd69a8abae833895fb7923d --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,8 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: + - name: "Grounded-SAM Contributors" +title: "Grounded-Segment-Anything" +date-released: 2023-04-06 +url: "https://github.com/IDEA-Research/Grounded-Segment-Anything" +license: Apache-2.0 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3383b82a5995b722ccd413a111b3f6a9edc051fa --- /dev/null +++ b/Dockerfile @@ -0,0 +1,26 @@ +FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel + +# Arguments to build Docker Image using CUDA +ARG USE_CUDA=0 +ARG TORCH_ARCH= + +ENV AM_I_DOCKER True +ENV BUILD_WITH_CUDA "${USE_CUDA}" +ENV TORCH_CUDA_ARCH_LIST "${TORCH_ARCH}" +ENV CUDA_HOME /usr/local/cuda-11.6/ + +RUN mkdir -p /home/appuser/Grounded-Segment-Anything +COPY . /home/appuser/Grounded-Segment-Anything/ + +RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \ + libsm6=2:* libxext6=2:* git=1:* nano=2.* \ + vim=2:* -y \ + && apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/* + +WORKDIR /home/appuser/Grounded-Segment-Anything +RUN python -m pip install --no-cache-dir -e segment_anything && \ + python -m pip install --no-cache-dir -e GroundingDINO +WORKDIR /home/appuser +RUN pip install --no-cache-dir diffusers[torch]==0.15.1 opencv-python==4.7.0.72 \ + pycocotools==2.0.6 matplotlib==3.5.3 \ + onnxruntime==1.14.1 onnx==1.13.1 ipykernel==6.16.2 scipy gradio openai diff --git a/EfficientSAM/EdgeSAM/common.py b/EfficientSAM/EdgeSAM/common.py new file mode 100644 index 0000000000000000000000000000000000000000..be321e5384b3e65c77cb3acf1a4e4b68d8de823d --- /dev/null +++ b/EfficientSAM/EdgeSAM/common.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def val2list(x: list or tuple or any, repeat_time=1) -> list: + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: + x = val2list(x) + + # repeat elements if necessary + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def list_sum(x: list) -> any: + return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) + + +def resize( + x: torch.Tensor, + size: any or None = None, + scale_factor=None, + mode: str = "bicubic", + align_corners: bool or None = False, +) -> torch.Tensor: + if mode in ["bilinear", "bicubic"]: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ) + elif mode in ["nearest", "area"]: + return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) + else: + raise NotImplementedError(f"resize(mode={mode}) not implemented.") + + +class UpSampleLayer(nn.Module): + def __init__( + self, + mode="bicubic", + size=None, + factor=2, + align_corners=False, + ): + super(UpSampleLayer, self).__init__() + self.mode = mode + self.size = val2list(size, 2) if size is not None else None + self.factor = None if self.size is not None else factor + self.align_corners = align_corners + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return resize(x, self.size, self.factor, self.mode, self.align_corners) + + +class OpSequential(nn.Module): + def __init__(self, op_list): + super(OpSequential, self).__init__() + valid_op_list = [] + for op in op_list: + if op is not None: + valid_op_list.append(op) + self.op_list = nn.ModuleList(valid_op_list) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for op in self.op_list: + x = op(x) + return x \ No newline at end of file diff --git a/EfficientSAM/EdgeSAM/rep_vit.py b/EfficientSAM/EdgeSAM/rep_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e9ed2e3efe0df679325cd95d9be8a192b734bf --- /dev/null +++ b/EfficientSAM/EdgeSAM/rep_vit.py @@ -0,0 +1,370 @@ +import torch.nn as nn +from EdgeSAM.common import LayerNorm2d, UpSampleLayer, OpSequential + +__all__ = ['rep_vit_m1', 'rep_vit_m2', 'rep_vit_m3', 'RepViT'] + +m1_cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 48, 1, 0, 1], + [3, 2, 48, 0, 0, 1], + [3, 2, 48, 0, 0, 1], + [3, 2, 96, 0, 0, 2], + [3, 2, 96, 1, 0, 1], + [3, 2, 96, 0, 0, 1], + [3, 2, 96, 0, 0, 1], + [3, 2, 192, 0, 1, 2], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 1, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 192, 0, 1, 1], + [3, 2, 384, 0, 1, 2], + [3, 2, 384, 1, 1, 1], + [3, 2, 384, 0, 1, 1] +] + +m2_cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 64, 1, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 128, 0, 0, 2], + [3, 2, 128, 1, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 256, 0, 1, 2], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 512, 0, 1, 2], + [3, 2, 512, 1, 1, 1], + [3, 2, 512, 0, 1, 1] +] + +m3_cfgs = [ + # k, t, c, SE, HS, s + [3, 2, 64, 1, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 64, 1, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 64, 0, 0, 1], + [3, 2, 128, 0, 0, 2], + [3, 2, 128, 1, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 128, 1, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 128, 0, 0, 1], + [3, 2, 256, 0, 1, 2], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 1, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 256, 0, 1, 1], + [3, 2, 512, 0, 1, 2], + [3, 2, 512, 1, 1, 1], + [3, 2, 512, 0, 1, 1] +] + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +from timm.models.layers import SqueezeExcite + +import torch + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', torch.nn.BatchNorm2d(b)) + torch.nn.init.constant_(self.bn.weight, bn_weight_init) + torch.nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps) ** 0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, + groups=self.c.groups, + device=c.weight.device) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class Residual(torch.nn.Module): + def __init__(self, m, drop=0.): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + @torch.no_grad() + def fuse(self): + if isinstance(self.m, Conv2d_BN): + m = self.m.fuse() + assert (m.groups == m.in_channels) + identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) + identity = torch.nn.functional.pad(identity, [1, 1, 1, 1]) + m.weight += identity.to(m.weight.device) + return m + elif isinstance(self.m, torch.nn.Conv2d): + m = self.m + assert (m.groups != m.in_channels) + identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) + identity = torch.nn.functional.pad(identity, [1, 1, 1, 1]) + m.weight += identity.to(m.weight.device) + return m + else: + return self + + +class RepVGGDW(torch.nn.Module): + def __init__(self, ed) -> None: + super().__init__() + self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed) + self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed) + self.dim = ed + + def forward(self, x): + return self.conv(x) + self.conv1(x) + x + + @torch.no_grad() + def fuse(self): + conv = self.conv.fuse() + conv1 = self.conv1.fuse() + + conv_w = conv.weight + conv_b = conv.bias + conv1_w = conv1.weight + conv1_b = conv1.bias + + conv1_w = torch.nn.functional.pad(conv1_w, [1, 1, 1, 1]) + + identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), + [1, 1, 1, 1]) + + final_conv_w = conv_w + conv1_w + identity + final_conv_b = conv_b + conv1_b + + conv.weight.data.copy_(final_conv_w) + conv.bias.data.copy_(final_conv_b) + return conv + + +class RepViTBlock(nn.Module): + def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs, skip_downsample=False): + super(RepViTBlock, self).__init__() + assert stride in [1, 2] + + self.identity = stride == 1 and inp == oup + assert (hidden_dim == 2 * inp) + + if stride == 2: + if skip_downsample: + stride = 1 + self.token_mixer = nn.Sequential( + Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp), + SqueezeExcite(inp, 0.25) if use_se else nn.Identity(), + Conv2d_BN(inp, oup, ks=1, stride=1, pad=0) + ) + self.channel_mixer = Residual(nn.Sequential( + # pw + Conv2d_BN(oup, 2 * oup, 1, 1, 0), + nn.GELU() if use_hs else nn.GELU(), + # pw-linear + Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0), + )) + else: + assert (self.identity) + self.token_mixer = nn.Sequential( + RepVGGDW(inp), + SqueezeExcite(inp, 0.25) if use_se else nn.Identity(), + ) + self.channel_mixer = Residual(nn.Sequential( + # pw + Conv2d_BN(inp, hidden_dim, 1, 1, 0), + nn.GELU() if use_hs else nn.GELU(), + # pw-linear + Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0), + )) + + def forward(self, x): + return self.channel_mixer(self.token_mixer(x)) + + +from timm.models.vision_transformer import trunc_normal_ + + +class BN_Linear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + self.add_module('l', torch.nn.Linear(a, b, bias=bias)) + trunc_normal_(self.l.weight, std=std) + if bias: + torch.nn.init.constant_(self.l.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class RepViT(nn.Module): + arch_settings = { + 'm1': m1_cfgs, + 'm2': m2_cfgs, + 'm3': m3_cfgs + } + + def __init__(self, arch, img_size=1024, upsample_mode='bicubic'): + super(RepViT, self).__init__() + # setting of inverted residual blocks + self.cfgs = self.arch_settings[arch] + self.img_size = img_size + + # building first layer + input_channel = self.cfgs[0][2] + patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(), + Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1)) + layers = [patch_embed] + # building inverted residual blocks + block = RepViTBlock + self.stage_idx = [] + prev_c = input_channel + for idx, (k, t, c, use_se, use_hs, s) in enumerate(self.cfgs): + output_channel = _make_divisible(c, 8) + exp_size = _make_divisible(input_channel * t, 8) + skip_downsample = False + if c != prev_c: + self.stage_idx.append(idx - 1) + prev_c = c + layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs, skip_downsample)) + input_channel = output_channel + self.stage_idx.append(idx) + self.features = nn.ModuleList(layers) + + stage2_channels = _make_divisible(self.cfgs[self.stage_idx[2]][2], 8) + stage3_channels = _make_divisible(self.cfgs[self.stage_idx[3]][2], 8) + self.fuse_stage2 = nn.Conv2d(stage2_channels, 256, kernel_size=1, bias=False) + self.fuse_stage3 = OpSequential([ + nn.Conv2d(stage3_channels, 256, kernel_size=1, bias=False), + UpSampleLayer(factor=2, mode=upsample_mode), + ]) + + self.neck = nn.Sequential( + nn.Conv2d(256, 256, kernel_size=1, bias=False), + LayerNorm2d(256), + nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), + LayerNorm2d(256), + ) + + def forward(self, x): + counter = 0 + output_dict = dict() + # patch_embed + x = self.features[0](x) + output_dict['stem'] = x + # stages + for idx, f in enumerate(self.features[1:]): + x = f(x) + if idx in self.stage_idx: + output_dict[f'stage{counter}'] = x + counter += 1 + + x = self.fuse_stage2(output_dict['stage2']) + self.fuse_stage3(output_dict['stage3']) + + x = self.neck(x) + # hack this place because we modified the predictor of SAM for HQ-SAM in + # segment_anything/segment_anything/predictor.py line 91 to return intern features of the backbone + # self.features, self.interm_features = self.model.image_encoder(input_image) + return x, None + + +def rep_vit_m1(img_size=1024, **kwargs): + return RepViT('m1', img_size, **kwargs) + + +def rep_vit_m2(img_size=1024, **kwargs): + return RepViT('m2', img_size, **kwargs) + + +def rep_vit_m3(img_size=1024, **kwargs): + return RepViT('m3', img_size, **kwargs) \ No newline at end of file diff --git a/EfficientSAM/EdgeSAM/setup_edge_sam.py b/EfficientSAM/EdgeSAM/setup_edge_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa99254fb901f6606e37d8e319efced8ff86223 --- /dev/null +++ b/EfficientSAM/EdgeSAM/setup_edge_sam.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from segment_anything.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer +from EdgeSAM.rep_vit import RepViT + + +prompt_embed_dim = 256 +image_size = 1024 +vit_patch_size = 16 +image_embedding_size = image_size // vit_patch_size + + +def build_edge_sam(checkpoint=None, upsample_mode="bicubic"): + image_encoder = RepViT( + arch="m1", + img_size=image_size, + upsample_mode=upsample_mode + ) + return _build_sam(image_encoder, checkpoint) + + +sam_model_registry = { + "default": build_edge_sam, + "edge_sam": build_edge_sam, +} + +def _build_sam_encoder( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, +): + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + return image_encoder + + +def _build_sam( + image_encoder, + checkpoint=None, +): + sam = Sam( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f, map_location="cpu") + sam.load_state_dict(state_dict) + return sam \ No newline at end of file diff --git a/EfficientSAM/FastSAM/tools.py b/EfficientSAM/FastSAM/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d43c4ea51ff16e7a9a595692e05ad78a40c69bd3 --- /dev/null +++ b/EfficientSAM/FastSAM/tools.py @@ -0,0 +1,413 @@ +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt +import cv2 +import torch +import os +import clip + + +def convert_box_xywh_to_xyxy(box): + x1 = box[0] + y1 = box[1] + x2 = box[0] + box[2] + y2 = box[1] + box[3] + return [x1, y1, x2, y2] + + +def segment_image(image, bbox): + image_array = np.array(image) + segmented_image_array = np.zeros_like(image_array) + x1, y1, x2, y2 = bbox + segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] + segmented_image = Image.fromarray(segmented_image_array) + black_image = Image.new("RGB", image.size, (255, 255, 255)) + # transparency_mask = np.zeros_like((), dtype=np.uint8) + transparency_mask = np.zeros( + (image_array.shape[0], image_array.shape[1]), dtype=np.uint8 + ) + transparency_mask[y1:y2, x1:x2] = 255 + transparency_mask_image = Image.fromarray(transparency_mask, mode="L") + black_image.paste(segmented_image, mask=transparency_mask_image) + return black_image + + +def format_results(result, filter=0): + annotations = [] + n = len(result.masks.data) + for i in range(n): + annotation = {} + mask = result.masks.data[i] == 1.0 + + if torch.sum(mask) < filter: + continue + annotation["id"] = i + annotation["segmentation"] = mask.cpu().numpy() + annotation["bbox"] = result.boxes.data[i] + annotation["score"] = result.boxes.conf[i] + annotation["area"] = annotation["segmentation"].sum() + annotations.append(annotation) + return annotations + + +def filter_masks(annotations): # filte the overlap mask + annotations.sort(key=lambda x: x["area"], reverse=True) + to_remove = set() + for i in range(0, len(annotations)): + a = annotations[i] + for j in range(i + 1, len(annotations)): + b = annotations[j] + if i != j and j not in to_remove: + # check if + if b["area"] < a["area"]: + if (a["segmentation"] & b["segmentation"]).sum() / b[ + "segmentation" + ].sum() > 0.8: + to_remove.add(j) + + return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove + + +def get_bbox_from_mask(mask): + mask = mask.astype(np.uint8) + contours, hierarchy = cv2.findContours( + mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + x1, y1, w, h = cv2.boundingRect(contours[0]) + x2, y2 = x1 + w, y1 + h + if len(contours) > 1: + for b in contours: + x_t, y_t, w_t, h_t = cv2.boundingRect(b) + # 将多个bbox合并成一个 + x1 = min(x1, x_t) + y1 = min(y1, y_t) + x2 = max(x2, x_t + w_t) + y2 = max(y2, y_t + h_t) + h = y2 - y1 + w = x2 - x1 + return [x1, y1, x2, y2] + + +def fast_process( + annotations, args, mask_random_color, bbox=None, points=None, edges=False +): + if isinstance(annotations[0], dict): + annotations = [annotation["segmentation"] for annotation in annotations] + result_name = os.path.basename(args.img_path) + image = cv2.imread(args.img_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + original_h = image.shape[0] + original_w = image.shape[1] + plt.figure(figsize=(original_w/100, original_h/100)) + plt.imshow(image) + if args.better_quality == True: + if isinstance(annotations[0], torch.Tensor): + annotations = np.array(annotations.cpu()) + for i, mask in enumerate(annotations): + mask = cv2.morphologyEx( + mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8) + ) + annotations[i] = cv2.morphologyEx( + mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8) + ) + if args.device == "cpu": + annotations = np.array(annotations) + fast_show_mask( + annotations, + plt.gca(), + random_color=mask_random_color, + bbox=bbox, + points=points, + pointlabel=args.point_label, + retinamask=args.retina, + target_height=original_h, + target_width=original_w, + ) + else: + if isinstance(annotations[0], np.ndarray): + annotations = torch.from_numpy(annotations) + fast_show_mask_gpu( + annotations, + plt.gca(), + random_color=args.randomcolor, + bbox=bbox, + points=points, + pointlabel=args.point_label, + retinamask=args.retina, + target_height=original_h, + target_width=original_w, + ) + if isinstance(annotations, torch.Tensor): + annotations = annotations.cpu().numpy() + if args.withContours == True: + contour_all = [] + temp = np.zeros((original_h, original_w, 1)) + for i, mask in enumerate(annotations): + if type(mask) == dict: + mask = mask["segmentation"] + annotation = mask.astype(np.uint8) + if args.retina == False: + annotation = cv2.resize( + annotation, + (original_w, original_h), + interpolation=cv2.INTER_NEAREST, + ) + contours, hierarchy = cv2.findContours( + annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + for contour in contours: + contour_all.append(contour) + cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) + color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) + contour_mask = temp / 255 * color.reshape(1, 1, -1) + plt.imshow(contour_mask) + + save_path = args.output + if not os.path.exists(save_path): + os.makedirs(save_path) + plt.axis("off") + fig = plt.gcf() + plt.draw() + buf = fig.canvas.tostring_rgb() + cols, rows = fig.canvas.get_width_height() + img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3) + return img_array + # cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) + + + +# CPU post process +def fast_show_mask( + annotation, + ax, + random_color=False, + bbox=None, + points=None, + pointlabel=None, + retinamask=True, + target_height=960, + target_width=960, +): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + # 将annotation 按照面积 排序 + areas = np.sum(annotation, axis=(1, 2)) + sorted_indices = np.argsort(areas) + annotation = annotation[sorted_indices] + + index = (annotation != 0).argmax(axis=0) + if random_color == True: + color = np.random.random((msak_sum, 1, 1, 3)) + else: + color = np.ones((msak_sum, 1, 1, 3)) * np.array( + [30 / 255, 144 / 255, 255 / 255] + ) + transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 + visual = np.concatenate([color, transparency], axis=-1) + mask_image = np.expand_dims(annotation, -1) * visual + + show = np.zeros((height, weight, 4)) + h_indices, w_indices = np.meshgrid( + np.arange(height), np.arange(weight), indexing="ij" + ) + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # 使用向量化索引更新show的值 + show[h_indices, w_indices, :] = mask_image[indices] + if bbox is not None: + x1, y1, x2, y2 = bbox + ax.add_patch( + plt.Rectangle( + (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 + ) + ) + # draw point + if points is not None: + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], + s=20, + c="y", + ) + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], + s=20, + c="m", + ) + + if retinamask == False: + show = cv2.resize( + show, (target_width, target_height), interpolation=cv2.INTER_NEAREST + ) + ax.imshow(show) + + +def fast_show_mask_gpu( + annotation, + ax, + random_color=False, + bbox=None, + points=None, + pointlabel=None, + retinamask=True, + target_height=960, + target_width=960, +): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + areas = torch.sum(annotation, dim=(1, 2)) + sorted_indices = torch.argsort(areas, descending=False) + annotation = annotation[sorted_indices] + # 找每个位置第一个非零值下标 + index = (annotation != 0).to(torch.long).argmax(dim=0) + if random_color == True: + color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) + else: + color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor( + [30 / 255, 144 / 255, 255 / 255] + ).to(annotation.device) + transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 + visual = torch.cat([color, transparency], dim=-1) + mask_image = torch.unsqueeze(annotation, -1) * visual + # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 + show = torch.zeros((height, weight, 4)).to(annotation.device) + h_indices, w_indices = torch.meshgrid( + torch.arange(height), torch.arange(weight), indexing="ij" + ) + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # 使用向量化索引更新show的值 + show[h_indices, w_indices, :] = mask_image[indices] + show_cpu = show.cpu().numpy() + if bbox is not None: + x1, y1, x2, y2 = bbox + ax.add_patch( + plt.Rectangle( + (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 + ) + ) + # draw point + if points is not None: + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], + s=20, + c="y", + ) + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], + s=20, + c="m", + ) + if retinamask == False: + show_cpu = cv2.resize( + show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST + ) + ax.imshow(show_cpu) + + +# clip +@torch.no_grad() +def retriev( + model, preprocess, elements, search_text: str, device +) -> int: + preprocessed_images = [preprocess(image).to(device) for image in elements] + tokenized_text = clip.tokenize([search_text]).to(device) + stacked_images = torch.stack(preprocessed_images) + image_features = model.encode_image(stacked_images) + text_features = model.encode_text(tokenized_text) + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + probs = 100.0 * image_features @ text_features.T + return probs[:, 0].softmax(dim=0) + + +def crop_image(annotations, image_path): + image = Image.open(image_path) + ori_w, ori_h = image.size + mask_h, mask_w = annotations[0]["segmentation"].shape + if ori_w != mask_w or ori_h != mask_h: + image = image.resize((mask_w, mask_h)) + cropped_boxes = [] + cropped_images = [] + not_crop = [] + filter_id = [] + # annotations, _ = filter_masks(annotations) + # filter_id = list(_) + for _, mask in enumerate(annotations): + if np.sum(mask["segmentation"]) <= 100: + filter_id.append(_) + continue + bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox + cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片 + # cropped_boxes.append(segment_image(image,mask["segmentation"])) + cropped_images.append(bbox) # 保存裁剪的图片的bbox + + return cropped_boxes, cropped_images, not_crop, filter_id, annotations + + +def box_prompt(masks, bbox, target_height, target_width): + h = masks.shape[1] + w = masks.shape[2] + if h != target_height or w != target_width: + bbox = [ + int(bbox[0] * w / target_width), + int(bbox[1] * h / target_height), + int(bbox[2] * w / target_width), + int(bbox[3] * h / target_height), + ] + bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 + bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 + bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w + bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h + + # IoUs = torch.zeros(len(masks), dtype=torch.float32) + bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) + + masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2)) + orig_masks_area = torch.sum(masks, dim=(1, 2)) + + union = bbox_area + orig_masks_area - masks_area + IoUs = masks_area / union + max_iou_index = torch.argmax(IoUs) + + return masks[max_iou_index].cpu().numpy(), max_iou_index + + +def point_prompt(masks, points, pointlabel, target_height, target_width): # numpy 处理 + h = masks[0]["segmentation"].shape[0] + w = masks[0]["segmentation"].shape[1] + if h != target_height or w != target_width: + points = [ + [int(point[0] * w / target_width), int(point[1] * h / target_height)] + for point in points + ] + onemask = np.zeros((h, w)) + for i, annotation in enumerate(masks): + if type(annotation) == dict: + mask = annotation["segmentation"] + else: + mask = annotation + for i, point in enumerate(points): + if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: + onemask += mask + if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: + onemask -= mask + onemask = onemask >= 1 + return onemask, 0 + + +def text_prompt(annotations, args): + cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image( + annotations, args.img_path + ) + clip_model, preprocess = clip.load("ViT-B/32", device=args.device) + scores = retriev( + clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device + ) + max_idx = scores.argsort() + max_idx = max_idx[-1] + max_idx += sum(np.array(filter_id) <= int(max_idx)) + return annotaions[max_idx]["segmentation"], max_idx \ No newline at end of file diff --git a/EfficientSAM/LightHQSAM/example_light_hqsam.png b/EfficientSAM/LightHQSAM/example_light_hqsam.png new file mode 100644 index 0000000000000000000000000000000000000000..179b8c8d9961b7bd7cc4441daf8e806c6aceef19 --- /dev/null +++ b/EfficientSAM/LightHQSAM/example_light_hqsam.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:866820ace9a150b791c00f955c2b436fc72a2e6a43b36187aba975be196161c4 +size 2324039 diff --git a/EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg b/EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5a0af6b52022de48dcdcf2d34ef15914ee141fe4 Binary files /dev/null and b/EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg differ diff --git a/EfficientSAM/LightHQSAM/setup_light_hqsam.py b/EfficientSAM/LightHQSAM/setup_light_hqsam.py new file mode 100644 index 0000000000000000000000000000000000000000..3a34cf7512223c330e8d724104fb9e893058330c --- /dev/null +++ b/EfficientSAM/LightHQSAM/setup_light_hqsam.py @@ -0,0 +1,45 @@ +from LightHQSAM.tiny_vit_sam import TinyViT +from segment_anything.modeling import MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer + +def setup_model(): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoderHQ( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=160, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + return mobile_sam \ No newline at end of file diff --git a/EfficientSAM/LightHQSAM/tiny_vit_sam.py b/EfficientSAM/LightHQSAM/tiny_vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..65f04aa374599f6bb70fe69c81660df9d4e786e1 --- /dev/null +++ b/EfficientSAM/LightHQSAM/tiny_vit_sam.py @@ -0,0 +1,724 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath as TimmDropPath,\ + to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from typing import Tuple + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f'(drop_prob={self.drop_prob})' + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, + activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, + ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN( + self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c=2 + if(out_dim==320 or out_dim==448 or out_dim==576): + stride_c=1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__(self, dim, input_resolution, depth, + activation, + drop_path=0., downsample=None, use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv(dim, dim, conv_expand_ratio, activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product( + range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.register_buffer('ab', + self.attention_biases[:, self.attention_bias_idxs], + persistent=False) + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, + mlp_ratio=4., drop=0., drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, + attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % + self.window_size) % self.window_size + pad_r = (self.window_size - W % + self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., + drop_path=0., downsample=None, use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x +class TinyViT(nn.Module): + def __init__(self, img_size=224, in_chans=3, num_classes=1000, + embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size=img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict(dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( + i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + #print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply( + lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + interm_embeddings=[] + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + # print('x shape:', x.shape, '---i:', i) + if i == 1: + interm_embeddings.append(x.view(x.shape[0], 64, 64, -1)) + + B,_,C=x.size() + x = x.view(B, 64, 64, C) + x=x.permute(0, 3, 1, 2) + x=self.neck(x) + return x, interm_embeddings + + def forward(self, x): + x, interm_embeddings = self.forward_features(x) + #x = self.norm_head(x) + #x = self.head(x) + # print('come to here is correct'* 3) + return x, interm_embeddings + + +_checkpoint_url_format = \ + 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' +_provided_checkpoints = { + 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', + 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', + 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', + 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', + 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', +} + + +def register_tiny_vit_model(fn): + '''Register a TinyViT model + It is a wrapper of `register_model` with loading the pretrained checkpoint. + ''' + def fn_wrapper(pretrained=False, **kwargs): + model = fn() + if pretrained: + model_name = fn.__name__ + assert model_name in _provided_checkpoints, \ + f'Sorry that the checkpoint `{model_name}` is not provided yet.' + url = _checkpoint_url_format.format( + _provided_checkpoints[model_name]) + checkpoint = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', check_hash=False, + ) + model.load_state_dict(checkpoint['model']) + + return model + + # rename the name of fn_wrapper + fn_wrapper.__name__ = fn.__name__ + return register_model(fn_wrapper) + + +@register_tiny_vit_model +def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): + return TinyViT( + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=384, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=512, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=drop_path_rate, + ) diff --git a/EfficientSAM/MobileSAM/setup_mobile_sam.py b/EfficientSAM/MobileSAM/setup_mobile_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..49d8c17cb207dd5c0022bfa8e4f60b53d48cd6e9 --- /dev/null +++ b/EfficientSAM/MobileSAM/setup_mobile_sam.py @@ -0,0 +1,44 @@ +from MobileSAM.tiny_vit_sam import TinyViT +from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + +def setup_model(): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + return mobile_sam \ No newline at end of file diff --git a/EfficientSAM/MobileSAM/tiny_vit_sam.py b/EfficientSAM/MobileSAM/tiny_vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..fb0062b0e1b982a3004c92c8573c4c3754d4aeaa --- /dev/null +++ b/EfficientSAM/MobileSAM/tiny_vit_sam.py @@ -0,0 +1,716 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath as TimmDropPath,\ + to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from typing import Tuple + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f'(drop_prob={self.drop_prob})' + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, + activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, + ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN( + self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c=2 + if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576 + stride_c=1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__(self, dim, input_resolution, depth, + activation, + drop_path=0., downsample=None, use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv(dim, dim, conv_expand_ratio, activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product( + range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, + mlp_ratio=4., drop=0., drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, + attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % + self.window_size) % self.window_size + pad_r = (self.window_size - W % + self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., + drop_path=0., downsample=None, use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x +class TinyViT(nn.Module): + def __init__(self, img_size=224, in_chans=3, num_classes=1000, + embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size=img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict(dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( + i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1],#handongshen + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply( + lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + B,_,C=x.size() + x = x.view(B, 64, 64, C) + x=x.permute(0, 3, 1, 2) + x=self.neck(x) + return x + + def forward(self, x): + x = self.forward_features(x) + + # We have made some hack changes here to make it compatible with SAM-HQ + return x, None + + +_checkpoint_url_format = \ + 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' +_provided_checkpoints = { + 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', + 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', + 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', + 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', + 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', +} + + +def register_tiny_vit_model(fn): + '''Register a TinyViT model + It is a wrapper of `register_model` with loading the pretrained checkpoint. + ''' + def fn_wrapper(pretrained=False, **kwargs): + model = fn() + if pretrained: + model_name = fn.__name__ + assert model_name in _provided_checkpoints, \ + f'Sorry that the checkpoint `{model_name}` is not provided yet.' + url = _checkpoint_url_format.format( + _provided_checkpoints[model_name]) + checkpoint = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', check_hash=False, + ) + model.load_state_dict(checkpoint['model']) + + return model + + # rename the name of fn_wrapper + fn_wrapper.__name__ = fn.__name__ + return register_model(fn_wrapper) + + +@register_tiny_vit_model +def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): + return TinyViT( + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=384, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=512, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=drop_path_rate, + ) \ No newline at end of file diff --git a/EfficientSAM/README.md b/EfficientSAM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b93bdacbca1df347d5365e882f14c8ad4353a55a --- /dev/null +++ b/EfficientSAM/README.md @@ -0,0 +1,194 @@ +## Efficient Grounded-SAM + +We're going to combine [Grounding-DINO](https://github.com/IDEA-Research/GroundingDINO) with efficient SAM variants for faster annotating. + + + + +### Table of Contents +- [Installation](#installation) +- [Efficient SAM Series](#efficient-sams) +- [Run Grounded-FastSAM Demo](#run-grounded-fastsam-demo) +- [Run Grounded-MobileSAM Demo](#run-grounded-mobilesam-demo) +- [Run Grounded-LightHQSAM Demo](#run-grounded-light-hqsam-demo) +- [Run Grounded-Efficient-SAM Demo](#run-grounded-efficient-sam-demo) +- [Run Grounded-Edge-SAM Demo](#run-grounded-edge-sam-demo) +- [Run Grounded-RepViT-SAM Demo](#run-grounded-repvit-sam-demo) + + +### Installation + +- Install [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything#installation) + +- Install [Fast-SAM](https://github.com/CASIA-IVA-Lab/FastSAM#installation) + +- Note that we may use the sam image as the demo image in order to compare the inference results of different efficient-sam variants. + +### Efficient SAMs +Here's the list of Efficient SAM variants: + +
+ | name | +backbone | +Data | +box AP on COCO | +Checkpoint | +Config | +
---|---|---|---|---|---|---|
1 | +GroundingDINO-T | +Swin-T | +O365,GoldG,Cap4M | +48.4 (zero-shot) / 57.2 (fine-tune) | +Github link | HF link | +link | +
+
+
+
+
+
+
+
+
+