backbone.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import torch
  2. import torch.nn as nn
  3. import time
  4. class ConvBNReLU(nn.Module):
  5. def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1,
  6. dilation=1, groups=1, bias=False):
  7. super(ConvBNReLU, self).__init__()
  8. self.conv = nn.Conv2d(
  9. in_chan, out_chan, kernel_size=ks, stride=stride,
  10. padding=padding, dilation=dilation,
  11. groups=groups, bias=bias)
  12. self.bn = nn.BatchNorm2d(out_chan)
  13. self.relu = nn.ReLU(inplace=True)
  14. def forward(self, x):
  15. feat = self.conv(x)
  16. feat = self.bn(feat)
  17. feat = self.relu(feat)
  18. return feat
  19. class UpSample(nn.Module):
  20. def __init__(self, n_chan, factor=2):
  21. super(UpSample, self).__init__()
  22. out_chan = n_chan * factor * factor
  23. self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
  24. self.up = nn.PixelShuffle(factor)
  25. self.init_weight()
  26. def forward(self, x):
  27. feat = self.proj(x)
  28. feat = self.up(feat)
  29. return feat
  30. def init_weight(self):
  31. nn.init.xavier_normal_(self.proj.weight, gain=1.)
  32. class DetailBranch(nn.Module):
  33. def __init__(self, input_channel=3):
  34. super(DetailBranch, self).__init__()
  35. self.S1 = nn.Sequential(
  36. ConvBNReLU(input_channel, 64, 3, stride=2),
  37. ConvBNReLU(64, 64, 3, stride=1),
  38. )
  39. self.S2 = nn.Sequential(
  40. ConvBNReLU(64, 64, 3, stride=2),
  41. ConvBNReLU(64, 64, 3, stride=1),
  42. ConvBNReLU(64, 64, 3, stride=1),
  43. )
  44. self.S3 = nn.Sequential(
  45. ConvBNReLU(64, 128, 3, stride=2),
  46. ConvBNReLU(128, 128, 3, stride=1),
  47. ConvBNReLU(128, 128, 3, stride=1),
  48. )
  49. def forward(self, x):
  50. feat = self.S1(x)
  51. feat = self.S2(feat)
  52. feat = self.S3(feat)
  53. return feat
  54. class StemBlock(nn.Module):
  55. def __init__(self, input_channel=3):
  56. super(StemBlock, self).__init__()
  57. self.conv = ConvBNReLU(input_channel, 16, 3, stride=2)
  58. self.left = nn.Sequential(
  59. ConvBNReLU(16, 8, 1, stride=1, padding=0),
  60. ConvBNReLU(8, 16, 3, stride=2),
  61. )
  62. self.right = nn.MaxPool2d(
  63. kernel_size=3, stride=2, padding=1, ceil_mode=False)
  64. self.fuse = ConvBNReLU(32, 16, 3, stride=1)
  65. def forward(self, x):
  66. feat = self.conv(x)
  67. feat_left = self.left(feat)
  68. feat_right = self.right(feat)
  69. feat = torch.cat([feat_left, feat_right], dim=1)
  70. feat = self.fuse(feat)
  71. return feat
  72. class CEBlock(nn.Module):
  73. def __init__(self):
  74. super(CEBlock, self).__init__()
  75. self.bn = nn.BatchNorm2d(128)
  76. self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0)
  77. # TODO: in paper here is naive conv2d, no bn-relu
  78. self.conv_last = ConvBNReLU(128, 128, 3, stride=1)
  79. def forward(self, x):
  80. feat = torch.mean(x, dim=(2, 3), keepdim=True)
  81. feat = self.bn(feat)
  82. feat = self.conv_gap(feat)
  83. feat = feat + x
  84. feat = self.conv_last(feat)
  85. return feat
  86. class GELayerS1(nn.Module):
  87. def __init__(self, in_chan, out_chan, exp_ratio=6):
  88. super(GELayerS1, self).__init__()
  89. mid_chan = in_chan * exp_ratio
  90. self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
  91. self.dwconv = nn.Sequential(
  92. nn.Conv2d(
  93. in_chan, mid_chan, kernel_size=3, stride=1,
  94. padding=1, groups=in_chan, bias=False),
  95. nn.BatchNorm2d(mid_chan),
  96. nn.ReLU(inplace=True), # not shown in paper
  97. )
  98. self.conv2 = nn.Sequential(
  99. nn.Conv2d(
  100. mid_chan, out_chan, kernel_size=1, stride=1,
  101. padding=0, bias=False),
  102. nn.BatchNorm2d(out_chan),
  103. )
  104. self.conv2[1].last_bn = True
  105. self.relu = nn.ReLU(inplace=True)
  106. def forward(self, x):
  107. feat = self.conv1(x)
  108. feat = self.dwconv(feat)
  109. feat = self.conv2(feat)
  110. feat = feat + x
  111. feat = self.relu(feat)
  112. return feat
  113. class GELayerS2(nn.Module):
  114. def __init__(self, in_chan, out_chan, exp_ratio=6):
  115. super(GELayerS2, self).__init__()
  116. mid_chan = in_chan * exp_ratio
  117. self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1)
  118. self.dwconv1 = nn.Sequential(
  119. nn.Conv2d(
  120. in_chan, mid_chan, kernel_size=3, stride=2,
  121. padding=1, groups=in_chan, bias=False),
  122. nn.BatchNorm2d(mid_chan),
  123. )
  124. self.dwconv2 = nn.Sequential(
  125. nn.Conv2d(
  126. mid_chan, mid_chan, kernel_size=3, stride=1,
  127. padding=1, groups=mid_chan, bias=False),
  128. nn.BatchNorm2d(mid_chan),
  129. nn.ReLU(inplace=True), # not shown in paper
  130. )
  131. self.conv2 = nn.Sequential(
  132. nn.Conv2d(
  133. mid_chan, out_chan, kernel_size=1, stride=1,
  134. padding=0, bias=False),
  135. nn.BatchNorm2d(out_chan),
  136. )
  137. self.conv2[1].last_bn = True
  138. self.shortcut = nn.Sequential(
  139. nn.Conv2d(
  140. in_chan, in_chan, kernel_size=3, stride=2,
  141. padding=1, groups=in_chan, bias=False),
  142. nn.BatchNorm2d(in_chan),
  143. nn.Conv2d(
  144. in_chan, out_chan, kernel_size=1, stride=1,
  145. padding=0, bias=False),
  146. nn.BatchNorm2d(out_chan),
  147. )
  148. self.relu = nn.ReLU(inplace=True)
  149. def forward(self, x):
  150. feat = self.conv1(x)
  151. feat = self.dwconv1(feat)
  152. feat = self.dwconv2(feat)
  153. feat = self.conv2(feat)
  154. shortcut = self.shortcut(x)
  155. feat = feat + shortcut
  156. feat = self.relu(feat)
  157. return feat
  158. class SegmentBranch(nn.Module):
  159. def __init__(self, input_channel=3):
  160. super(SegmentBranch, self).__init__()
  161. self.S1S2 = StemBlock(input_channel)
  162. self.S3 = nn.Sequential(
  163. GELayerS2(16, 32),
  164. GELayerS1(32, 32),
  165. )
  166. self.S4 = nn.Sequential(
  167. GELayerS2(32, 64),
  168. GELayerS1(64, 64),
  169. )
  170. self.S5_4 = nn.Sequential(
  171. GELayerS2(64, 128),
  172. GELayerS1(128, 128),
  173. GELayerS1(128, 128),
  174. GELayerS1(128, 128),
  175. )
  176. self.S5_5 = CEBlock()
  177. def forward(self, x):
  178. feat2 = self.S1S2(x)
  179. feat3 = self.S3(feat2)
  180. feat4 = self.S4(feat3)
  181. feat5_4 = self.S5_4(feat4)
  182. feat5_5 = self.S5_5(feat5_4)
  183. return feat2, feat3, feat4, feat5_4, feat5_5
  184. class BGALayer(nn.Module):
  185. def __init__(self):
  186. super(BGALayer, self).__init__()
  187. self.left1 = nn.Sequential(
  188. nn.Conv2d(
  189. 128, 128, kernel_size=3, stride=1,
  190. padding=1, groups=128, bias=False),
  191. nn.BatchNorm2d(128),
  192. nn.Conv2d(
  193. 128, 128, kernel_size=1, stride=1,
  194. padding=0, bias=False),
  195. )
  196. self.left2 = nn.Sequential(
  197. nn.Conv2d(
  198. 128, 128, kernel_size=3, stride=2,
  199. padding=1, bias=False),
  200. nn.BatchNorm2d(128),
  201. nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)
  202. )
  203. self.right1 = nn.Sequential(
  204. nn.Conv2d(
  205. 128, 128, kernel_size=3, stride=1,
  206. padding=1, bias=False),
  207. nn.BatchNorm2d(128),
  208. )
  209. self.right2 = nn.Sequential(
  210. nn.Conv2d(
  211. 128, 128, kernel_size=3, stride=1,
  212. padding=1, groups=128, bias=False),
  213. nn.BatchNorm2d(128),
  214. nn.Conv2d(
  215. 128, 128, kernel_size=1, stride=1,
  216. padding=0, bias=False),
  217. )
  218. self.up1 = nn.Upsample(scale_factor=4)
  219. self.up2 = nn.Upsample(scale_factor=4)
  220. ##TODO: does this really has no relu?
  221. self.conv = nn.Sequential(
  222. nn.Conv2d(
  223. 128, 128, kernel_size=3, stride=1,
  224. padding=1, bias=False),
  225. nn.BatchNorm2d(128),
  226. nn.ReLU(inplace=True), # not shown in paper
  227. )
  228. def forward(self, x_d, x_s):
  229. dsize = x_d.size()[2:]
  230. left1 = self.left1(x_d)
  231. left2 = self.left2(x_d)
  232. right1 = self.right1(x_s)
  233. right2 = self.right2(x_s)
  234. right1 = self.up1(right1)
  235. left = left1 * torch.sigmoid(right1)
  236. right = left2 * torch.sigmoid(right2)
  237. right = self.up2(right)
  238. out = self.conv(left + right)
  239. return out
  240. class SegmentHead(nn.Module):
  241. def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True):
  242. super(SegmentHead, self).__init__()
  243. self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1)
  244. self.drop = nn.Dropout(0.1)
  245. self.up_factor = up_factor
  246. out_chan = n_classes
  247. mid_chan2 = up_factor * up_factor if aux else mid_chan
  248. up_factor = up_factor // 2 if aux else up_factor
  249. self.conv_out = nn.Sequential(
  250. nn.Sequential(
  251. nn.Upsample(scale_factor=2),
  252. ConvBNReLU(mid_chan, mid_chan2, 3, stride=1)
  253. ) if aux else nn.Identity(),
  254. nn.Conv2d(mid_chan2, out_chan, 1, 1, 0, bias=True),
  255. nn.Upsample(scale_factor=up_factor, mode='bilinear', align_corners=False)
  256. )
  257. def forward(self, x):
  258. feat = self.conv(x)
  259. feat = self.drop(feat)
  260. feat = self.conv_out(feat)
  261. return feat
  262. class BiSeNetV2(nn.Module):
  263. def __init__(self, n_classes, input_channels=3, aux_mode='train'):
  264. super(BiSeNetV2, self).__init__()
  265. self.aux_mode = aux_mode
  266. self.detail = DetailBranch(input_channels)
  267. self.segment = SegmentBranch(input_channels)
  268. self.bga = BGALayer()
  269. ## TODO: what is the number of mid chan ?
  270. self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False)
  271. if self.aux_mode == 'train':
  272. self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4)
  273. self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8)
  274. self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16)
  275. self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)
  276. self.init_weights()
  277. def forward(self, x):
  278. size = x.size()[2:]
  279. feat_d = self.detail(x)
  280. feat2, feat3, feat4, feat5_4, feat_s = self.segment(x)
  281. feat_head = self.bga(feat_d, feat_s)
  282. logits = self.head(feat_head)
  283. if self.aux_mode == 'train':
  284. logits_aux2 = self.aux2(feat2)
  285. logits_aux3 = self.aux3(feat3)
  286. logits_aux4 = self.aux4(feat4)
  287. logits_aux5_4 = self.aux5_4(feat5_4)
  288. return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4
  289. elif self.aux_mode == 'eval':
  290. return logits,
  291. elif self.aux_mode == 'pred':
  292. pred = logits.argmax(dim=1)
  293. return pred
  294. else:
  295. raise NotImplementedError
  296. def init_weights(self):
  297. for name, module in self.named_modules():
  298. if isinstance(module, (nn.Conv2d, nn.Linear)):
  299. nn.init.kaiming_normal_(module.weight, mode='fan_out')
  300. if not module.bias is None: nn.init.constant_(module.bias, 0)
  301. elif isinstance(module, nn.modules.batchnorm._BatchNorm):
  302. if hasattr(module, 'last_bn') and module.last_bn:
  303. nn.init.zeros_(module.weight)
  304. else:
  305. nn.init.ones_(module.weight)
  306. nn.init.zeros_(module.bias)
  307. self.load_pretrain()
  308. def load_pretrain(self):
  309. # 230423:推理时,不必在这里加载预训练模型
  310. pass
  311. def get_params(self):
  312. def add_param_to_list(mod, wd_params, nowd_params):
  313. for param in mod.parameters():
  314. if param.dim() == 1:
  315. nowd_params.append(param)
  316. elif param.dim() == 4:
  317. wd_params.append(param)
  318. else:
  319. print(name)
  320. wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
  321. for name, child in self.named_children():
  322. if 'head' in name or 'aux' in name:
  323. add_param_to_list(child, lr_mul_wd_params, lr_mul_nowd_params)
  324. else:
  325. add_param_to_list(child, wd_params, nowd_params)
  326. return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
  327. class OhemCELoss(nn.Module):
  328. """
  329. 算法本质:
  330. Ohem本质:核心思路是取所有损失大于阈值的像素点参与计算,但是最少也要保证取n_min个
  331. """
  332. def __init__(self, paramsDict, thresh, lb_ignore=255):
  333. super(OhemCELoss, self).__init__()
  334. self.paramsDict = paramsDict
  335. device_str = self.paramsDict['params']['device_str']
  336. # 确保模型被发送到device_str
  337. device = torch.device(device_str)
  338. # self.thresh = 0.3567
  339. self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).to(device)
  340. # self.lb_ignore = 255
  341. self.lb_ignore = lb_ignore
  342. self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')
  343. def forward(self, logits, labels):
  344. # logits: [2,11,1088,896] batch,classNum,height,width
  345. # labels: [2,1088,896] batch,height,width
  346. # 1、计算n_min(最少算多少个像素点)的大小
  347. # n_min的大小:一个batch的n张h*w的label图的所有的像素点的十六分之一
  348. # n_min: 121856
  349. n_min = labels[labels != self.lb_ignore].numel() // 16
  350. # 2、交叉熵预测得到loss之后,打平成一维的
  351. # loss.shape = (1949696,) 1949696 = 2 * 1088 * 896
  352. loss = self.criteria(logits, labels).view(-1)
  353. # 3、所有loss中大于阈值的,这边叫做loss hard,这些点才参与损失计算
  354. # 注意,这里是优化了pytorch中 Ohem 排序的,不然排序太耗时间了
  355. # loss_hard.shape = (140232,)
  356. loss_hard = loss[loss > self.thresh]
  357. # 4、如果总数小于了n_min,那么肯定要保证有n_min个
  358. if loss_hard.numel() < n_min:
  359. loss_hard, _ = loss.topk(n_min)
  360. # 5、如果参与的像素点的个数大于了n_min个,那么这些点都参与计算
  361. # loss_hard_mean = 0.7070
  362. loss_hard_mean = torch.mean(loss_hard)
  363. # 6、返回损失的均值
  364. # 7、为什么Ohem的损失不能很好的评估模型的损失
  365. # 因为Ohem对应的损失只考虑了大于阈值对应部分的损失,小于阈值部分的没有考虑
  366. return loss_hard_mean
  367. if __name__ == "__main__":
  368. # ==========================================================
  369. # 支持不同输入通道的bisenetv2
  370. # ==========================================================
  371. input_channels = 7
  372. x = torch.randn(2, input_channels, 256, 256).cuda()
  373. # x = torch.randn(2, 3, 224, 224).cuda()
  374. print("=============输入:=============")
  375. print(x.shape)
  376. model = BiSeNetV2(n_classes=19, input_channels=7)
  377. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  378. print(device)
  379. model = model.to(device)
  380. netBeforeTime = time.time()
  381. outs = model(x)
  382. netEndTime = time.time()
  383. print("模型推理花费时间:", netEndTime - netBeforeTime)
  384. print("=============输出:=============")
  385. for out in outs:
  386. print(out.size())
  387. # print(logits.size())
  388. """
  389. =============输入:=============
  390. torch.Size([2, 7, 256, 256])
  391. cuda
  392. 模型推理花费时间: 0.3020000457763672
  393. =============输出:=============
  394. torch.Size([2, 19, 256, 256])
  395. torch.Size([2, 19, 256, 256])
  396. torch.Size([2, 19, 256, 256])
  397. torch.Size([2, 19, 256, 256])
  398. torch.Size([2, 19, 256, 256])
  399. 进程已结束,退出代码0
  400. """