fry_bisenetv2_predictor_V04_250819.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. import numpy as np
  2. import json
  3. import torch
  4. import datetime
  5. import torch.nn as nn
  6. import cv2
  7. from pathlib import Path
  8. import copy
  9. from typing import Dict, List, Tuple, Optional
  10. from app.utils.card_inference.backbone import BiSeNetV2
  11. from app.utils.card_inference.predict_preprocess import predict_preprocess
  12. from app.utils.card_inference.create_predict_result import create_result_singleImg
  13. from app.utils.card_inference.handle_result import process_detection_result
  14. import logging
  15. logging.basicConfig(level=logging.INFO)
  16. def fry_algo_print(level_str: str, info_str: str):
  17. # logging.info(f"[{level_str}] : {info_str}")
  18. pass
  19. def fry_cv2_imread(filename, flags=cv2.IMREAD_COLOR):
  20. """支持中文路径的图像读取"""
  21. try:
  22. with open(filename, 'rb') as f:
  23. chunk = f.read()
  24. chunk_arr = np.frombuffer(chunk, dtype=np.uint8)
  25. img = cv2.imdecode(chunk_arr, flags)
  26. if img is None:
  27. fry_algo_print("警告", f"Warning: Unable to decode image: {filename}")
  28. return img
  29. except IOError as e:
  30. fry_algo_print("错误", f"IOError: Unable to read file: {filename}")
  31. fry_algo_print("错误", f"Error details: {str(e)}")
  32. return None
  33. def fry_cv2_imwrite(filename, img, params=None):
  34. """支持中文路径的图像保存"""
  35. try:
  36. ext = Path(filename).suffix.lower()
  37. result, encoded_img = cv2.imencode(ext, img, params)
  38. if result:
  39. with open(filename, 'wb') as f:
  40. encoded_img.tofile(f)
  41. return True
  42. else:
  43. fry_algo_print("警告", f"Warning: Unable to encode image: {filename}")
  44. return False
  45. except Exception as e:
  46. fry_algo_print("错误", f"Error: Unable to write file: {filename}")
  47. fry_algo_print("错误", f"Error details: {str(e)}")
  48. return False
  49. def fry_opencv_chinese_path_init():
  50. """初始化OpenCV中文路径支持"""
  51. cv2.imread = fry_cv2_imread
  52. cv2.imwrite = fry_cv2_imwrite
  53. # 初始化OpenCV中文路径支持
  54. OPENCV_IO_ALREADY_INIT = False
  55. if not OPENCV_IO_ALREADY_INIT:
  56. fry_opencv_chinese_path_init()
  57. OPENCV_IO_ALREADY_INIT = True
  58. class FryBisenetV2Predictor:
  59. """BiSeNetV2 语义分割预测器"""
  60. def __init__(self,
  61. pth_path: str,
  62. real_seg_class_dict: Dict[int, str],
  63. imgSize_train_dict: Dict[str, int],
  64. confidence: float = 0.5,
  65. label_colors_dict: Optional[Dict[str, Tuple[int, int, int]]] = None,
  66. input_channels: int = 3,
  67. aux_mode: str = "eval"):
  68. """
  69. 初始化预测器
  70. Args:
  71. pth_path: 模型权重文件路径
  72. real_seg_class_dict: 真实的分割类别字典,格式为 {类别id: 类别名称}
  73. imgSize_train_dict: 训练时的图像尺寸,格式为 {'width': 宽度, 'height': 高度}
  74. confidence: 置信度阈值
  75. label_colors_dict: 类别颜色字典,格式为 {类别名称: (R, G, B)}
  76. input_channels: 输入通道数
  77. aux_mode: 辅助模式
  78. """
  79. self.pth_path = pth_path
  80. self.real_seg_class_dict = real_seg_class_dict
  81. self.imgSize_train_dict = imgSize_train_dict
  82. self.confidence = confidence
  83. self.input_channels = input_channels
  84. self.aux_mode = aux_mode
  85. # 构建完整的分割类别字典(包含背景类)
  86. self.seg_class_dict = {0: '___background___'}
  87. self.seg_class_dict.update(real_seg_class_dict)
  88. self.n_classes = len(self.seg_class_dict)
  89. # 生成或使用提供的颜色字典
  90. self.label_colors_dict = self._generate_label_colors(label_colors_dict)
  91. # 获取设备
  92. self.device = self._get_device()
  93. # 初始化模型
  94. self.model = self._init_model()
  95. @staticmethod
  96. def _get_device():
  97. """获取计算设备"""
  98. return torch.device("cuda" if torch.cuda.is_available() else "cpu")
  99. def _generate_label_colors(self, label_colors_dict: Optional[Dict[str, Tuple[int, int, int]]]) -> Dict[
  100. str, Tuple[int, int, int]]:
  101. """
  102. 生成或补充类别颜色字典
  103. Args:
  104. label_colors_dict: 用户提供的颜色字典
  105. Returns:
  106. 完整的颜色字典
  107. """
  108. if label_colors_dict is None:
  109. label_colors_dict = {}
  110. # 为所有类别生成颜色(除了背景)
  111. np.random.seed(42) # 设置随机种子以保证颜色一致性
  112. for class_id, class_name in self.seg_class_dict.items():
  113. if class_id == 0: # 跳过背景类
  114. continue
  115. if class_name not in label_colors_dict:
  116. # 生成随机颜色,避免太暗的颜色
  117. color = tuple(np.random.randint(50, 256, 3).tolist())
  118. label_colors_dict[class_name] = color
  119. return label_colors_dict
  120. def _load_model_weights(self, model: nn.Module, modelLoadPth: str) -> nn.Module:
  121. """
  122. 加载模型权重
  123. Args:
  124. model: 模型对象
  125. modelLoadPth: 权重文件路径
  126. Returns:
  127. 加载权重后的模型
  128. """
  129. fry_algo_print("信息", "加载预训练参数...")
  130. weights_dict = torch.load(modelLoadPth, map_location=self.device)
  131. new_weights_dict = {}
  132. exclude_layer_list = ["aux2", 'aux3', 'aux4', 'aux5']
  133. all_layer_num = 0
  134. ok_layer_num = 0
  135. for k, v in weights_dict.items():
  136. all_layer_num += 1
  137. is_exclude = False
  138. # 检查是否需要排除该层
  139. for exclude_str in exclude_layer_list:
  140. if exclude_str in k:
  141. is_exclude = True
  142. break
  143. if not is_exclude:
  144. new_weights_dict[k] = v
  145. ok_layer_num += 1
  146. else:
  147. fry_algo_print("信息", f"被排除的层:{k}")
  148. # 加载权重,不要求严格对等
  149. model.load_state_dict(new_weights_dict, strict=False)
  150. fry_algo_print("信息", f"成功加载模型层数:{ok_layer_num}/{all_layer_num}")
  151. return model
  152. def _init_model(self) -> nn.Module:
  153. """初始化并加载模型"""
  154. model = BiSeNetV2(self.n_classes, self.input_channels, self.aux_mode)
  155. model = model.to(self.device)
  156. model = self._load_model_weights(model, self.pth_path)
  157. model.eval()
  158. return model
  159. def _predict_tensor(self, CHW: torch.Tensor) -> Dict:
  160. """
  161. 对单个图像张量进行预测
  162. Args:
  163. CHW: 形状为 (C, H, W) 的图像张量
  164. Returns:
  165. 包含预测结果的字典
  166. """
  167. with torch.no_grad():
  168. NCHW = CHW.unsqueeze(0)
  169. # 因为单张图片推理 batch norm 层会报错,所以复制一份
  170. NCHW2 = torch.cat([NCHW, NCHW], dim=0)
  171. # 模型推理
  172. logits, *logits_aux = self.model(NCHW2)
  173. # 计算概率和预测类别
  174. probs = torch.softmax(logits, dim=1)
  175. preds = torch.argmax(probs, dim=1)
  176. # 转换为numpy数组
  177. probs_np = probs.detach().cpu().numpy()
  178. preds_np = preds.detach().cpu().numpy()
  179. # 取第一张图片的结果
  180. ansImg_needSave = preds_np[0]
  181. ansProbs = probs_np[0]
  182. return {
  183. "ans_img": ansImg_needSave,
  184. "probs": ansProbs,
  185. "file_name": "result"
  186. }
  187. def _save_result_json(self, result: Dict, json_path: Path):
  188. """
  189. 保存预测结果为JSON文件
  190. Args:
  191. result: 预测结果字典
  192. json_path: JSON文件保存路径
  193. """
  194. # 将numpy数组转换为可序列化的格式
  195. json_result = {}
  196. for key, value in result.items():
  197. if isinstance(value, np.ndarray):
  198. json_result[key] = value.tolist()
  199. elif isinstance(value, dict):
  200. json_result[key] = {}
  201. for k, v in value.items():
  202. if isinstance(v, np.ndarray):
  203. json_result[key][k] = v.tolist()
  204. else:
  205. json_result[key][k] = v
  206. else:
  207. json_result[key] = value
  208. with open(json_path, 'w', encoding='utf-8') as f:
  209. json.dump(json_result, f, ensure_ascii=False, indent=2)
  210. def predict_single_image_np(self,
  211. img_bgr: np.ndarray,
  212. image_path_str: str = None,
  213. save_visualization: bool = True,
  214. save_json: bool = True,
  215. answer_json_dir_str: Optional[str] = None,
  216. input_channels=3
  217. ) -> Dict:
  218. """
  219. 预测单张图片
  220. Args:
  221. img_path: 图片路径
  222. save_visualization: 是否保存可视化结果
  223. save_json: 是否保存JSON结果
  224. answer_json_dir_str: JSON结果保存目录
  225. Returns:
  226. 预测结果字典
  227. """
  228. if image_path_str is None:
  229. timestamp = datetime.now().strftime('%y%m%d_%H%M%S_%f')
  230. image_path_real_str = f"{timestamp}.jpg"
  231. else:
  232. image_path_real_str = str(image_path_str)
  233. img_path_real_obj = Path(image_path_real_str).resolve()
  234. answer_json_dir_str_obj = Path(answer_json_dir_str).resolve()
  235. shape = img_bgr.shape
  236. image_channel = shape[2]
  237. fry_algo_print("信息", f"模型需要的通道数为:{input_channels}")
  238. fry_algo_print("信息", f"测试的图片实际的通道数为:{image_channel}")
  239. if image_channel != input_channels:
  240. # if image_channel==3 and input_channels==1:
  241. # img_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
  242. # elif image_channel==4 and input_channels==1:
  243. # img_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_BGRA2GRAY)
  244. # elif image_channel==1 and input_channels==3:
  245. # img_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_GRAY2BGR)
  246. # else:
  247. # raise ValueError(f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{input_channels}")
  248. raise ValueError(
  249. f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{input_channels}")
  250. # 获取原始图片尺寸
  251. height, width = img_bgr.shape[:2]
  252. originImgSize = {'width': width, 'height': height}
  253. # 预处理
  254. imgTensor_CHW_norm = predict_preprocess(img_bgr, self.imgSize_train_dict)
  255. # 预测
  256. ansImgDict = self._predict_tensor(imgTensor_CHW_norm)
  257. image_path_name = str(img_path_real_obj.name)
  258. # 创建结果
  259. per_img_seg_result = create_result_singleImg(
  260. self.seg_class_dict,
  261. ansImgDict,
  262. originImgSize,
  263. self.imgSize_train_dict,
  264. confidence=self.confidence,
  265. )
  266. # 保存JSON结果
  267. if save_json and answer_json_dir_str:
  268. json_dir = Path(answer_json_dir_str)
  269. json_dir.mkdir(parents=True, exist_ok=True)
  270. if image_path_str is None:
  271. cv2.imwrite(image_path_real_str, img_bgr)
  272. # 获取图片文件名(不含扩展名)
  273. img_name = Path(img_path_real_obj).stem
  274. json_path = json_dir / f"{img_name}.json"
  275. self._save_result_json(per_img_seg_result, json_path)
  276. fry_algo_print("成功", f"JSON结果已保存到:{json_path}")
  277. # 保存可视化结果
  278. if save_visualization:
  279. result_img = process_detection_result(img_bgr, per_img_seg_result, self.label_colors_dict)
  280. output_path = str(answer_json_dir_str_obj / f"{img_path_real_obj.name}")
  281. cv2.imwrite(output_path, result_img)
  282. fry_algo_print("成功", f"可视化结果已保存到:{output_path}")
  283. return per_img_seg_result
  284. def _save_result_json(self, result: Dict, json_path: Path):
  285. """
  286. 保存预测结果为JSON文件
  287. Args:
  288. result: 预测结果字典
  289. json_path: JSON文件保存路径
  290. """
  291. # 将numpy数组转换为可序列化的格式
  292. json_result = {}
  293. for key, value in result.items():
  294. if isinstance(value, np.ndarray):
  295. json_result[key] = value.tolist()
  296. elif isinstance(value, dict):
  297. json_result[key] = {}
  298. for k, v in value.items():
  299. if isinstance(v, np.ndarray):
  300. json_result[key][k] = v.tolist()
  301. else:
  302. json_result[key][k] = v
  303. else:
  304. json_result[key] = value
  305. with open(json_path, 'w', encoding='utf-8') as f:
  306. json.dump(json_result, f, ensure_ascii=False, indent=2)
  307. def predict_from_image(self, img_bgr: np.ndarray) -> Dict:
  308. """
  309. 直接从解码后的图像数据(numpy数组)进行预测。
  310. Args:
  311. img_bgr: BGR格式的图像,作为一个numpy数组。
  312. Returns:
  313. 预测结果字典。
  314. """
  315. # 检查通道数是否匹配
  316. shape = img_bgr.shape
  317. image_channel = shape[2] if len(shape) == 3 else 1
  318. if image_channel != self.input_channels:
  319. raise ValueError(
  320. f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{self.input_channels}")
  321. # 获取原始图片尺寸
  322. height, width = img_bgr.shape[:2]
  323. originImgSize = {'width': width, 'height': height}
  324. # 预处理
  325. imgTensor_CHW_norm = predict_preprocess(img_bgr, self.imgSize_train_dict)
  326. # 预测
  327. ansImgDict = self._predict_tensor(imgTensor_CHW_norm)
  328. # 创建结果
  329. per_img_seg_result = create_result_singleImg(
  330. self.seg_class_dict,
  331. ansImgDict,
  332. originImgSize,
  333. self.imgSize_train_dict,
  334. confidence=self.confidence
  335. )
  336. return per_img_seg_result
  337. def predict_single_image(self,
  338. img_path: str,
  339. save_visualization: bool = True,
  340. save_json: bool = True,
  341. answer_json_dir_str: Optional[str] = None,
  342. input_channels=3
  343. ) -> Dict:
  344. """
  345. 预测单张图片
  346. Args:
  347. img_path: 图片路径
  348. save_visualization: 是否保存可视化结果
  349. save_json: 是否保存JSON结果
  350. answer_json_dir_str: JSON结果保存目录
  351. Returns:
  352. 预测结果字典
  353. """
  354. img_path_obj = Path(img_path).resolve()
  355. img_path_parent_obj = img_path_obj.parent
  356. answer_json_dir_str_obj = Path(answer_json_dir_str).resolve()
  357. # 读取图片
  358. img_bgr = cv2.imread(str(img_path_obj))
  359. if img_bgr is None:
  360. raise ValueError(f"无法读取图片:{img_path}")
  361. per_img_seg_result = self.predict_single_image_np(
  362. img_bgr=img_bgr,
  363. image_path_str=str(img_path_obj),
  364. save_visualization=save_visualization,
  365. save_json=save_json,
  366. answer_json_dir_str=answer_json_dir_str,
  367. input_channels=input_channels
  368. )
  369. return per_img_seg_result
  370. def predict_batch(self,
  371. img_paths: List[str],
  372. save_visualization: bool = True,
  373. save_json: bool = True,
  374. answer_json_dir_str: Optional[str] = None,
  375. input_channels=3
  376. ) -> List[Dict]:
  377. """
  378. 批量预测图片
  379. Args:
  380. img_paths: 图片路径列表
  381. save_visualization: 是否保存可视化结果
  382. save_json: 是否保存JSON结果
  383. answer_json_dir_str: JSON结果保存目录
  384. output_dir: 可视化结果保存目录
  385. Returns:
  386. 所有图片的预测结果列表
  387. """
  388. answer_json_dir_str_obj = Path(answer_json_dir_str).resolve()
  389. results = []
  390. Path(answer_json_dir_str).mkdir(parents=True, exist_ok=True)
  391. # 批量处理
  392. for i, img_path in enumerate(img_paths):
  393. fry_algo_print("信息", f"处理图片 {i + 1}/{len(img_paths)}: {img_path}")
  394. try:
  395. # 读取图片
  396. img_bgr = cv2.imread(img_path)
  397. if img_bgr is None:
  398. fry_algo_print("信息", f"警告:无法读取图片 {img_path}")
  399. continue
  400. shape = img_bgr.shape
  401. image_channel = shape[2]
  402. if image_channel != input_channels:
  403. fry_algo_print("信息", f"模型需要的通道数为:{input_channels}")
  404. fry_algo_print("信息", f"测试的图片实际的通道数为:{image_channel}")
  405. fry_algo_print("错误",
  406. f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{input_channels}")
  407. continue
  408. # 获取原始图片尺寸
  409. height, width = img_bgr.shape[:2]
  410. originImgSize = {'width': width, 'height': height}
  411. # 预处理
  412. imgTensor_CHW_norm = predict_preprocess(img_bgr, self.imgSize_train_dict)
  413. # 预测
  414. ansImgDict = self._predict_tensor(imgTensor_CHW_norm)
  415. img_path_obj = Path(img_path).resolve()
  416. image_path_name = str(img_path_obj.name)
  417. # 创建结果
  418. per_img_seg_result = create_result_singleImg(
  419. self.seg_class_dict,
  420. ansImgDict,
  421. originImgSize,
  422. self.imgSize_train_dict,
  423. confidence=self.confidence,
  424. image_path_name=image_path_name
  425. )
  426. # 保存JSON结果
  427. if save_json and answer_json_dir_str:
  428. json_dir = Path(answer_json_dir_str)
  429. json_dir.mkdir(parents=True, exist_ok=True)
  430. img_name = Path(img_path).stem
  431. json_path = json_dir / f"{img_name}.json"
  432. self._save_result_json(per_img_seg_result, json_path)
  433. # 保存可视化结果
  434. if save_visualization:
  435. result_img = process_detection_result(img_bgr, per_img_seg_result, self.label_colors_dict)
  436. output_path = answer_json_dir_str_obj / f"{Path(img_path).name}"
  437. cv2.imwrite(str(output_path), result_img)
  438. results.append(per_img_seg_result)
  439. except Exception as e:
  440. fry_algo_print("失败", f"处理图片 {img_path} 时出错:{e}")
  441. continue
  442. fry_algo_print("成功", f"批量处理完成,成功处理 {len(results)}/{len(img_paths)} 张图片")
  443. return results
  444. def main():
  445. """使用示例"""
  446. # 配置参数
  447. pth_path = r"segmentation_bisenetv2.pth"
  448. input_channels = 3
  449. real_seg_class_dict = {1: 'outer_box'}
  450. # 为不同类别设置不同颜色(可选)
  451. label_colors_dict = {
  452. 'outer_box': (255, 0, 0),
  453. }
  454. imgSize_train_dict = {'width': 1280, 'height': 1280}
  455. confidence = 0.5
  456. # 创建预测器
  457. predictor = FryBisenetV2Predictor(
  458. pth_path=pth_path,
  459. real_seg_class_dict=real_seg_class_dict,
  460. imgSize_train_dict=imgSize_train_dict,
  461. confidence=confidence,
  462. label_colors_dict=label_colors_dict,
  463. input_channels=input_channels,
  464. )
  465. # 单张图片预测
  466. print("=== 单张图片预测 ===")
  467. now_img_path = r"input_output\images\coaxis_0008.jpg"
  468. answer_json_dir_str = r"input_output\images_answer_json_dir_str"
  469. result = predictor.predict_single_image(
  470. img_path=now_img_path,
  471. save_visualization=True,
  472. save_json=True,
  473. answer_json_dir_str=answer_json_dir_str,
  474. input_channels=input_channels,
  475. )
  476. # 批量预测示例
  477. # print("\n=== 批量图片预测 ===")
  478. # img_paths = [
  479. # r"input_output\images\coaxis_0008.jpg",
  480. # r"input_output\images\coaxis_0082.jpg",
  481. # r"input_output\images\ring_0001.jpg",
  482. # r"input_output\images\Pokemon_back_for_Edge_0001.jpg",
  483. # ]
  484. #
  485. # results = predictor.predict_batch(
  486. # img_paths=img_paths,
  487. # save_visualization=True,
  488. # save_json=True,
  489. # answer_json_dir_str=answer_json_dir_str,
  490. # input_channels=input_channels,
  491. # )
  492. def _test_pokemon_inner_box():
  493. # 配置参数
  494. pth_path = r"E:\_250807_训练好的导出的模型\_250808_1043_宝可梦内框训练效果还行\pth_and_images\segmentation_bisenetv2.pth"
  495. input_channels = 3
  496. real_seg_class_dict = {1: 'inner_box'}
  497. # 为不同类别设置不同颜色(可选)
  498. label_colors_dict = {
  499. 'outer_box': (255, 0, 0),
  500. }
  501. imgSize_train_dict = {'width': 1280, 'height': 1280}
  502. confidence = 0.5
  503. # 创建预测器
  504. predictor = FryBisenetV2Predictor(
  505. pth_path=pth_path,
  506. real_seg_class_dict=real_seg_class_dict,
  507. imgSize_train_dict=imgSize_train_dict,
  508. confidence=confidence,
  509. label_colors_dict=label_colors_dict,
  510. input_channels=input_channels,
  511. )
  512. # 单张图片预测
  513. print("=== 单张图片预测 ===")
  514. now_img_path = r"E:\_250807_训练好的导出的模型\_250808_1043_宝可梦内框训练效果还行\pth_and_images\images\diff_big_00065.jpg"
  515. answer_json_dir_str = r"E:\_250807_训练好的导出的模型\_250808_1043_宝可梦内框训练效果还行\pth_and_images\images_answer"
  516. result = predictor.predict_single_image(
  517. img_path=now_img_path,
  518. save_visualization=True,
  519. save_json=True,
  520. answer_json_dir_str=answer_json_dir_str
  521. )
  522. def _test_pokemon_back_edge():
  523. # 配置参数
  524. pth_path = r"E:\_250807_训练好的导出的模型\_250811_1104_宝可梦背面边角\pth_and_images\segmentation_bisenetv2.pth"
  525. input_channels = 3
  526. real_seg_class_dict = {
  527. 1: 'wear',
  528. 2: 'wear_and_impact',
  529. 3: 'impact',
  530. 4: 'damaged',
  531. 5: 'wear_and_stain',
  532. }
  533. # 为不同类别设置不同颜色(可选)
  534. # label_colors_dict = {
  535. # 'outer_box': (255, 0, 0),
  536. # }
  537. imgSize_train_dict = {'width': 512, 'height': 512}
  538. confidence = 0.5
  539. # 创建预测器
  540. predictor = FryBisenetV2Predictor(
  541. pth_path=pth_path,
  542. real_seg_class_dict=real_seg_class_dict,
  543. imgSize_train_dict=imgSize_train_dict,
  544. confidence=confidence,
  545. input_channels=input_channels,
  546. )
  547. # 单张图片预测
  548. print("=== 单张图片预测 ===")
  549. now_img_path = r"E:\_250807_训练好的导出的模型\_250811_1104_宝可梦背面边角\pth_and_images\images\split\Pokémon_back_for_Edge_0001_bottom_grid_r0_c0.jpg"
  550. answer_json_dir_str = r"E:\_250807_训练好的导出的模型\_250811_1104_宝可梦背面边角\pth_and_images\images_answer"
  551. result = predictor.predict_single_image(
  552. img_path=now_img_path,
  553. save_visualization=True,
  554. save_json=True,
  555. answer_json_dir_str=answer_json_dir_str,
  556. input_channels=input_channels
  557. )
  558. if __name__ == "__main__":
  559. main()
  560. # test_pokemon_inner_box()
  561. # test_pokemon_back_edge()