Bladeren bron

修改前的准备

AnlaAnla 5 dagen geleden
bovenliggende
commit
3991ce2752
9 gewijzigde bestanden met toevoegingen van 243 en 13 verwijderingen
  1. 2 0
      .gitignore
  2. 10 0
      Test/RapidOCR_test.py
  3. 47 0
      Test/seg_test02.py
  4. 1 0
      Test/test01.py
  5. 20 0
      Test/视频提取语音.py
  6. 75 0
      Test/语音转文本.py
  7. 6 0
      app/main.py
  8. 80 12
      app/services/video_service.py
  9. 2 1
      run_CardVideoSummary.py

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+/static/
+/Temp/

+ 10 - 0
Test/RapidOCR_test.py

@@ -0,0 +1,10 @@
+from rapidocr import RapidOCR
+
+engine = RapidOCR()
+
+img_url = r"C:\Code\ML\Image\_TEST_DATA\Card_test\test05\945e0cc0884c8766a5883ea9593def9d.png"
+result = engine(img_url)
+print(result)
+
+result.vis("vis_result.jpg")
+

+ 47 - 0
Test/seg_test02.py

@@ -0,0 +1,47 @@
+import torch
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
+
+model_dir = r"C:\Code\ML\Model\Card_Seg\segformer_card_hand02_safetensors"
+# img_path = r"C:\Users\wow38\Pictures\videoframe_6871967.png"
+
+
+
+processor = AutoImageProcessor.from_pretrained(model_dir)
+model = AutoModelForSemanticSegmentation.from_pretrained(model_dir)
+
+def show(img_path):
+    image = Image.open(img_path).convert("RGB")
+    inputs = processor(images=image, return_tensors="pt")
+
+    with torch.no_grad():
+        outputs = model(**inputs)
+
+    logits = outputs.logits
+    pred = torch.nn.functional.interpolate(
+        logits,
+        size=image.size[::-1],
+        mode="bilinear",
+        align_corners=False
+    ).argmax(dim=1)[0].cpu().numpy()
+
+    plt.figure(figsize=(10, 5))
+
+    plt.subplot(1, 2, 1)
+    plt.imshow(image)
+    plt.title("image")
+    plt.axis("off")
+
+    plt.subplot(1, 2, 2)
+    plt.imshow(pred)
+    plt.title("mask")
+    plt.axis("off")
+
+    plt.tight_layout()
+    plt.show()
+
+if __name__ == '__main__':
+    show(r"C:\Users\wow38\Pictures\videoframe_6871967 - 副本.png")
+    print()

+ 1 - 0
Test/test01.py

@@ -1,3 +1,4 @@
 
 if __name__ == '__main__':
     print("1235456")
+    "C:/Code/ML/Video/直播数据/video/vortexcards.mp4"

+ 20 - 0
Test/视频提取语音.py

@@ -0,0 +1,20 @@
+from moviepy.audio.io.AudioFileClip import AudioFileClip
+from moviepy.editor import VideoFileClip, CompositeAudioClip
+import os
+import pandas as pd
+import numpy as np
+
+if __name__ == '__main__':
+    video_path = r"C:\Code\ML\Video\直播数据\video\2026_02_25 16_47_46.mp4"
+    audio_file_save_path = r"C:\Code\ML\Video\直播数据\video\2026_02_25 16_47_46.mp3"
+
+    # 加载视频文件
+    video = VideoFileClip(video_path)
+    audio = video.audio
+
+    # start_time = 1
+    # end_time = 8
+    # cut_audio = audio.subclip(start_time, end_time)
+
+    audio.write_audiofile(audio_file_save_path)
+    print('end')

+ 75 - 0
Test/语音转文本.py

