Quellcode durchsuchen

ocr功能升级

AnlaAnla vor 1 Monat
Ursprung
Commit
f0346604b9
2 geänderte Dateien mit 112 neuen und 67 gelöschten Zeilen
  1. 11 63
      Test/test02.py
  2. 101 4
      app/services/video_service.py

+ 11 - 63
Test/test02.py

@@ -1,67 +1,15 @@
-import os
-from langgraph.graph import StateGraph, START, END
-from langgraph.graph.message import add_messages
-from langchain_openai import ChatOpenAI
+from ultralytics import YOLO
 
-# --------------- 基本配置 ----------------
+# Load a model
+model = YOLO(r"C:\Code\ML\Model\yolo26n-pose.pt")  # load an official model
 
-# 1) 设置 API Key & Base URL
-# 假设你已经通过 DeepSeek QPI 获取到了兼容 OpenAI 的 key & endpoint
-os.environ["OPENAI_API_KEY"] = "YOUR_QPI_API_KEY"
-# 如果 DeepSeek QPI 需要自定义 Base URL,请设置:
-# os.environ["OPENAI_API_BASE"] = "https://your-provider-url/v1"
+# Predict with the model
+results = model("https://ultralytics.com/images/bus.jpg")  # predict on an image
 
-# 2) 初始化 LLM
-# deepseek/deepseek-r1 通常在 QPI/OpenRouter 兼容 API 下可调用
-llm = ChatOpenAI(model="deepseek/deepseek-r1:latest", temperature=0.7)
+# Access the results
+for result in results:
+    xy = result.keypoints.xy  # x and y coordinates
+    xyn = result.keypoints.xyn  # normalized
+    kpts = result.keypoints.data  # x, y, visibility (if available)
 
-
-# --------------- LangGraph 节点 ----------------
-
-def call_deepseek(state):
-    """
-    一个简单的函数节点,它用 LLM 理解 state["messages"]
-    并返回下一步 messages
-    """
-    user_msgs = state["messages"]
-
-    # 调用 LLM
-    response = llm(
-        # LangChain 格式要求 messages 是 dict 列表
-        messages=user_msgs
-    )
-
-    # 获取模型输出的 text
-    ai_msg = response["choices"][0]["message"]
-
-    # 将 AI 的回复追加回状态
-    return {"messages": user_msgs + [ai_msg]}
-
-
-# --------------- 构建状态图 ----------------
-
-# 状态类型使用 LangChain 的消息状态对象
-from langgraph.graph import MessagesState
-
-graph = StateGraph(MessagesState)
-
-# 添加节点到图
-graph.add_node(call_deepseek)
-
-# 定义边 (Start → 我们的 LLM 节点 → End)
-graph.add_edge(START, "call_deepseek")
-graph.add_edge("call_deepseek", END)
-
-# 编译图
-compiled_graph = graph.compile()
-
-# --------------- 调用运行 ----------------
-
-result = compiled_graph.invoke({
-    "messages": [
-        {"role": "user", "content": "你好,帮我写一段 LangGraph 入门示例说明"}
-    ]
-})
-
-# 输出最终状态
-print(result["messages"][-1]["content"])
+print()

+ 101 - 4
app/services/video_service.py

@@ -228,6 +228,70 @@ class VideoService:
 
         return min(score, 1.0)
 
