Переглянути джерело

io推理上传进行性能限制

AnlaAnla 1 місяць тому
батько
коміт
8693858a4e
1 змінених файлів з 12 додано та 22 видалено
  1. 12 22
      app/api/auto_import.py

+ 12 - 22
app/api/auto_import.py

@@ -185,7 +185,6 @@ async def auto_import_script_api(
         "back_gray": back_gray
     }
 
-    # 【改动点2】过滤文件时,增加 isinstance(v, UploadFile) 的判断,剔除空字符串
     valid_main_files = {
         k: v for k, v in main_inputs.items()
         if (v is not None) and v.filename
@@ -205,11 +204,9 @@ async def auto_import_script_api(
 
     async with aiohttp.ClientSession() as session:
         try:
-            # 读取所有图片至内存
             main_bytes_data = {k: (await v.read(), v.filename) for k, v in valid_main_files.items()}
             gray_bytes_data = {k: (await v.read(), v.filename) for k, v in valid_gray_files.items()}
 
-            # Step 1: 主图顺序推理 (防止瞬间塞爆推理服务器)
             logger.info(f"--- 开始自动导入任务: {card_name} ---")
             processed_results = []
             for img_type, (f_bytes, f_name) in main_bytes_data.items():
@@ -219,22 +216,22 @@ async def auto_import_script_api(
                 res = await safe_process_main_image(session, f_bytes, f_name, img_type, is_reflect_str)
                 processed_results.append(res)
 
-            # Step 2: 在自身数据库创建卡片记录
             card_id = await create_card_record(
                 session, local_base_url, card_name, cardNo, card_type
             )
             logger.info(f"卡片记录创建成功,ID: {card_id}")
 
-            # Step 3: 并发调用自身的图片保存接口
+            # ---------- 修改点:safe_upload_task 调用 + gather ----------
             upload_tasks = []
             for res in processed_results:
-                upload_tasks.append(safe_upload_task(session, local_base_url, card_id, res))
+                upload_tasks.append(safe_upload_task(upload_main_image, session, local_base_url, card_id, res))
 
             for img_type, (f_bytes, f_name) in gray_bytes_data.items():
-                upload_tasks.append(upload_gray_image(session, local_base_url, card_id, img_type, f_bytes, f_name))
+                upload_tasks.append(safe_upload_task(upload_gray_image, session, local_base_url, card_id, img_type, f_bytes, f_name))
 
             if upload_tasks:
                 await asyncio.gather(*upload_tasks)
+            # ------------------------------------------------------------
 
             logger.info(f"--- 自动导入流程结束, Card ID: {card_id} ---")
             return {
@@ -250,7 +247,6 @@ async def auto_import_script_api(
             logger.error(f"[流程终止] 发生异常: {e}")
             raise HTTPException(status_code=500, detail=f"自动化处理异常: {str(e)}")
 
-
 @router.post("/process_and_import_url", summary="通过URL自动化处理并导入卡牌数据")
 async def auto_import_url_script_api(
         request: Request,
@@ -277,7 +273,6 @@ async def auto_import_url_script_api(
         "front_gray": front_gray, "back_gray": back_gray
     }
 
-    # 过滤掉空字符串
     valid_main_urls = {k: v for k, v in main_inputs.items() if v and v.strip()}
     valid_gray_urls = {k: v for k, v in gray_inputs.items() if v and v.strip()}
 
@@ -293,15 +288,12 @@ async def auto_import_url_script_api(
         try:
             logger.info(f"--- 开始URL自动导入任务: {card_name} ---")
 
-            # 1. 并发下载图片至内存
             async def fetch_image(img_key: str, img_url: str):
                 try:
                     async with session.get(img_url) as resp:
                         if resp.status != 200:
                             raise HTTPException(status_code=400, detail=f"下载图片失败: {img_key} -> {resp.status}")
                         file_bytes = await resp.read()
-
-                        # 尝试从 url 解析出文件名,否则使用默认名称
                         filename = img_url.split('/')[-1].split('?')[0]
                         if not filename or '.' not in filename:
                             filename = f"{img_key}.jpg"
@@ -310,13 +302,11 @@ async def auto_import_url_script_api(
                     if isinstance(e, HTTPException): raise e
                     raise HTTPException(status_code=400, detail=f"访问图片URL异常: {img_key} -> {str(e)}")
 
-            fetch_tasks = []
-            for k, url in valid_main_urls.items(): fetch_tasks.append(fetch_image(k, url))
-            for k, url in valid_gray_urls.items(): fetch_tasks.append(fetch_image(k, url))
+            fetch_tasks = [fetch_image(k, url) for k, url in valid_main_urls.items()]
+            fetch_tasks += [fetch_image(k, url) for k, url in valid_gray_urls.items()]
 
             downloaded_files = await asyncio.gather(*fetch_tasks)
 
-            # 分拣主图与灰度图的 bytes
             main_bytes_data = {}
             gray_bytes_data = {}
             for key, data in downloaded_files:
@@ -325,7 +315,6 @@ async def auto_import_url_script_api(
                 else:
                     gray_bytes_data[key] = data
 
-            # 2. 复用原有逻辑 - 主图顺序推理
             processed_results = []
             for img_type, (f_bytes, f_name) in main_bytes_data.items():
                 if len(f_bytes) == 0:
@@ -333,19 +322,20 @@ async def auto_import_url_script_api(
                 res = await safe_process_main_image(session, f_bytes, f_name, img_type, is_reflect_str)
                 processed_results.append(res)
 
-            # 3. 在自身数据库创建卡片记录
             card_id = await create_card_record(session, local_base_url, card_name, cardNo, card_type)
             logger.info(f"URL导入卡片记录创建成功,ID: {card_id}")
 
-            # 4. 并发调用自身的图片保存接口
+            # ---------- 修改点:safe_upload_task 调用 + gather ----------
             upload_tasks = []
             for res in processed_results:
-                upload_tasks.append(safe_upload_task(session, local_base_url, card_id, res))
+                upload_tasks.append(safe_upload_task(upload_main_image, session, local_base_url, card_id, res))
+
             for img_type, (f_bytes, f_name) in gray_bytes_data.items():
-                upload_tasks.append(upload_gray_image(session, local_base_url, card_id, img_type, f_bytes, f_name))
+                upload_tasks.append(safe_upload_task(upload_gray_image, session, local_base_url, card_id, img_type, f_bytes, f_name))
 
             if upload_tasks:
                 await asyncio.gather(*upload_tasks)
+            # ------------------------------------------------------------
 
             logger.info(f"--- URL自动导入流程结束, Card ID: {card_id} ---")
             return {
@@ -359,4 +349,4 @@ async def auto_import_url_script_api(
             raise
         except Exception as e:
             logger.error(f"[URL导入流程终止] 发生异常: {e}")
-            raise HTTPException(status_code=500, detail=f"自动化处理异常: {str(e)}")
+            raise HTTPException(status_code=500, detail=f"自动化处理异常: {str(e)}")