@@ -0,0 +1,75 @@
+import os
+from faster_whisper import WhisperModel
+
+
+def format_timestamp(seconds: float):
+    """
+    将秒数转换为 SRT 时间戳格式 (HH:MM:SS,mmm)
+    """
+    whole_seconds = int(seconds)
+    # milliseconds = int((seconds - whole_seconds) * 1000)
+
+    hours = whole_seconds // 3600
+    minutes = (whole_seconds % 3600) // 60
+    secs = whole_seconds % 60
+
+    return f"{hours:02d}:{minutes:02d}:{secs:02d}"
+
+
+def transcribe_mp3_to_srt(mp3_path, model_size="large-v3", device="cuda", compute_type="int8_float16",
+                          language=None):
+    print(f"正在加载模型: {model_size} ({compute_type})...")
+
+    # 1. 初始化模型
+    # 这里是核心:使用 cuda 和 int8_float16 达到速度与精度的平衡
+    model = WhisperModel(model_size, device=device, compute_type=compute_type)
+
+    print(f"正在转录音频: {mp3_path} ...")
+
+    # 2. 开始转录
+    # beam_size=5 是官方推荐的精度设置
+    # vad_filter=True 会自动过滤静音片段,极大提升长音频的处理速度
+    segments, info = model.transcribe(
+        mp3_path,
+        beam_size=5,
+        vad_filter=True,
+        language=language
+    )
+
+    print(f"检测到语言: {info.language}, 置信度: {info.language_probability:.2f}")
+
+    # 3. 输出文件名
+    srt_filename = os.path.splitext(mp3_path)[0] + ".txt"
+
+    # 4. 写入 SRT 文件
+    # 注意:segments 是一个生成器,只有在遍历时才会真正开始计算(流式处理)
+    with open(srt_filename, "w", encoding="utf-8") as f:
+        for i, segment in enumerate(segments, start=1):
+            start_time = format_timestamp(segment.start)
+            end_time = format_timestamp(segment.end)
+            text = segment.text.strip()
+
+            # 写入 SRT 格式
+            # f.write(f"{i}\n")
+            # f.write(f"{start_time} --> {end_time}\n")
+            # f.write(f"{text}\n\n")
+
+            # txt 格式
+            f.write(f"{start_time}\n")
+            f.write(f"{text}\n\n")
+
+            # 可选:实时打印进度
+            print(f"[{start_time} -> {end_time}] {text}")
+
+    print(f"\n✅ 提取完成!字幕已保存为: {srt_filename}")
+
+
+if __name__ == "__main__":
+    # 替换为你的 mp3 文件路径
+    audio_file = "/home/martin/ML/RemoteProject/untitled10/Audio/temp/audio/backyardhits2.mp3"
+
+    if os.path.exists(audio_file):
+        transcribe_mp3_to_srt(audio_file, compute_type="default")
+    else:
+        print(f"找不到文件: {audio_file}")
+        print('==')

+ 6 - 0
app/main.py

@@ -1,7 +1,9 @@
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import HTMLResponse
 from fastapi.staticfiles import StaticFiles
 from app.api.routes import router
+from pathlib import Path
 from app.core.config import settings
 
 app = FastAPI(title="Card Extraction API", version="1.0")
@@ -16,6 +18,10 @@ app.add_middleware(
     allow_headers=["*"]
 )
 
+@app.get("/", response_class=HTMLResponse)
+async def root():
+    html_path = Path("static/view_results.html")
+    return html_path.read_text(encoding="utf-8")
 
 # 注册路由
 app.include_router(router, prefix="/api")

+ 80 - 12
app/services/video_service.py

@@ -1,6 +1,7 @@
 import cv2
 import os
 import uuid
+import math
 from app.core.config import settings
 from app.core.logger import get_logger
 from app.schemas.models import CardInfoInput, CardInfoOutput
@@ -9,6 +10,10 @@ logger = get_logger("VideoService")
 
 
 class VideoService:
+    def __init__(self):
+        # 高斯函数中的 sigma (标准差) 决定了时间权重的下降速度。
+        self.weight_sigma = 10.0
+
     def time_str_to_ms(self, time_str: str) -> int:
         try:
             parts = list(map(int, time_str.split(':')))
@@ -22,6 +27,21 @@ class VideoService:
         except ValueError:
             return 0
 
+    def get_laplacian_sharpness(self, frame) -> float:
+        """
+        计算图像的拉普拉斯方差。
+        方差越大,代表图像包含的高频边缘信息越多,也就意味着对焦越准、越清晰。
+        """
+        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+        return cv2.Laplacian(gray, cv2.CV_64F).var()
+
+    def calculate_weight(self, current_time_ms: int, target_time_ms: int) -> float:
+        """
+        计算时间权重:使用高斯衰减函数。距离目标时间越近,权重越高。
+        """
+        diff_seconds = abs(current_time_ms - target_time_ms) / 1000.0
+        return math.exp(- (diff_seconds ** 2) / (2 * self.weight_sigma ** 2))
+
     def capture_frames(self, video_path: str, cards: list[CardInfoInput]) -> list[CardInfoOutput]:
         if not os.path.exists(video_path):
             logger.error(f"❌ 找不到视频文件: {video_path}")
@@ -31,37 +51,85 @@ class VideoService:
         logger.info(f"📋 待处理卡片数量: {len(cards)}")
 
         cap = cv2.VideoCapture(video_path)
