fry_bisenetv2_predictor_V04_250819.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. import numpy as np
  2. import json
  3. import torch
  4. import datetime
  5. import torch.nn as nn
  6. import cv2
  7. from pathlib import Path
  8. import copy
  9. from typing import Dict, List, Tuple, Optional
  10. from app.utils.card_inference.backbone import BiSeNetV2
  11. from app.utils.card_inference.predict_preprocess import predict_preprocess
  12. from app.utils.card_inference.create_predict_result import create_result_singleImg
  13. from app.utils.card_inference.handle_result import process_detection_result
  14. import logging
  15. logging.basicConfig(level=logging.INFO)
  16. def fry_algo_print(level_str: str, info_str: str):
  17. logging.info(f"[{level_str}] : {info_str}")
  18. def fry_cv2_imread(filename, flags=cv2.IMREAD_COLOR):
  19. """支持中文路径的图像读取"""
  20. try:
  21. with open(filename, 'rb') as f:
  22. chunk = f.read()
  23. chunk_arr = np.frombuffer(chunk, dtype=np.uint8)
  24. img = cv2.imdecode(chunk_arr, flags)
  25. if img is None:
  26. fry_algo_print("警告", f"Warning: Unable to decode image: {filename}")
  27. return img
  28. except IOError as e:
  29. fry_algo_print("错误", f"IOError: Unable to read file: {filename}")
  30. fry_algo_print("错误", f"Error details: {str(e)}")
  31. return None
  32. def fry_cv2_imwrite(filename, img, params=None):
  33. """支持中文路径的图像保存"""
  34. try:
  35. ext = Path(filename).suffix.lower()
  36. result, encoded_img = cv2.imencode(ext, img, params)
  37. if result:
  38. with open(filename, 'wb') as f:
  39. encoded_img.tofile(f)
  40. return True
  41. else:
  42. fry_algo_print("警告", f"Warning: Unable to encode image: {filename}")
  43. return False
  44. except Exception as e:
  45. fry_algo_print("错误", f"Error: Unable to write file: {filename}")
  46. fry_algo_print("错误", f"Error details: {str(e)}")
  47. return False
  48. def fry_opencv_chinese_path_init():
  49. """初始化OpenCV中文路径支持"""
  50. cv2.imread = fry_cv2_imread
  51. cv2.imwrite = fry_cv2_imwrite
  52. # 初始化OpenCV中文路径支持
  53. OPENCV_IO_ALREADY_INIT = False
  54. if not OPENCV_IO_ALREADY_INIT:
  55. fry_opencv_chinese_path_init()
  56. OPENCV_IO_ALREADY_INIT = True
  57. class FryBisenetV2Predictor:
  58. """BiSeNetV2 语义分割预测器"""
  59. def __init__(self,
  60. pth_path: str,
  61. real_seg_class_dict: Dict[int, str],
  62. imgSize_train_dict: Dict[str, int],
  63. confidence: float = 0.5,
  64. label_colors_dict: Optional[Dict[str, Tuple[int, int, int]]] = None,
  65. input_channels: int = 3,
  66. aux_mode: str = "eval"):
  67. """
  68. 初始化预测器
  69. Args:
  70. pth_path: 模型权重文件路径
  71. real_seg_class_dict: 真实的分割类别字典,格式为 {类别id: 类别名称}
  72. imgSize_train_dict: 训练时的图像尺寸,格式为 {'width': 宽度, 'height': 高度}
  73. confidence: 置信度阈值
  74. label_colors_dict: 类别颜色字典,格式为 {类别名称: (R, G, B)}
  75. input_channels: 输入通道数
  76. aux_mode: 辅助模式
  77. """
  78. self.pth_path = pth_path
  79. self.real_seg_class_dict = real_seg_class_dict
  80. self.imgSize_train_dict = imgSize_train_dict
  81. self.confidence = confidence
  82. self.input_channels = input_channels
  83. self.aux_mode = aux_mode
  84. # 构建完整的分割类别字典(包含背景类)
  85. self.seg_class_dict = {0: '___background___'}
  86. self.seg_class_dict.update(real_seg_class_dict)
  87. self.n_classes = len(self.seg_class_dict)
  88. # 生成或使用提供的颜色字典
  89. self.label_colors_dict = self._generate_label_colors(label_colors_dict)
  90. # 获取设备
  91. self.device = self._get_device()
  92. # 初始化模型
  93. self.model = self._init_model()
  94. @staticmethod
  95. def _get_device():
  96. """获取计算设备"""
  97. return torch.device("cuda" if torch.cuda.is_available() else "cpu")
  98. def _generate_label_colors(self, label_colors_dict: Optional[Dict[str, Tuple[int, int, int]]]) -> Dict[
  99. str, Tuple[int, int, int]]:
  100. """
  101. 生成或补充类别颜色字典
  102. Args:
  103. label_colors_dict: 用户提供的颜色字典
  104. Returns:
  105. 完整的颜色字典
  106. """
  107. if label_colors_dict is None:
  108. label_colors_dict = {}
  109. # 为所有类别生成颜色(除了背景)
  110. np.random.seed(42) # 设置随机种子以保证颜色一致性
  111. for class_id, class_name in self.seg_class_dict.items():
  112. if class_id == 0: # 跳过背景类
  113. continue
  114. if class_name not in label_colors_dict:
  115. # 生成随机颜色,避免太暗的颜色
  116. color = tuple(np.random.randint(50, 256, 3).tolist())
  117. label_colors_dict[class_name] = color
  118. return label_colors_dict
  119. def _load_model_weights(self, model: nn.Module, modelLoadPth: str) -> nn.Module:
  120. """
  121. 加载模型权重
  122. Args:
  123. model: 模型对象
  124. modelLoadPth: 权重文件路径
  125. Returns:
  126. 加载权重后的模型
  127. """
  128. fry_algo_print("信息", "加载预训练参数...")
  129. weights_dict = torch.load(modelLoadPth, map_location=self.device)
  130. new_weights_dict = {}
  131. exclude_layer_list = ["aux2", 'aux3', 'aux4', 'aux5']
  132. all_layer_num = 0
  133. ok_layer_num = 0
  134. for k, v in weights_dict.items():
  135. all_layer_num += 1
  136. is_exclude = False
  137. # 检查是否需要排除该层
  138. for exclude_str in exclude_layer_list:
  139. if exclude_str in k:
  140. is_exclude = True
  141. break
  142. if not is_exclude:
  143. new_weights_dict[k] = v
  144. ok_layer_num += 1
  145. else:
  146. fry_algo_print("信息", f"被排除的层:{k}")
  147. # 加载权重,不要求严格对等
  148. model.load_state_dict(new_weights_dict, strict=False)
  149. fry_algo_print("信息", f"成功加载模型层数:{ok_layer_num}/{all_layer_num}")
  150. return model
  151. def _init_model(self) -> nn.Module:
  152. """初始化并加载模型"""
  153. model = BiSeNetV2(self.n_classes, self.input_channels, self.aux_mode)
  154. model = model.to(self.device)
  155. model = self._load_model_weights(model, self.pth_path)
  156. model.eval()
  157. return model
  158. def _predict_tensor(self, CHW: torch.Tensor) -> Dict:
  159. """
  160. 对单个图像张量进行预测
  161. Args:
  162. CHW: 形状为 (C, H, W) 的图像张量
  163. Returns:
  164. 包含预测结果的字典
  165. """
  166. with torch.no_grad():
  167. NCHW = CHW.unsqueeze(0)
  168. # 因为单张图片推理 batch norm 层会报错,所以复制一份
  169. NCHW2 = torch.cat([NCHW, NCHW], dim=0)
  170. # 模型推理
  171. logits, *logits_aux = self.model(NCHW2)
  172. # 计算概率和预测类别
  173. probs = torch.softmax(logits, dim=1)
  174. preds = torch.argmax(probs, dim=1)
  175. # 转换为numpy数组
  176. probs_np = probs.detach().cpu().numpy()
  177. preds_np = preds.detach().cpu().numpy()
  178. # 取第一张图片的结果
  179. ansImg_needSave = preds_np[0]
  180. ansProbs = probs_np[0]
  181. return {
  182. "ans_img": ansImg_needSave,
  183. "probs": ansProbs,
  184. "file_name": "result"
  185. }
  186. def _save_result_json(self, result: Dict, json_path: Path):
  187. """
  188. 保存预测结果为JSON文件
  189. Args:
  190. result: 预测结果字典
  191. json_path: JSON文件保存路径
  192. """
  193. # 将numpy数组转换为可序列化的格式
  194. json_result = {}
  195. for key, value in result.items():
  196. if isinstance(value, np.ndarray):
  197. json_result[key] = value.tolist()
  198. elif isinstance(value, dict):
  199. json_result[key] = {}
  200. for k, v in value.items():
  201. if isinstance(v, np.ndarray):
  202. json_result[key][k] = v.tolist()
  203. else:
  204. json_result[key][k] = v
  205. else:
  206. json_result[key] = value
  207. with open(json_path, 'w', encoding='utf-8') as f:
  208. json.dump(json_result, f, ensure_ascii=False, indent=2)
  209. def predict_single_image_np(self,
  210. img_bgr: np.ndarray,
  211. image_path_str: str = None,
  212. save_visualization: bool = True,
  213. save_json: bool = True,
  214. answer_json_dir_str: Optional[str] = None,
  215. input_channels=3
  216. ) -> Dict:
  217. """
  218. 预测单张图片
  219. Args:
  220. img_path: 图片路径
  221. save_visualization: 是否保存可视化结果
  222. save_json: 是否保存JSON结果
  223. answer_json_dir_str: JSON结果保存目录
  224. Returns:
  225. 预测结果字典
  226. """
  227. if image_path_str is None:
  228. timestamp = datetime.now().strftime('%y%m%d_%H%M%S_%f')
  229. image_path_real_str = f"{timestamp}.jpg"
  230. else:
  231. image_path_real_str = str(image_path_str)
  232. img_path_real_obj = Path(image_path_real_str).resolve()
  233. answer_json_dir_str_obj = Path(answer_json_dir_str).resolve()
  234. shape = img_bgr.shape
  235. image_channel = shape[2]
  236. fry_algo_print("信息", f"模型需要的通道数为:{input_channels}")
  237. fry_algo_print("信息", f"测试的图片实际的通道数为:{image_channel}")
  238. if image_channel != input_channels:
  239. # if image_channel==3 and input_channels==1:
  240. # img_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
  241. # elif image_channel==4 and input_channels==1:
  242. # img_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_BGRA2GRAY)
  243. # elif image_channel==1 and input_channels==3:
  244. # img_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_GRAY2BGR)
  245. # else:
  246. # raise ValueError(f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{input_channels}")
  247. raise ValueError(
  248. f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{input_channels}")
  249. # 获取原始图片尺寸
  250. height, width = img_bgr.shape[:2]
  251. originImgSize = {'width': width, 'height': height}
  252. # 预处理
  253. imgTensor_CHW_norm = predict_preprocess(img_bgr, self.imgSize_train_dict)
  254. # 预测
  255. ansImgDict = self._predict_tensor(imgTensor_CHW_norm)
  256. image_path_name = str(img_path_real_obj.name)
  257. # 创建结果
  258. per_img_seg_result = create_result_singleImg(
  259. self.seg_class_dict,
  260. ansImgDict,
  261. originImgSize,
  262. self.imgSize_train_dict,
  263. confidence=self.confidence,
  264. )
  265. # 保存JSON结果
  266. if save_json and answer_json_dir_str:
  267. json_dir = Path(answer_json_dir_str)
  268. json_dir.mkdir(parents=True, exist_ok=True)
  269. if image_path_str is None:
  270. cv2.imwrite(image_path_real_str, img_bgr)
  271. # 获取图片文件名(不含扩展名)
  272. img_name = Path(img_path_real_obj).stem
  273. json_path = json_dir / f"{img_name}.json"
  274. self._save_result_json(per_img_seg_result, json_path)
  275. fry_algo_print("成功", f"JSON结果已保存到:{json_path}")
  276. # 保存可视化结果
  277. if save_visualization:
  278. result_img = process_detection_result(img_bgr, per_img_seg_result, self.label_colors_dict)
  279. output_path = str(answer_json_dir_str_obj / f"{img_path_real_obj.name}")
  280. cv2.imwrite(output_path, result_img)
  281. fry_algo_print("成功", f"可视化结果已保存到:{output_path}")
  282. return per_img_seg_result
  283. def _save_result_json(self, result: Dict, json_path: Path):
  284. """
  285. 保存预测结果为JSON文件
  286. Args:
  287. result: 预测结果字典
  288. json_path: JSON文件保存路径
  289. """
  290. # 将numpy数组转换为可序列化的格式
  291. json_result = {}
  292. for key, value in result.items():
  293. if isinstance(value, np.ndarray):
  294. json_result[key] = value.tolist()
  295. elif isinstance(value, dict):
  296. json_result[key] = {}
  297. for k, v in value.items():
  298. if isinstance(v, np.ndarray):
  299. json_result[key][k] = v.tolist()
  300. else:
  301. json_result[key][k] = v
  302. else:
  303. json_result[key] = value
  304. with open(json_path, 'w', encoding='utf-8') as f:
  305. json.dump(json_result, f, ensure_ascii=False, indent=2)
  306. def predict_from_image(self, img_bgr: np.ndarray) -> Dict:
  307. """
  308. 直接从解码后的图像数据(numpy数组)进行预测。
  309. Args:
  310. img_bgr: BGR格式的图像,作为一个numpy数组。
  311. Returns:
  312. 预测结果字典。
  313. """
  314. # 检查通道数是否匹配
  315. shape = img_bgr.shape
  316. image_channel = shape[2] if len(shape) == 3 else 1
  317. if image_channel != self.input_channels:
  318. raise ValueError(
  319. f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{self.input_channels}")
  320. # 获取原始图片尺寸
  321. height, width = img_bgr.shape[:2]
  322. originImgSize = {'width': width, 'height': height}
  323. # 预处理
  324. imgTensor_CHW_norm = predict_preprocess(img_bgr, self.imgSize_train_dict)
  325. # 预测
  326. ansImgDict = self._predict_tensor(imgTensor_CHW_norm)
  327. # 创建结果
  328. per_img_seg_result = create_result_singleImg(
  329. self.seg_class_dict,
  330. ansImgDict,
  331. originImgSize,
  332. self.imgSize_train_dict,
  333. confidence=self.confidence
  334. )
  335. return per_img_seg_result
  336. def predict_single_image(self,
  337. img_path: str,
  338. save_visualization: bool = True,
  339. save_json: bool = True,
  340. answer_json_dir_str: Optional[str] = None,
  341. input_channels=3
  342. ) -> Dict:
  343. """
  344. 预测单张图片
  345. Args:
  346. img_path: 图片路径
  347. save_visualization: 是否保存可视化结果
  348. save_json: 是否保存JSON结果
  349. answer_json_dir_str: JSON结果保存目录
  350. Returns:
  351. 预测结果字典
  352. """
  353. img_path_obj = Path(img_path).resolve()
  354. img_path_parent_obj = img_path_obj.parent
  355. answer_json_dir_str_obj = Path(answer_json_dir_str).resolve()
  356. # 读取图片
  357. img_bgr = cv2.imread(str(img_path_obj))
  358. if img_bgr is None:
  359. raise ValueError(f"无法读取图片:{img_path}")
  360. per_img_seg_result = self.predict_single_image_np(
  361. img_bgr=img_bgr,
  362. image_path_str=str(img_path_obj),
  363. save_visualization=save_visualization,
  364. save_json=save_json,
  365. answer_json_dir_str=answer_json_dir_str,
  366. input_channels=input_channels
  367. )
  368. return per_img_seg_result
  369. def predict_batch(self,
  370. img_paths: List[str],
  371. save_visualization: bool = True,
  372. save_json: bool = True,
  373. answer_json_dir_str: Optional[str] = None,
  374. input_channels=3
  375. ) -> List[Dict]:
  376. """
  377. 批量预测图片
  378. Args:
  379. img_paths: 图片路径列表
  380. save_visualization: 是否保存可视化结果
  381. save_json: 是否保存JSON结果
  382. answer_json_dir_str: JSON结果保存目录
  383. output_dir: 可视化结果保存目录
  384. Returns:
  385. 所有图片的预测结果列表
  386. """
  387. answer_json_dir_str_obj = Path(answer_json_dir_str).resolve()
  388. results = []
  389. Path(answer_json_dir_str).mkdir(parents=True, exist_ok=True)
  390. # 批量处理
  391. for i, img_path in enumerate(img_paths):
  392. fry_algo_print("信息", f"处理图片 {i + 1}/{len(img_paths)}: {img_path}")
  393. try:
  394. # 读取图片
  395. img_bgr = cv2.imread(img_path)
  396. if img_bgr is None:
  397. fry_algo_print("信息", f"警告:无法读取图片 {img_path}")
  398. continue
  399. shape = img_bgr.shape
  400. image_channel = shape[2]
  401. if image_channel != input_channels:
  402. fry_algo_print("信息", f"模型需要的通道数为:{input_channels}")
  403. fry_algo_print("信息", f"测试的图片实际的通道数为:{image_channel}")
  404. fry_algo_print("错误",
  405. f"输入图片的通道数和模型不匹配:image_channel:{image_channel},input_channels:{input_channels}")
  406. continue
  407. # 获取原始图片尺寸
  408. height, width = img_bgr.shape[:2]
  409. originImgSize = {'width': width, 'height': height}
  410. # 预处理
  411. imgTensor_CHW_norm = predict_preprocess(img_bgr, self.imgSize_train_dict)
  412. # 预测
  413. ansImgDict = self._predict_tensor(imgTensor_CHW_norm)
  414. img_path_obj = Path(img_path).resolve()
  415. image_path_name = str(img_path_obj.name)
  416. # 创建结果
  417. per_img_seg_result = create_result_singleImg(
  418. self.seg_class_dict,
  419. ansImgDict,
  420. originImgSize,
  421. self.imgSize_train_dict,
  422. confidence=self.confidence,
  423. image_path_name=image_path_name
  424. )
  425. # 保存JSON结果
  426. if save_json and answer_json_dir_str:
  427. json_dir = Path(answer_json_dir_str)
  428. json_dir.mkdir(parents=True, exist_ok=True)
  429. img_name = Path(img_path).stem
  430. json_path = json_dir / f"{img_name}.json"
  431. self._save_result_json(per_img_seg_result, json_path)
  432. # 保存可视化结果
  433. if save_visualization:
  434. result_img = process_detection_result(img_bgr, per_img_seg_result, self.label_colors_dict)
  435. output_path = answer_json_dir_str_obj / f"{Path(img_path).name}"
  436. cv2.imwrite(str(output_path), result_img)
  437. results.append(per_img_seg_result)
  438. except Exception as e:
  439. fry_algo_print("失败", f"处理图片 {img_path} 时出错:{e}")
  440. continue
  441. fry_algo_print("成功", f"批量处理完成,成功处理 {len(results)}/{len(img_paths)} 张图片")
  442. return results
  443. def main():
  444. """使用示例"""
  445. # 配置参数
  446. pth_path = r"segmentation_bisenetv2.pth"
  447. input_channels = 3
  448. real_seg_class_dict = {1: 'outer_box'}
  449. # 为不同类别设置不同颜色(可选)
  450. label_colors_dict = {
  451. 'outer_box': (255, 0, 0),
  452. }
  453. imgSize_train_dict = {'width': 1280, 'height': 1280}
  454. confidence = 0.5
  455. # 创建预测器
  456. predictor = FryBisenetV2Predictor(
  457. pth_path=pth_path,
  458. real_seg_class_dict=real_seg_class_dict,
  459. imgSize_train_dict=imgSize_train_dict,
  460. confidence=confidence,
  461. label_colors_dict=label_colors_dict,
  462. input_channels=input_channels,
  463. )
  464. # 单张图片预测
  465. print("=== 单张图片预测 ===")
  466. now_img_path = r"input_output\images\coaxis_0008.jpg"
  467. answer_json_dir_str = r"input_output\images_answer_json_dir_str"
  468. result = predictor.predict_single_image(
  469. img_path=now_img_path,
  470. save_visualization=True,
  471. save_json=True,
  472. answer_json_dir_str=answer_json_dir_str,
  473. input_channels=input_channels,
  474. )
  475. # 批量预测示例
  476. # print("\n=== 批量图片预测 ===")
  477. # img_paths = [
  478. # r"input_output\images\coaxis_0008.jpg",
  479. # r"input_output\images\coaxis_0082.jpg",
  480. # r"input_output\images\ring_0001.jpg",
  481. # r"input_output\images\Pokemon_back_for_Edge_0001.jpg",
  482. # ]
  483. #
  484. # results = predictor.predict_batch(
  485. # img_paths=img_paths,
  486. # save_visualization=True,
  487. # save_json=True,
  488. # answer_json_dir_str=answer_json_dir_str,
  489. # input_channels=input_channels,
  490. # )
  491. def _test_pokemon_inner_box():
  492. # 配置参数
  493. pth_path = r"E:\_250807_训练好的导出的模型\_250808_1043_宝可梦内框训练效果还行\pth_and_images\segmentation_bisenetv2.pth"
  494. input_channels = 3
  495. real_seg_class_dict = {1: 'inner_box'}
  496. # 为不同类别设置不同颜色(可选)
  497. label_colors_dict = {
  498. 'outer_box': (255, 0, 0),
  499. }
  500. imgSize_train_dict = {'width': 1280, 'height': 1280}
  501. confidence = 0.5
  502. # 创建预测器
  503. predictor = FryBisenetV2Predictor(
  504. pth_path=pth_path,
  505. real_seg_class_dict=real_seg_class_dict,
  506. imgSize_train_dict=imgSize_train_dict,
  507. confidence=confidence,
  508. label_colors_dict=label_colors_dict,
  509. input_channels=input_channels,
  510. )
  511. # 单张图片预测
  512. print("=== 单张图片预测 ===")
  513. now_img_path = r"E:\_250807_训练好的导出的模型\_250808_1043_宝可梦内框训练效果还行\pth_and_images\images\diff_big_00065.jpg"
  514. answer_json_dir_str = r"E:\_250807_训练好的导出的模型\_250808_1043_宝可梦内框训练效果还行\pth_and_images\images_answer"
  515. result = predictor.predict_single_image(
  516. img_path=now_img_path,
  517. save_visualization=True,
  518. save_json=True,
  519. answer_json_dir_str=answer_json_dir_str
  520. )
  521. def _test_pokemon_back_edge():
  522. # 配置参数
  523. pth_path = r"E:\_250807_训练好的导出的模型\_250811_1104_宝可梦背面边角\pth_and_images\segmentation_bisenetv2.pth"
  524. input_channels = 3
  525. real_seg_class_dict = {
  526. 1: 'wear',
  527. 2: 'wear_and_impact',
  528. 3: 'impact',
  529. 4: 'damaged',
  530. 5: 'wear_and_stain',
  531. }
  532. # 为不同类别设置不同颜色(可选)
  533. # label_colors_dict = {
  534. # 'outer_box': (255, 0, 0),
  535. # }
  536. imgSize_train_dict = {'width': 512, 'height': 512}
  537. confidence = 0.5
  538. # 创建预测器
  539. predictor = FryBisenetV2Predictor(
  540. pth_path=pth_path,
  541. real_seg_class_dict=real_seg_class_dict,
  542. imgSize_train_dict=imgSize_train_dict,
  543. confidence=confidence,
  544. input_channels=input_channels,
  545. )
  546. # 单张图片预测
  547. print("=== 单张图片预测 ===")
  548. now_img_path = r"E:\_250807_训练好的导出的模型\_250811_1104_宝可梦背面边角\pth_and_images\images\split\Pokémon_back_for_Edge_0001_bottom_grid_r0_c0.jpg"
  549. answer_json_dir_str = r"E:\_250807_训练好的导出的模型\_250811_1104_宝可梦背面边角\pth_and_images\images_answer"
  550. result = predictor.predict_single_image(
  551. img_path=now_img_path,
  552. save_visualization=True,
  553. save_json=True,
  554. answer_json_dir_str=answer_json_dir_str,
  555. input_channels=input_channels
  556. )
  557. if __name__ == "__main__":
  558. main()
  559. # test_pokemon_inner_box()
  560. # test_pokemon_back_edge()