File size: 1,727 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=True):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
self.fc1 = Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def dsconv_3x3(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
class changedetector(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.mlp1 = Mlp(in_features = in_channel, out_features = in_channel)
self.mlp2 = Mlp(in_features = in_channel, out_features=2)
self.dwc = dsconv_3x3(in_channel, in_channel)
def forward(self, x):
x1 = self.mlp1(x)
x_d = self.dwc(x1)
x_out = self.mlp2(x1 + x_d)
x_out = F.interpolate(
x_out,
scale_factor=(4,4),
mode="bilinear",
align_corners=False,
)
return x_out
|