+    def _batch_analyze_segmentation(self, frames: list[Any]) -> list[dict[str, Any]]:
+        """批量对多张图像进行语义分割,极大提高 GPU 利用率"""
+        if not frames or self._ensure_segmentation_model() is None:
+            return [{"segmentation_used": False, "has_card": False, "has_hand": False,
+                     "card_area_ratio": 0.0, "hand_area_ratio": 0.0, "card_bbox": None}] * len(frames)
+
+        try:
+            pil_images = [self._seg_pil_image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB)) for f in frames]
+            device = next(self._seg_model.parameters()).device
+            results = []
+
+            # 分块批处理,防止显存 OOM(比如 16 帧一个 Batch)
+            batch_size = 16
+            for i in range(0, len(pil_images), batch_size):
+                batch_imgs = pil_images[i: i + batch_size]
+
+                inputs = self._seg_processor(images=batch_imgs, return_tensors="pt").to(device)
+
+                with self._seg_torch.no_grad():
+                    outputs = self._seg_model(**inputs)
+
+                logits = outputs.logits
+
+                # 批量上采样并取 argmax
+                preds = self._seg_torch.nn.functional.interpolate(
+                    logits,
+                    size=batch_imgs[0].size[::-1],  # 假设所有帧分辨率一样
+                    mode="bilinear",
+                    align_corners=False,
+                ).argmax(dim=1).cpu().numpy()
+
+                # 解析每张图的 Mask
+                for pred in preds:
+                    card_mask = pred == settings.VIDEO_CARD_LABEL_ID
+                    hand_mask = pred == settings.VIDEO_HAND_LABEL_ID
+
+                    card_area = float(card_mask.mean()) if card_mask.size else 0.0
+                    hand_area = float(hand_mask.mean()) if hand_mask.size else 0.0
+
+                    card_bbox = self._largest_bbox(card_mask)
+                    hand_bbox = self._largest_bbox(hand_mask)
+                    focus_bbox = card_bbox if card_bbox is not None else hand_bbox
+
+                    results.append({
+                        "segmentation_used": True,
+                        "has_card": card_area >= settings.VIDEO_MIN_CARD_AREA_RATIO,
+                        "has_hand": hand_area >= settings.VIDEO_MIN_HAND_AREA_RATIO,
+                        "card_area_ratio": card_area,
+                        "hand_area_ratio": hand_area,
+                        "card_bbox": focus_bbox,
+                    })
+
+                # 及时清理这批显存
+                del inputs, outputs, logits, preds
+                if self._seg_torch.cuda.is_available():
+                    self._seg_torch.cuda.empty_cache()
+
+            return results
+
+        except Exception as exc:
+            logger.warning(f"Batch segmentation failed, fallback: {exc}")
+            return [{"segmentation_used": False, "has_card": False, "has_hand": False,
+                     "card_area_ratio": 0.0, "hand_area_ratio": 0.0, "card_bbox": None}] * len(frames)
+
     def _analyze_segmentation(self, frame) -> dict[str, Any]:
         """对单帧图像进行语义分割分析,寻找卡片和手的区域"""
         if self._ensure_segmentation_model() is None:
@@ -346,6 +410,9 @@ class VideoService:
     ) -> list[FrameCandidate]:
         """在指定时间窗口内滑动,按步长收集视频帧作为候选"""
         candidates: list[FrameCandidate] = []
+        raw_frames = []
+        time_ms_list = []
+
         analysis_stride = self._analysis_stride(fps)
 
         # 预估最大读取次数,防止视频结尾卡死死循环
@@ -356,7 +423,8 @@ class VideoService:
 
         read_count = 0
         while read_count < max_reads:
-            ret, frame = cap.read()
+            # 仅仅抓取下一帧的数据流,不进行耗时的图像解码
+            ret = cap.grab()
             if not ret:
                 break
 
@@ -364,12 +432,41 @@ class VideoService:
             if current_time_ms > end_time_ms:
                 break
 
-            # 按计算好的步长 (analysis_stride) 进行抽帧分析
+            # 到达步长,才真正解码成图像矩阵
             if read_count % analysis_stride == 0:
-                candidates.append(self._build_candidate(frame, int(current_time_ms), target_time_ms))
+                ret, frame = cap.retrieve()
+                if ret:
+                    raw_frames.append(frame.copy())
+                    time_ms_list.append(current_time_ms)
 
             read_count += 1
 
+        if not raw_frames:
+            return []
+
+        # 1. 批量过分割模型
+        seg_results = self._batch_analyze_segmentation(raw_frames)
+
+        # 2. 遍历组装 Candidate 并计算清晰度
+        for frame, time_ms, seg_res in zip(raw_frames, time_ms_list, seg_results):
+            # 切割关注区域算清晰度
+            focus_region = self._focus_region(frame, seg_res["card_bbox"])
+            sharpness = self.get_laplacian_sharpness(focus_region)
+
+            presence_score = self._compute_presence_score(
+                seg_res["segmentation_used"], seg_res["has_card"],
+                seg_res["has_hand"], seg_res["card_area_ratio"], seg_res["hand_area_ratio"]
+            )
+
+            candidates.append(FrameCandidate(
+                frame=frame,
+                time_ms=int(time_ms),
+                sharpness=sharpness,
+                time_weight=self.calculate_weight(time_ms, target_time_ms),
+                presence_score=presence_score,
+                **seg_res  # 解包填入 has_card, card_bbox 等属性
+            ))
+
         return candidates
 
     def _assign_dwell_scores(self, candidates: list[FrameCandidate]) -> None:
@@ -793,4 +890,4 @@ class VideoService:
             f"Frame capture finished. saved={success_count}, "
             f"filtered={filtered_count}, total={len(cards)}"
         )
-        return output_list
+        return output_list