-        output_list = []
+        # 获取视频帧率,用于计算安全边界
+        fps = cap.get(cv2.CAP_PROP_FPS)
+        if fps <= 0:
+            fps = 30.0
 
+        output_list = []
         success_count = 0
 
         for idx, card_input in enumerate(cards):
-            # 将 Input 模型转为 Output 模型 (此时 path 为 None)
             card_output = CardInfoOutput(**card_input.dict())
+            target_time_ms = self.time_str_to_ms(card_output.time)
+
+            # 设定搜索窗口区间: [目标时间 - 1秒, 目标时间 + 4秒]
+            start_time_ms = max(0, target_time_ms - 1000)
+            end_time_ms = target_time_ms + 4000
 
-            time_ms = self.time_str_to_ms(card_output.time)
             logger.info(
-                f"📸 [{idx + 1}/{len(cards)}] 正在截取 {card_output.time} ({time_ms}ms) - {card_output.card_name_cn or '未知卡名'}")
+                f"📸[{idx + 1}/{len(cards)}] 智能截取 {card_output.time} ({target_time_ms}ms) - {card_output.card_name_cn or '未知卡名'}")
+            logger.info(f"   => 搜索区间: [{start_time_ms}ms ~ {end_time_ms}ms]")
+
+            # 定位到窗口开始时间
+            cap.set(cv2.CAP_PROP_POS_MSEC, start_time_ms)
+
+            best_frame = None
+            best_score = -1.0
+            best_time_ms = start_time_ms
+            best_sharpness = 0.0
 
-            # 设定位置
-            cap.set(cv2.CAP_PROP_POS_MSEC, time_ms)
-            ret, frame = cap.read()
+            # 保护机制:最多读取这么多次,防止由于视频末尾造成的无限死循环
+            max_reads = int((end_time_ms - start_time_ms) / 1000.0 * fps) + 30
+            read_count = 0
 
-            if ret:
-                filename = f"{uuid.uuid4()}_{time_ms}.jpg"
+            while read_count < max_reads:
+                current_pos_ms = cap.get(cv2.CAP_PROP_POS_MSEC)
+
+                # 超出窗口最大时间,停止当前卡片的搜索
+                if current_pos_ms > end_time_ms:
+                    break
+
+                ret, frame = cap.read()
+                if not ret:
+                    break  # 视频结束
+
+                # 计算原图清晰度
+                sharpness = self.get_laplacian_sharpness(frame)
+                # 计算时间偏移带来的衰减权重
+                weight = self.calculate_weight(current_pos_ms, target_time_ms)
+
+                # 综合评分 = 清晰度 * 时间权重
+                score = sharpness * weight
+
+                # 更新最佳候选帧
+                if score > best_score:
+                    best_score = score
+                    best_frame = frame
+                    best_time_ms = current_pos_ms
+                    best_sharpness = sharpness
+
+                read_count += 1
+
+            # 保存最清晰的一张
+            if best_frame is not None:
+                filename = f"{uuid.uuid4()}_{int(best_time_ms)}.jpg"
                 save_path = os.path.join(settings.FRAMES_DIR, filename)
 
                 try:
-                    cv2.imwrite(save_path, frame)
+                    cv2.imwrite(save_path, best_frame)
 
                     image_url = f"{settings.BASE_URL}/static/frames/{filename}"
                     card_output.frame_image_path = image_url
                     success_count += 1
-                    logger.info(f"   ✅ 保存成功: {filename}")
+
+                    time_diff = (best_time_ms - target_time_ms) / 1000.0
+                    logger.info(
+                        f"   ✅ 保存成功: {filename} (偏移: {time_diff:+.2f}s, 清晰度: {best_sharpness:.1f}, 综合分: {best_score:.1f})")
                 except Exception as e:
                     logger.error(f"   ❌ 保存图片失败: {e}")
             else:
-                logger.warning(f"   ⚠️ 无法读取视频帧 (可能时间戳超出视频长度)")
+                logger.warning(f"   ⚠️ 无法在窗口内读取视频帧 (可能时间戳超出视频长度)")
 
             output_list.append(card_output)
 

+ 2 - 1
run_CardVideoSummary.py

@@ -4,5 +4,6 @@ import socket
 if __name__ == "__main__":
     ip = socket.gethostbyname(socket.gethostname())
     port = 7721
+    print(f"http://{ip}:{port}")
     print(f"http://{ip}:{port}/docs")
-    uvicorn.run("app.main:app", host="0.0.0.0", port=port, reload=True)
+    uvicorn.run("app.main:app", host="0.0.0.0", port=port)