import torch import torch.nn as nn class ConvBNReLU(nn.Module): def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, dilation=1, groups=1, bias=False): super(ConvBNReLU, self).__init__() self.conv = nn.Conv2d( in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_chan) self.relu = nn.ReLU(inplace=True) def forward(self, x): feat = self.conv(x) feat = self.bn(feat) feat = self.relu(feat) return feat class UpSample(nn.Module): def __init__(self, n_chan, factor=2): super(UpSample, self).__init__() out_chan = n_chan * factor * factor self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0) self.up = nn.PixelShuffle(factor) self.init_weight() def forward(self, x): feat = self.proj(x) feat = self.up(feat) return feat def init_weight(self): nn.init.xavier_normal_(self.proj.weight, gain=1.) class DetailBranch(nn.Module): def __init__(self, input_channel=3): super(DetailBranch, self).__init__() self.S1 = nn.Sequential( ConvBNReLU(input_channel, 64, 3, stride=2), ConvBNReLU(64, 64, 3, stride=1), ) self.S2 = nn.Sequential( ConvBNReLU(64, 64, 3, stride=2), ConvBNReLU(64, 64, 3, stride=1), ConvBNReLU(64, 64, 3, stride=1), ) self.S3 = nn.Sequential( ConvBNReLU(64, 128, 3, stride=2), ConvBNReLU(128, 128, 3, stride=1), ConvBNReLU(128, 128, 3, stride=1), ) def forward(self, x): feat = self.S1(x) feat = self.S2(feat) feat = self.S3(feat) return feat class StemBlock(nn.Module): def __init__(self, input_channel=3): super(StemBlock, self).__init__() self.conv = ConvBNReLU(input_channel, 16, 3, stride=2) self.left = nn.Sequential( ConvBNReLU(16, 8, 1, stride=1, padding=0), ConvBNReLU(8, 16, 3, stride=2), ) self.right = nn.MaxPool2d( kernel_size=3, stride=2, padding=1, ceil_mode=False) self.fuse = ConvBNReLU(32, 16, 3, stride=1) def forward(self, x): feat = self.conv(x) feat_left = self.left(feat) feat_right = self.right(feat) feat = torch.cat([feat_left, feat_right], dim=1) feat = self.fuse(feat) return feat class CEBlock(nn.Module): def __init__(self): super(CEBlock, self).__init__() self.bn = nn.BatchNorm2d(128) self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0) # TODO: in paper here is naive conv2d, no bn-relu self.conv_last = ConvBNReLU(128, 128, 3, stride=1) def forward(self, x): feat = torch.mean(x, dim=(2, 3), keepdim=True) feat = self.bn(feat) feat = self.conv_gap(feat) feat = feat + x feat = self.conv_last(feat) return feat class GELayerS1(nn.Module): def __init__(self, in_chan, out_chan, exp_ratio=6): super(GELayerS1, self).__init__() mid_chan = in_chan * exp_ratio self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1) self.dwconv = nn.Sequential( nn.Conv2d( in_chan, mid_chan, kernel_size=3, stride=1, padding=1, groups=in_chan, bias=False), nn.BatchNorm2d(mid_chan), nn.ReLU(inplace=True), # not shown in paper ) self.conv2 = nn.Sequential( nn.Conv2d( mid_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_chan), ) self.conv2[1].last_bn = True self.relu = nn.ReLU(inplace=True) def forward(self, x): feat = self.conv1(x) feat = self.dwconv(feat) feat = self.conv2(feat) feat = feat + x feat = self.relu(feat) return feat class GELayerS2(nn.Module): def __init__(self, in_chan, out_chan, exp_ratio=6): super(GELayerS2, self).__init__() mid_chan = in_chan * exp_ratio self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1) self.dwconv1 = nn.Sequential( nn.Conv2d( in_chan, mid_chan, kernel_size=3, stride=2, padding=1, groups=in_chan, bias=False), nn.BatchNorm2d(mid_chan), ) self.dwconv2 = nn.Sequential( nn.Conv2d( mid_chan, mid_chan, kernel_size=3, stride=1, padding=1, groups=mid_chan, bias=False), nn.BatchNorm2d(mid_chan), nn.ReLU(inplace=True), # not shown in paper ) self.conv2 = nn.Sequential( nn.Conv2d( mid_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_chan), ) self.conv2[1].last_bn = True self.shortcut = nn.Sequential( nn.Conv2d( in_chan, in_chan, kernel_size=3, stride=2, padding=1, groups=in_chan, bias=False), nn.BatchNorm2d(in_chan), nn.Conv2d( in_chan, out_chan, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_chan), ) self.relu = nn.ReLU(inplace=True) def forward(self, x): feat = self.conv1(x) feat = self.dwconv1(feat) feat = self.dwconv2(feat) feat = self.conv2(feat) shortcut = self.shortcut(x) feat = feat + shortcut feat = self.relu(feat) return feat class SegmentBranch(nn.Module): def __init__(self, input_channel=3): super(SegmentBranch, self).__init__() self.S1S2 = StemBlock(input_channel) self.S3 = nn.Sequential( GELayerS2(16, 32), GELayerS1(32, 32), ) self.S4 = nn.Sequential( GELayerS2(32, 64), GELayerS1(64, 64), ) self.S5_4 = nn.Sequential( GELayerS2(64, 128), GELayerS1(128, 128), GELayerS1(128, 128), GELayerS1(128, 128), ) self.S5_5 = CEBlock() def forward(self, x): feat2 = self.S1S2(x) feat3 = self.S3(feat2) feat4 = self.S4(feat3) feat5_4 = self.S5_4(feat4) feat5_5 = self.S5_5(feat5_4) return feat2, feat3, feat4, feat5_4, feat5_5 class BGALayer(nn.Module): def __init__(self): super(BGALayer, self).__init__() self.left1 = nn.Sequential( nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False), nn.BatchNorm2d(128), nn.Conv2d( 128, 128, kernel_size=1, stride=1, padding=0, bias=False), ) self.left2 = nn.Sequential( nn.Conv2d( 128, 128, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) ) self.right1 = nn.Sequential( nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(128), ) self.right2 = nn.Sequential( nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, groups=128, bias=False), nn.BatchNorm2d(128), nn.Conv2d( 128, 128, kernel_size=1, stride=1, padding=0, bias=False), ) self.up1 = nn.Upsample(scale_factor=4) self.up2 = nn.Upsample(scale_factor=4) ##TODO: does this really has no relu? self.conv = nn.Sequential( nn.Conv2d( 128, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # not shown in paper ) def forward(self, x_d, x_s): dsize = x_d.size()[2:] left1 = self.left1(x_d) left2 = self.left2(x_d) right1 = self.right1(x_s) right2 = self.right2(x_s) right1 = self.up1(right1) left = left1 * torch.sigmoid(right1) right = left2 * torch.sigmoid(right2) right = self.up2(right) out = self.conv(left + right) return out class SegmentHead(nn.Module): def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True): super(SegmentHead, self).__init__() self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1) self.drop = nn.Dropout(0.1) self.up_factor = up_factor out_chan = n_classes mid_chan2 = up_factor * up_factor if aux else mid_chan up_factor = up_factor // 2 if aux else up_factor self.conv_out = nn.Sequential( nn.Sequential( nn.Upsample(scale_factor=2), ConvBNReLU(mid_chan, mid_chan2, 3, stride=1) ) if aux else nn.Identity(), nn.Conv2d(mid_chan2, out_chan, 1, 1, 0, bias=True), nn.Upsample(scale_factor=up_factor, mode='bilinear', align_corners=False) ) def forward(self, x): feat = self.conv(x) feat = self.drop(feat) feat = self.conv_out(feat) return feat class BiSeNetV2(nn.Module): def __init__(self, n_classes, input_channels=3, aux_mode='train'): super(BiSeNetV2, self).__init__() self.aux_mode = aux_mode self.detail = DetailBranch(input_channels) self.segment = SegmentBranch(input_channels) self.bga = BGALayer() ## TODO: what is the number of mid chan ? self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False) if self.aux_mode == 'train': self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4) self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8) self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16) self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32) self.init_weights() def forward(self, x): size = x.size()[2:] feat_d = self.detail(x) feat2, feat3, feat4, feat5_4, feat_s = self.segment(x) feat_head = self.bga(feat_d, feat_s) logits = self.head(feat_head) if self.aux_mode == 'train': logits_aux2 = self.aux2(feat2) logits_aux3 = self.aux3(feat3) logits_aux4 = self.aux4(feat4) logits_aux5_4 = self.aux5_4(feat5_4) return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4 elif self.aux_mode == 'eval': return logits, elif self.aux_mode == 'pred': pred = logits.argmax(dim=1) return pred else: raise NotImplementedError def init_weights(self): for name, module in self.named_modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(module.weight, mode='fan_out') if not module.bias is None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.modules.batchnorm._BatchNorm): if hasattr(module, 'last_bn') and module.last_bn: nn.init.zeros_(module.weight) else: nn.init.ones_(module.weight) nn.init.zeros_(module.bias) self.load_pretrain() def load_pretrain(self): # 230423:推理时,不必在这里加载预训练模型 pass def get_params(self): def add_param_to_list(mod, wd_params, nowd_params): for param in mod.parameters(): if param.dim() == 1: nowd_params.append(param) elif param.dim() == 4: wd_params.append(param) else: print(name) wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] for name, child in self.named_children(): if 'head' in name or 'aux' in name: add_param_to_list(child, lr_mul_wd_params, lr_mul_nowd_params) else: add_param_to_list(child, wd_params, nowd_params) return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params class OhemCELoss(nn.Module): """ 算法本质: Ohem本质:核心思路是取所有损失大于阈值的像素点参与计算,但是最少也要保证取n_min个 """ def __init__(self, paramsDict, thresh, lb_ignore=255): super(OhemCELoss, self).__init__() self.paramsDict = paramsDict device_str = self.paramsDict['params']['device_str'] # 确保模型被发送到device_str device = torch.device(device_str) # self.thresh = 0.3567 self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).to(device) # self.lb_ignore = 255 self.lb_ignore = lb_ignore self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none') def forward(self, logits, labels): # logits: [2,11,1088,896] batch,classNum,height,width # labels: [2,1088,896] batch,height,width # 1、计算n_min(最少算多少个像素点)的大小 # n_min的大小:一个batch的n张h*w的label图的所有的像素点的十六分之一 # n_min: 121856 n_min = labels[labels != self.lb_ignore].numel() // 16 # 2、交叉熵预测得到loss之后,打平成一维的 # loss.shape = (1949696,) 1949696 = 2 * 1088 * 896 loss = self.criteria(logits, labels).view(-1) # 3、所有loss中大于阈值的,这边叫做loss hard,这些点才参与损失计算 # 注意,这里是优化了pytorch中 Ohem 排序的,不然排序太耗时间了 # loss_hard.shape = (140232,) loss_hard = loss[loss > self.thresh] # 4、如果总数小于了n_min,那么肯定要保证有n_min个 if loss_hard.numel() < n_min: loss_hard, _ = loss.topk(n_min) # 5、如果参与的像素点的个数大于了n_min个,那么这些点都参与计算 # loss_hard_mean = 0.7070 loss_hard_mean = torch.mean(loss_hard) # 6、返回损失的均值 # 7、为什么Ohem的损失不能很好的评估模型的损失 # 因为Ohem对应的损失只考虑了大于阈值对应部分的损失,小于阈值部分的没有考虑 return loss_hard_mean # if __name__ == "__main__": # # # ========================================================== # # 支持不同输入通道的bisenetv2 # # ========================================================== # # input_channels = 7 # # x = torch.randn(2, input_channels, 256, 256).cuda() # # x = torch.randn(2, 3, 224, 224).cuda() # print("=============输入:=============") # print(x.shape) # # model = BiSeNetV2(n_classes=19,input_channels=7) # # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(device) # model = model.to(device) # # netBeforeTime = time.time() # outs = model(x) # netEndTime = time.time() # print("模型推理花费时间:",netEndTime-netBeforeTime) # print("=============输出:=============") # for out in outs: # print(out.size()) # # print(logits.size()) """ =============输入:============= torch.Size([2, 7, 256, 256]) cuda 模型推理花费时间: 0.3020000457763672 =============输出:============= torch.Size([2, 19, 256, 256]) torch.Size([2, 19, 256, 256]) torch.Size([2, 19, 256, 256]) torch.Size([2, 19, 256, 256]) torch.Size([2, 19, 256, 256]) 进程已结束,退出代码0 """