backbone.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import torch
  2. import torch.nn as nn
  3. class ConvBNReLU(nn.Module):
  4. def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
  5. dilation=1, groups=1, bias=False):
  6. super(ConvBNReLU, self).__init__()
  7. self.conv = nn.Conv2d(
  8. in_chan, out_chan, kernel_size=ks, stride=stride,
  9. padding=padding, dilation=dilation,
  10. groups=groups, bias=bias)
  11. self.bn = nn.BatchNorm2d(out_chan)
  12. self.relu = nn.ReLU(inplace=True)
  13. def forward(self, x):
  14. feat = self.conv(x)
  15. feat = self.bn(feat)
  16. feat = self.relu(feat)
  17. return feat
  18. class UpSample(nn.Module):
  19. def __init__(self, n_chan, factor=2):
  20. super(UpSample, self).__init__()
  21. out_chan = n_chan * factor * factor
  22. self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
  23. self.up = nn.PixelShuffle(factor)
  24. self.init_weight()
  25. def forward(self, x):
  26. feat = self.proj(x)
  27. feat = self.up(feat)
  28. return feat
  29. def init_weight(self):
  30. nn.init.xavier_normal_(self.proj.weight, gain=1.)
  31. class DetailBranch(nn.Module):
  32. def __init__(self, input_channel=3):
  33. super(DetailBranch, self).__init__()
  34. self.S1 = nn.Sequential(
  35. ConvBNReLU(input_channel, 64, 3, stride=2),
  36. ConvBNReLU(64, 64, 3, stride=1),
  37. )
  38. self.S2 = nn.Sequential(
  39. ConvBNReLU(64, 64, 3, stride=2),
  40. ConvBNReLU(64, 64, 3, stride=1),
  41. ConvBNReLU(64, 64, 3, stride=1),
  42. )
  43. self.S3 = nn.Sequential(
  44. ConvBNReLU(64, 128, 3, stride=2),
  45. ConvBNReLU(128, 128, 3, stride=1),
  46. ConvBNReLU(128, 128, 3, stride=1),
  47. )
  48. def forward(self, x):
  49. feat = self.S1(x)
  50. feat = self.S2(feat)
  51. feat = self.S3(feat)
  52. return feat
  53. class StemBlock(nn.Module):
  54. def __init__(self, input_channel=3):
  55. super(StemBlock, self).__init__()
  56. self.conv = ConvBNReLU(input_channel, 16, 3, stride=2)
  57. self.left = nn.Sequential(
  58. ConvBNReLU(16, 8, 1, stride=1, padding=0),
  59. ConvBNReLU(8, 16, 3, stride=2),
  60. )
  61. self.right = nn.MaxPool2d(
  62. kernel_size=3, stride=2, padding=1, ceil_mode=False)
  63. self.fuse = ConvBNReLU(32, 16, 3, stride=1)
  64. def forward(self, x):
  65. feat = self.conv(x)
  66. feat_left = self.left(feat)
  67. feat_right = self.right(feat)
  68. feat = torch.cat([feat_left, feat_right], dim=1)
  69. feat = self.fuse(feat)
  70. return feat
  71. class CEBlock(nn.Module):
  72. def __init__(self):
  73. super(CEBlock, self).__init__()
  74. self.bn = nn.BatchNorm2d(128)
  75. self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
  76. # TODO: in paper here is naive conv2d, no bn-relu
  77. self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
  78. def forward(self, x):
  79. feat = torch.mean(x, dim=(2, 3), keepdim=True)
  80. feat = self.bn(feat)
  81. feat = self.conv_gap(feat)
  82. feat = feat + x
  83. feat = self.conv_last(feat)
  84. return feat
  85. class GELayerS1(nn.Module):
  86. def __init__(self, in_chan, out_chan, exp_ratio=6):
  87. super(GELayerS1, self).__init__()
  88. mid_chan = in_chan * exp_ratio
  89. self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
  90. self.dwconv = nn.Sequential(
  91. nn.Conv2d(
  92. in_chan, mid_chan, kernel_size=3, stride=1,
  93. padding=1, groups=in_chan, bias=False),
  94. nn.BatchNorm2d(mid_chan),
  95. nn.ReLU(inplace=True), # not shown in paper
  96. )
  97. self.conv2 = nn.Sequential(
  98. nn.Conv2d(
  99. mid_chan, out_chan, kernel_size=1, stride=1,
  100. padding=0, bias=False),
  101. nn.BatchNorm2d(out_chan),
  102. )
  103. self.conv2[1].last_bn = True
  104. self.relu = nn.ReLU(inplace=True)
  105. def forward(self, x):
  106. feat = self.conv1(x)
  107. feat = self.dwconv(feat)
  108. feat = self.conv2(feat)
  109. feat = feat + x
  110. feat = self.relu(feat)
  111. return feat
  112. class GELayerS2(nn.Module):
  113. def __init__(self, in_chan, out_chan, exp_ratio=6):
  114. super(GELayerS2, self).__init__()
  115. mid_chan = in_chan * exp_ratio
  116. self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
  117. self.dwconv1 = nn.Sequential(
  118. nn.Conv2d(
  119. in_chan, mid_chan, kernel_size=3, stride=2,
  120. padding=1, groups=in_chan, bias=False),
  121. nn.BatchNorm2d(mid_chan),
  122. )
  123. self.dwconv2 = nn.Sequential(
  124. nn.Conv2d(
  125. mid_chan, mid_chan, kernel_size=3, stride=1,
  126. padding=1, groups=mid_chan, bias=False),
  127. nn.BatchNorm2d(mid_chan),
  128. nn.ReLU(inplace=True), # not shown in paper
  129. )
  130. self.conv2 = nn.Sequential(
  131. nn.Conv2d(
  132. mid_chan, out_chan, kernel_size=1, stride=1,
  133. padding=0, bias=False),
  134. nn.BatchNorm2d(out_chan),
  135. )
  136. self.conv2[1].last_bn = True
  137. self.shortcut = nn.Sequential(
  138. nn.Conv2d(
  139. in_chan, in_chan, kernel_size=3, stride=2,
  140. padding=1, groups=in_chan, bias=False),
  141. nn.BatchNorm2d(in_chan),
  142. nn.Conv2d(
  143. in_chan, out_chan, kernel_size=1, stride=1,
  144. padding=0, bias=False),
  145. nn.BatchNorm2d(out_chan),
  146. )
  147. self.relu = nn.ReLU(inplace=True)
  148. def forward(self, x):
  149. feat = self.conv1(x)
  150. feat = self.dwconv1(feat)
  151. feat = self.dwconv2(feat)
  152. feat = self.conv2(feat)
  153. shortcut = self.shortcut(x)
  154. feat = feat + shortcut
  155. feat = self.relu(feat)
  156. return feat
  157. class SegmentBranch(nn.Module):
  158. def __init__(self, input_channel=3):
  159. super(SegmentBranch, self).__init__()
  160. self.S1S2 = StemBlock(input_channel)
  161. self.S3 = nn.Sequential(
  162. GELayerS2(16, 32),
  163. GELayerS1(32, 32),
  164. )
  165. self.S4 = nn.Sequential(
  166. GELayerS2(32, 64),
  167. GELayerS1(64, 64),
  168. )
  169. self.S5_4 = nn.Sequential(
  170. GELayerS2(64, 128),
  171. GELayerS1(128, 128),
  172. GELayerS1(128, 128),
  173. GELayerS1(128, 128),
  174. )
  175. self.S5_5 = CEBlock()
  176. def forward(self, x):
  177. feat2 = self.S1S2(x)
  178. feat3 = self.S3(feat2)
  179. feat4 = self.S4(feat3)
  180. feat5_4 = self.S5_4(feat4)
  181. feat5_5 = self.S5_5(feat5_4)
  182. return feat2, feat3, feat4, feat5_4, feat5_5
  183. class BGALayer(nn.Module):
  184. def __init__(self):
  185. super(BGALayer, self).__init__()
  186. self.left1 = nn.Sequential(
  187. nn.Conv2d(
  188. 128, 128, kernel_size=3, stride=1,
  189. padding=1, groups=128, bias=False),
  190. nn.BatchNorm2d(128),
  191. nn.Conv2d(
  192. 128, 128, kernel_size=1, stride=1,
  193. padding=0, bias=False),
  194. )
  195. self.left2 = nn.Sequential(
  196. nn.Conv2d(
  197. 128, 128, kernel_size=3, stride=2,
  198. padding=1, bias=False),
  199. nn.BatchNorm2d(128),
  200. nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
  201. )
  202. self.right1 = nn.Sequential(
  203. nn.Conv2d(
  204. 128, 128, kernel_size=3, stride=1,
  205. padding=1, bias=False),
  206. nn.BatchNorm2d(128),
  207. )
  208. self.right2 = nn.Sequential(
  209. nn.Conv2d(
  210. 128, 128, kernel_size=3, stride=1,
  211. padding=1, groups=128, bias=False),
  212. nn.BatchNorm2d(128),
  213. nn.Conv2d(
  214. 128, 128, kernel_size=1, stride=1,
  215. padding=0, bias=False),
  216. )
  217. self.up1 = nn.Upsample(scale_factor=4)
  218. self.up2 = nn.Upsample(scale_factor=4)
  219. ##TODO: does this really has no relu?
  220. self.conv = nn.Sequential(
  221. nn.Conv2d(
  222. 128, 128, kernel_size=3, stride=1,
  223. padding=1, bias=False),
  224. nn.BatchNorm2d(128),
  225. nn.ReLU(inplace=True), # not shown in paper
  226. )
  227. def forward(self, x_d, x_s):
  228. dsize = x_d.size()[2:]
  229. left1 = self.left1(x_d)
  230. left2 = self.left2(x_d)
  231. right1 = self.right1(x_s)
  232. right2 = self.right2(x_s)
  233. right1 = self.up1(right1)
  234. left = left1 * torch.sigmoid(right1)
  235. right = left2 * torch.sigmoid(right2)
  236. right = self.up2(right)
  237. out = self.conv(left + right)
  238. return out
  239. class SegmentHead(nn.Module):
  240. def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True):
  241. super(SegmentHead, self).__init__()
  242. self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
  243. self.drop = nn.Dropout(0.1)
  244. self.up_factor = up_factor
  245. out_chan = n_classes
  246. mid_chan2 = up_factor * up_factor if aux else mid_chan
  247. up_factor = up_factor // 2 if aux else up_factor
  248. self.conv_out = nn.Sequential(
  249. nn.Sequential(
  250. nn.Upsample(scale_factor=2),
  251. ConvBNReLU(mid_chan, mid_chan2, 3, stride=1)
  252. ) if aux else nn.Identity(),
  253. nn.Conv2d(mid_chan2, out_chan, 1, 1, 0, bias=True),
  254. nn.Upsample(scale_factor=up_factor, mode='bilinear', align_corners=False)
  255. )
  256. def forward(self, x):
  257. feat = self.conv(x)
  258. feat = self.drop(feat)
  259. feat = self.conv_out(feat)
  260. return feat
  261. class BiSeNetV2(nn.Module):
  262. def __init__(self, n_classes, input_channels=3, aux_mode='train'):
  263. super(BiSeNetV2, self).__init__()
  264. self.aux_mode = aux_mode
  265. self.detail = DetailBranch(input_channels)
  266. self.segment = SegmentBranch(input_channels)
  267. self.bga = BGALayer()
  268. ## TODO: what is the number of mid chan ?
  269. self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False)
  270. if self.aux_mode == 'train':
  271. self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4)
  272. self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8)
  273. self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16)
  274. self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)
  275. self.init_weights()
  276. def forward(self, x):
  277. size = x.size()[2:]
  278. feat_d = self.detail(x)
  279. feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
  280. feat_head = self.bga(feat_d, feat_s)
  281. logits = self.head(feat_head)
  282. if self.aux_mode == 'train':
  283. logits_aux2 = self.aux2(feat2)
  284. logits_aux3 = self.aux3(feat3)
  285. logits_aux4 = self.aux4(feat4)
  286. logits_aux5_4 = self.aux5_4(feat5_4)
  287. return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4
  288. elif self.aux_mode == 'eval':
  289. return logits,
  290. elif self.aux_mode == 'pred':
  291. pred = logits.argmax(dim=1)
  292. return pred
  293. else:
  294. raise NotImplementedError
  295. def init_weights(self):
  296. for name, module in self.named_modules():
  297. if isinstance(module, (nn.Conv2d, nn.Linear)):
  298. nn.init.kaiming_normal_(module.weight, mode='fan_out')
  299. if not module.bias is None: nn.init.constant_(module.bias, 0)
  300. elif isinstance(module, nn.modules.batchnorm._BatchNorm):
  301. if hasattr(module, 'last_bn') and module.last_bn:
  302. nn.init.zeros_(module.weight)
  303. else:
  304. nn.init.ones_(module.weight)
  305. nn.init.zeros_(module.bias)
  306. self.load_pretrain()
  307. def load_pretrain(self):
  308. # 230423:推理时,不必在这里加载预训练模型
  309. pass
  310. def get_params(self):
  311. def add_param_to_list(mod, wd_params, nowd_params):
  312. for param in mod.parameters():
  313. if param.dim() == 1:
  314. nowd_params.append(param)
  315. elif param.dim() == 4:
  316. wd_params.append(param)
  317. else:
  318. print(name)
  319. wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
  320. for name, child in self.named_children():
  321. if 'head' in name or 'aux' in name:
  322. add_param_to_list(child, lr_mul_wd_params, lr_mul_nowd_params)
  323. else:
  324. add_param_to_list(child, wd_params, nowd_params)
  325. return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
  326. class OhemCELoss(nn.Module):
  327. """
  328. 算法本质:
  329. Ohem本质:核心思路是取所有损失大于阈值的像素点参与计算,但是最少也要保证取n_min个
  330. """
  331. def __init__(self, paramsDict, thresh, lb_ignore=255):
  332. super(OhemCELoss, self).__init__()
  333. self.paramsDict = paramsDict
  334. device_str = self.paramsDict['params']['device_str']
  335. # 确保模型被发送到device_str
  336. device = torch.device(device_str)
  337. # self.thresh = 0.3567
  338. self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).to(device)
  339. # self.lb_ignore = 255
  340. self.lb_ignore = lb_ignore
  341. self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')
  342. def forward(self, logits, labels):
  343. # logits: [2,11,1088,896] batch,classNum,height,width
  344. # labels: [2,1088,896] batch,height,width
  345. # 1、计算n_min(最少算多少个像素点)的大小
  346. # n_min的大小:一个batch的n张h*w的label图的所有的像素点的十六分之一
  347. # n_min: 121856
  348. n_min = labels[labels != self.lb_ignore].numel() // 16
  349. # 2、交叉熵预测得到loss之后,打平成一维的
  350. # loss.shape = (1949696,) 1949696 = 2 * 1088 * 896
  351. loss = self.criteria(logits, labels).view(-1)
  352. # 3、所有loss中大于阈值的,这边叫做loss hard,这些点才参与损失计算
  353. # 注意,这里是优化了pytorch中 Ohem 排序的,不然排序太耗时间了
  354. # loss_hard.shape = (140232,)
  355. loss_hard = loss[loss > self.thresh]
  356. # 4、如果总数小于了n_min,那么肯定要保证有n_min个
  357. if loss_hard.numel() < n_min:
  358. loss_hard, _ = loss.topk(n_min)
  359. # 5、如果参与的像素点的个数大于了n_min个,那么这些点都参与计算
  360. # loss_hard_mean = 0.7070
  361. loss_hard_mean = torch.mean(loss_hard)
  362. # 6、返回损失的均值
  363. # 7、为什么Ohem的损失不能很好的评估模型的损失
  364. # 因为Ohem对应的损失只考虑了大于阈值对应部分的损失,小于阈值部分的没有考虑
  365. return loss_hard_mean
  366. # if __name__ == "__main__":
  367. #
  368. # # ==========================================================
  369. # # 支持不同输入通道的bisenetv2
  370. # # ==========================================================
  371. #
  372. # input_channels = 7
  373. #
  374. # x = torch.randn(2, input_channels, 256, 256).cuda()
  375. # # x = torch.randn(2, 3, 224, 224).cuda()
  376. # print("=============输入:=============")
  377. # print(x.shape)
  378. #
  379. # model = BiSeNetV2(n_classes=19,input_channels=7)
  380. #
  381. # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  382. # print(device)
  383. # model = model.to(device)
  384. #
  385. # netBeforeTime = time.time()
  386. # outs = model(x)
  387. # netEndTime = time.time()
  388. # print("模型推理花费时间:",netEndTime-netBeforeTime)
  389. # print("=============输出:=============")
  390. # for out in outs:
  391. # print(out.size())
  392. # # print(logits.size())
  393. """
  394. =============输入:=============
  395. torch.Size([2, 7, 256, 256])
  396. cuda
  397. 模型推理花费时间: 0.3020000457763672
  398. =============输出:=============
  399. torch.Size([2, 19, 256, 256])
  400. torch.Size([2, 19, 256, 256])
  401. torch.Size([2, 19, 256, 256])
  402. torch.Size([2, 19, 256, 256])
  403. torch.Size([2, 19, 256, 256])
  404. 进程已结束,退出代码0
  405. """