|
@@ -0,0 +1,474 @@
|
|
|
|
|
+import torch
|
|
|
|
|
+import torch.nn as nn
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ConvBNReLU(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
|
|
|
|
|
+ dilation=1, groups=1, bias=False):
|
|
|
|
|
+ super(ConvBNReLU, self).__init__()
|
|
|
|
|
+ self.conv = nn.Conv2d(
|
|
|
|
|
+ in_chan, out_chan, kernel_size=ks, stride=stride,
|
|
|
|
|
+ padding=padding, dilation=dilation,
|
|
|
|
|
+ groups=groups, bias=bias)
|
|
|
|
|
+ self.bn = nn.BatchNorm2d(out_chan)
|
|
|
|
|
+ self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = self.conv(x)
|
|
|
|
|
+ feat = self.bn(feat)
|
|
|
|
|
+ feat = self.relu(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class UpSample(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, n_chan, factor=2):
|
|
|
|
|
+ super(UpSample, self).__init__()
|
|
|
|
|
+ out_chan = n_chan * factor * factor
|
|
|
|
|
+ self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
|
|
|
|
|
+ self.up = nn.PixelShuffle(factor)
|
|
|
|
|
+ self.init_weight()
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = self.proj(x)
|
|
|
|
|
+ feat = self.up(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+ def init_weight(self):
|
|
|
|
|
+ nn.init.xavier_normal_(self.proj.weight, gain=1.)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DetailBranch(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, input_channel=3):
|
|
|
|
|
+ super(DetailBranch, self).__init__()
|
|
|
|
|
+ self.S1 = nn.Sequential(
|
|
|
|
|
+ ConvBNReLU(input_channel, 64, 3, stride=2),
|
|
|
|
|
+ ConvBNReLU(64, 64, 3, stride=1),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.S2 = nn.Sequential(
|
|
|
|
|
+ ConvBNReLU(64, 64, 3, stride=2),
|
|
|
|
|
+ ConvBNReLU(64, 64, 3, stride=1),
|
|
|
|
|
+ ConvBNReLU(64, 64, 3, stride=1),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.S3 = nn.Sequential(
|
|
|
|
|
+ ConvBNReLU(64, 128, 3, stride=2),
|
|
|
|
|
+ ConvBNReLU(128, 128, 3, stride=1),
|
|
|
|
|
+ ConvBNReLU(128, 128, 3, stride=1),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = self.S1(x)
|
|
|
|
|
+ feat = self.S2(feat)
|
|
|
|
|
+ feat = self.S3(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class StemBlock(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, input_channel=3):
|
|
|
|
|
+ super(StemBlock, self).__init__()
|
|
|
|
|
+ self.conv = ConvBNReLU(input_channel, 16, 3, stride=2)
|
|
|
|
|
+ self.left = nn.Sequential(
|
|
|
|
|
+ ConvBNReLU(16, 8, 1, stride=1, padding=0),
|
|
|
|
|
+ ConvBNReLU(8, 16, 3, stride=2),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.right = nn.MaxPool2d(
|
|
|
|
|
+ kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
|
|
|
|
+ self.fuse = ConvBNReLU(32, 16, 3, stride=1)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = self.conv(x)
|
|
|
|
|
+ feat_left = self.left(feat)
|
|
|
|
|
+ feat_right = self.right(feat)
|
|
|
|
|
+ feat = torch.cat([feat_left, feat_right], dim=1)
|
|
|
|
|
+ feat = self.fuse(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class CEBlock(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ super(CEBlock, self).__init__()
|
|
|
|
|
+ self.bn = nn.BatchNorm2d(128)
|
|
|
|
|
+ self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
|
|
|
|
|
+ # TODO: in paper here is naive conv2d, no bn-relu
|
|
|
|
|
+ self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = torch.mean(x, dim=(2, 3), keepdim=True)
|
|
|
|
|
+ feat = self.bn(feat)
|
|
|
|
|
+ feat = self.conv_gap(feat)
|
|
|
|
|
+ feat = feat + x
|
|
|
|
|
+ feat = self.conv_last(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class GELayerS1(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, in_chan, out_chan, exp_ratio=6):
|
|
|
|
|
+ super(GELayerS1, self).__init__()
|
|
|
|
|
+ mid_chan = in_chan * exp_ratio
|
|
|
|
|
+ self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
|
|
|
|
|
+ self.dwconv = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ in_chan, mid_chan, kernel_size=3, stride=1,
|
|
|
|
|
+ padding=1, groups=in_chan, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(mid_chan),
|
|
|
|
|
+ nn.ReLU(inplace=True), # not shown in paper
|
|
|
|
|
+ )
|
|
|
|
|
+ self.conv2 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ mid_chan, out_chan, kernel_size=1, stride=1,
|
|
|
|
|
+ padding=0, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(out_chan),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.conv2[1].last_bn = True
|
|
|
|
|
+ self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = self.conv1(x)
|
|
|
|
|
+ feat = self.dwconv(feat)
|
|
|
|
|
+ feat = self.conv2(feat)
|
|
|
|
|
+ feat = feat + x
|
|
|
|
|
+ feat = self.relu(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class GELayerS2(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, in_chan, out_chan, exp_ratio=6):
|
|
|
|
|
+ super(GELayerS2, self).__init__()
|
|
|
|
|
+ mid_chan = in_chan * exp_ratio
|
|
|
|
|
+ self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
|
|
|
|
|
+ self.dwconv1 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ in_chan, mid_chan, kernel_size=3, stride=2,
|
|
|
|
|
+ padding=1, groups=in_chan, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(mid_chan),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.dwconv2 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ mid_chan, mid_chan, kernel_size=3, stride=1,
|
|
|
|
|
+ padding=1, groups=mid_chan, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(mid_chan),
|
|
|
|
|
+ nn.ReLU(inplace=True), # not shown in paper
|
|
|
|
|
+ )
|
|
|
|
|
+ self.conv2 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ mid_chan, out_chan, kernel_size=1, stride=1,
|
|
|
|
|
+ padding=0, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(out_chan),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.conv2[1].last_bn = True
|
|
|
|
|
+ self.shortcut = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ in_chan, in_chan, kernel_size=3, stride=2,
|
|
|
|
|
+ padding=1, groups=in_chan, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(in_chan),
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ in_chan, out_chan, kernel_size=1, stride=1,
|
|
|
|
|
+ padding=0, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(out_chan),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = self.conv1(x)
|
|
|
|
|
+ feat = self.dwconv1(feat)
|
|
|
|
|
+ feat = self.dwconv2(feat)
|
|
|
|
|
+ feat = self.conv2(feat)
|
|
|
|
|
+ shortcut = self.shortcut(x)
|
|
|
|
|
+ feat = feat + shortcut
|
|
|
|
|
+ feat = self.relu(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class SegmentBranch(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, input_channel=3):
|
|
|
|
|
+ super(SegmentBranch, self).__init__()
|
|
|
|
|
+ self.S1S2 = StemBlock(input_channel)
|
|
|
|
|
+ self.S3 = nn.Sequential(
|
|
|
|
|
+ GELayerS2(16, 32),
|
|
|
|
|
+ GELayerS1(32, 32),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.S4 = nn.Sequential(
|
|
|
|
|
+ GELayerS2(32, 64),
|
|
|
|
|
+ GELayerS1(64, 64),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.S5_4 = nn.Sequential(
|
|
|
|
|
+ GELayerS2(64, 128),
|
|
|
|
|
+ GELayerS1(128, 128),
|
|
|
|
|
+ GELayerS1(128, 128),
|
|
|
|
|
+ GELayerS1(128, 128),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.S5_5 = CEBlock()
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat2 = self.S1S2(x)
|
|
|
|
|
+ feat3 = self.S3(feat2)
|
|
|
|
|
+ feat4 = self.S4(feat3)
|
|
|
|
|
+ feat5_4 = self.S5_4(feat4)
|
|
|
|
|
+ feat5_5 = self.S5_5(feat5_4)
|
|
|
|
|
+ return feat2, feat3, feat4, feat5_4, feat5_5
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class BGALayer(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ super(BGALayer, self).__init__()
|
|
|
|
|
+ self.left1 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ 128, 128, kernel_size=3, stride=1,
|
|
|
|
|
+ padding=1, groups=128, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(128),
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ 128, 128, kernel_size=1, stride=1,
|
|
|
|
|
+ padding=0, bias=False),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.left2 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ 128, 128, kernel_size=3, stride=2,
|
|
|
|
|
+ padding=1, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(128),
|
|
|
|
|
+ nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
|
|
|
|
+ )
|
|
|
|
|
+ self.right1 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ 128, 128, kernel_size=3, stride=1,
|
|
|
|
|
+ padding=1, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(128),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.right2 = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ 128, 128, kernel_size=3, stride=1,
|
|
|
|
|
+ padding=1, groups=128, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(128),
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ 128, 128, kernel_size=1, stride=1,
|
|
|
|
|
+ padding=0, bias=False),
|
|
|
|
|
+ )
|
|
|
|
|
+ self.up1 = nn.Upsample(scale_factor=4)
|
|
|
|
|
+ self.up2 = nn.Upsample(scale_factor=4)
|
|
|
|
|
+ ##TODO: does this really has no relu?
|
|
|
|
|
+ self.conv = nn.Sequential(
|
|
|
|
|
+ nn.Conv2d(
|
|
|
|
|
+ 128, 128, kernel_size=3, stride=1,
|
|
|
|
|
+ padding=1, bias=False),
|
|
|
|
|
+ nn.BatchNorm2d(128),
|
|
|
|
|
+ nn.ReLU(inplace=True), # not shown in paper
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x_d, x_s):
|
|
|
|
|
+ dsize = x_d.size()[2:]
|
|
|
|
|
+ left1 = self.left1(x_d)
|
|
|
|
|
+ left2 = self.left2(x_d)
|
|
|
|
|
+ right1 = self.right1(x_s)
|
|
|
|
|
+ right2 = self.right2(x_s)
|
|
|
|
|
+ right1 = self.up1(right1)
|
|
|
|
|
+ left = left1 * torch.sigmoid(right1)
|
|
|
|
|
+ right = left2 * torch.sigmoid(right2)
|
|
|
|
|
+ right = self.up2(right)
|
|
|
|
|
+ out = self.conv(left + right)
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class SegmentHead(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True):
|
|
|
|
|
+ super(SegmentHead, self).__init__()
|
|
|
|
|
+ self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
|
|
|
|
|
+ self.drop = nn.Dropout(0.1)
|
|
|
|
|
+ self.up_factor = up_factor
|
|
|
|
|
+
|
|
|
|
|
+ out_chan = n_classes
|
|
|
|
|
+ mid_chan2 = up_factor * up_factor if aux else mid_chan
|
|
|
|
|
+ up_factor = up_factor // 2 if aux else up_factor
|
|
|
|
|
+ self.conv_out = nn.Sequential(
|
|
|
|
|
+ nn.Sequential(
|
|
|
|
|
+ nn.Upsample(scale_factor=2),
|
|
|
|
|
+ ConvBNReLU(mid_chan, mid_chan2, 3, stride=1)
|
|
|
|
|
+ ) if aux else nn.Identity(),
|
|
|
|
|
+ nn.Conv2d(mid_chan2, out_chan, 1, 1, 0, bias=True),
|
|
|
|
|
+ nn.Upsample(scale_factor=up_factor, mode='bilinear', align_corners=False)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ feat = self.conv(x)
|
|
|
|
|
+ feat = self.drop(feat)
|
|
|
|
|
+ feat = self.conv_out(feat)
|
|
|
|
|
+ return feat
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class BiSeNetV2(nn.Module):
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, n_classes, input_channels=3, aux_mode='train'):
|
|
|
|
|
+ super(BiSeNetV2, self).__init__()
|
|
|
|
|
+ self.aux_mode = aux_mode
|
|
|
|
|
+ self.detail = DetailBranch(input_channels)
|
|
|
|
|
+ self.segment = SegmentBranch(input_channels)
|
|
|
|
|
+ self.bga = BGALayer()
|
|
|
|
|
+
|
|
|
|
|
+ ## TODO: what is the number of mid chan ?
|
|
|
|
|
+ self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False)
|
|
|
|
|
+ if self.aux_mode == 'train':
|
|
|
|
|
+ self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4)
|
|
|
|
|
+ self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8)
|
|
|
|
|
+ self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16)
|
|
|
|
|
+ self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)
|
|
|
|
|
+
|
|
|
|
|
+ self.init_weights()
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x):
|
|
|
|
|
+ size = x.size()[2:]
|
|
|
|
|
+
|
|
|
|
|
+ feat_d = self.detail(x)
|
|
|
|
|
+ feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
|
|
|
|
|
+ feat_head = self.bga(feat_d, feat_s)
|
|
|
|
|
+
|
|
|
|
|
+ logits = self.head(feat_head)
|
|
|
|
|
+ if self.aux_mode == 'train':
|
|
|
|
|
+ logits_aux2 = self.aux2(feat2)
|
|
|
|
|
+ logits_aux3 = self.aux3(feat3)
|
|
|
|
|
+ logits_aux4 = self.aux4(feat4)
|
|
|
|
|
+ logits_aux5_4 = self.aux5_4(feat5_4)
|
|
|
|
|
+ return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4
|
|
|
|
|
+ elif self.aux_mode == 'eval':
|
|
|
|
|
+ return logits,
|
|
|
|
|
+ elif self.aux_mode == 'pred':
|
|
|
|
|
+ pred = logits.argmax(dim=1)
|
|
|
|
|
+ return pred
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise NotImplementedError
|
|
|
|
|
+
|
|
|
|
|
+ def init_weights(self):
|
|
|
|
|
+ for name, module in self.named_modules():
|
|
|
|
|
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
|
|
|
|
|
+ nn.init.kaiming_normal_(module.weight, mode='fan_out')
|
|
|
|
|
+ if not module.bias is None: nn.init.constant_(module.bias, 0)
|
|
|
|
|
+ elif isinstance(module, nn.modules.batchnorm._BatchNorm):
|
|
|
|
|
+ if hasattr(module, 'last_bn') and module.last_bn:
|
|
|
|
|
+ nn.init.zeros_(module.weight)
|
|
|
|
|
+ else:
|
|
|
|
|
+ nn.init.ones_(module.weight)
|
|
|
|
|
+ nn.init.zeros_(module.bias)
|
|
|
|
|
+ self.load_pretrain()
|
|
|
|
|
+
|
|
|
|
|
+ def load_pretrain(self):
|
|
|
|
|
+ # 230423:推理时,不必在这里加载预训练模型
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ def get_params(self):
|
|
|
|
|
+ def add_param_to_list(mod, wd_params, nowd_params):
|
|
|
|
|
+ for param in mod.parameters():
|
|
|
|
|
+ if param.dim() == 1:
|
|
|
|
|
+ nowd_params.append(param)
|
|
|
|
|
+ elif param.dim() == 4:
|
|
|
|
|
+ wd_params.append(param)
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(name)
|
|
|
|
|
+
|
|
|
|
|
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
|
|
|
|
+ for name, child in self.named_children():
|
|
|
|
|
+ if 'head' in name or 'aux' in name:
|
|
|
|
|
+ add_param_to_list(child, lr_mul_wd_params, lr_mul_nowd_params)
|
|
|
|
|
+ else:
|
|
|
|
|
+ add_param_to_list(child, wd_params, nowd_params)
|
|
|
|
|
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class OhemCELoss(nn.Module):
|
|
|
|
|
+ """
|
|
|
|
|
+ 算法本质:
|
|
|
|
|
+ Ohem本质:核心思路是取所有损失大于阈值的像素点参与计算,但是最少也要保证取n_min个
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, paramsDict, thresh, lb_ignore=255):
|
|
|
|
|
+ super(OhemCELoss, self).__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.paramsDict = paramsDict
|
|
|
|
|
+
|
|
|
|
|
+ device_str = self.paramsDict['params']['device_str']
|
|
|
|
|
+ # 确保模型被发送到device_str
|
|
|
|
|
+ device = torch.device(device_str)
|
|
|
|
|
+
|
|
|
|
|
+ # self.thresh = 0.3567
|
|
|
|
|
+ self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).to(device)
|
|
|
|
|
+ # self.lb_ignore = 255
|
|
|
|
|
+ self.lb_ignore = lb_ignore
|
|
|
|
|
+ self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, logits, labels):
|
|
|
|
|
+ # logits: [2,11,1088,896] batch,classNum,height,width
|
|
|
|
|
+ # labels: [2,1088,896] batch,height,width
|
|
|
|
|
+
|
|
|
|
|
+ # 1、计算n_min(最少算多少个像素点)的大小
|
|
|
|
|
+ # n_min的大小:一个batch的n张h*w的label图的所有的像素点的十六分之一
|
|
|
|
|
+ # n_min: 121856
|
|
|
|
|
+ n_min = labels[labels != self.lb_ignore].numel() // 16
|
|
|
|
|
+ # 2、交叉熵预测得到loss之后,打平成一维的
|
|
|
|
|
+ # loss.shape = (1949696,) 1949696 = 2 * 1088 * 896
|
|
|
|
|
+ loss = self.criteria(logits, labels).view(-1)
|
|
|
|
|
+ # 3、所有loss中大于阈值的,这边叫做loss hard,这些点才参与损失计算
|
|
|
|
|
+ # 注意,这里是优化了pytorch中 Ohem 排序的,不然排序太耗时间了
|
|
|
|
|
+ # loss_hard.shape = (140232,)
|
|
|
|
|
+ loss_hard = loss[loss > self.thresh]
|
|
|
|
|
+ # 4、如果总数小于了n_min,那么肯定要保证有n_min个
|
|
|
|
|
+ if loss_hard.numel() < n_min:
|
|
|
|
|
+ loss_hard, _ = loss.topk(n_min)
|
|
|
|
|
+ # 5、如果参与的像素点的个数大于了n_min个,那么这些点都参与计算
|
|
|
|
|
+ # loss_hard_mean = 0.7070
|
|
|
|
|
+ loss_hard_mean = torch.mean(loss_hard)
|
|
|
|
|
+ # 6、返回损失的均值
|
|
|
|
|
+ # 7、为什么Ohem的损失不能很好的评估模型的损失
|
|
|
|
|
+ # 因为Ohem对应的损失只考虑了大于阈值对应部分的损失,小于阈值部分的没有考虑
|
|
|
|
|
+ return loss_hard_mean
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# if __name__ == "__main__":
|
|
|
|
|
+#
|
|
|
|
|
+# # ==========================================================
|
|
|
|
|
+# # 支持不同输入通道的bisenetv2
|
|
|
|
|
+# # ==========================================================
|
|
|
|
|
+#
|
|
|
|
|
+# input_channels = 7
|
|
|
|
|
+#
|
|
|
|
|
+# x = torch.randn(2, input_channels, 256, 256).cuda()
|
|
|
|
|
+# # x = torch.randn(2, 3, 224, 224).cuda()
|
|
|
|
|
+# print("=============输入:=============")
|
|
|
|
|
+# print(x.shape)
|
|
|
|
|
+#
|
|
|
|
|
+# model = BiSeNetV2(n_classes=19,input_channels=7)
|
|
|
|
|
+#
|
|
|
|
|
+# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
+# print(device)
|
|
|
|
|
+# model = model.to(device)
|
|
|
|
|
+#
|
|
|
|
|
+# netBeforeTime = time.time()
|
|
|
|
|
+# outs = model(x)
|
|
|
|
|
+# netEndTime = time.time()
|
|
|
|
|
+# print("模型推理花费时间:",netEndTime-netBeforeTime)
|
|
|
|
|
+# print("=============输出:=============")
|
|
|
|
|
+# for out in outs:
|
|
|
|
|
+# print(out.size())
|
|
|
|
|
+# # print(logits.size())
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+"""
|
|
|
|
|
+=============输入:=============
|
|
|
|
|
+torch.Size([2, 7, 256, 256])
|
|
|
|
|
+cuda
|
|
|
|
|
+模型推理花费时间: 0.3020000457763672
|
|
|
|
|
+=============输出:=============
|
|
|
|
|
+torch.Size([2, 19, 256, 256])
|
|
|
|
|
+torch.Size([2, 19, 256, 256])
|
|
|
|
|
+torch.Size([2, 19, 256, 256])
|
|
|
|
|
+torch.Size([2, 19, 256, 256])
|
|
|
|
|
+torch.Size([2, 19, 256, 256])
|
|
|
|
|
+
|
|
|
|
|
+进程已结束,退出代码0
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+"""
|