|
@@ -1,667 +0,0 @@
|
|
|
-# -*- coding: utf-8 -*-
|
|
|
|
|
-# Author: Charley
|
|
|
|
|
-# Python: 3.10+
|
|
|
|
|
-# Date: 2026/04/16
|
|
|
|
|
-# Desc: 异步MySQL连接池 - 专为爬虫高并发场景优化
|
|
|
|
|
-import asyncio
|
|
|
|
|
-import re
|
|
|
|
|
-import aiomysql
|
|
|
|
|
-from loguru import logger
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class AsyncMySQLPool:
|
|
|
|
|
- """
|
|
|
|
|
- 异步MySQL连接池 - 专为爬虫高并发场景优化
|
|
|
|
|
-
|
|
|
|
|
- 相比同步版本(pymysql)的优势:
|
|
|
|
|
- 1. 真正的异步非阻塞I/O,爬虫并发爬取时不会因DB写入而阻塞
|
|
|
|
|
- 2. 支持连接池复用,减少连接创建开销
|
|
|
|
|
- 3. 批量插入更高效,支持更大的批次
|
|
|
|
|
-
|
|
|
|
|
- 使用示例:
|
|
|
|
|
- pool = AsyncMySQLPool()
|
|
|
|
|
- await pool.init()
|
|
|
|
|
-
|
|
|
|
|
- # 批量插入
|
|
|
|
|
- await pool.insert_many("table_name", [{"col1": "val1"}, {"col2": "val2"}])
|
|
|
|
|
-
|
|
|
|
|
- # 或者在async with中使用
|
|
|
|
|
- async with pool.get_connection() as conn:
|
|
|
|
|
- async with conn.cursor() as cursor:
|
|
|
|
|
- await cursor.execute("SELECT * FROM table LIMIT 10")
|
|
|
|
|
- results = await cursor.fetchall()
|
|
|
|
|
-
|
|
|
|
|
- await pool.close()
|
|
|
|
|
- """
|
|
|
|
|
-
|
|
|
|
|
- def __init__(
|
|
|
|
|
- self,
|
|
|
|
|
- min_size=5,
|
|
|
|
|
- max_size=20,
|
|
|
|
|
- pool_recycle=3600,
|
|
|
|
|
- connect_timeout=10,
|
|
|
|
|
- write_timeout=30,
|
|
|
|
|
- log=None
|
|
|
|
|
- ):
|
|
|
|
|
- """
|
|
|
|
|
- 初始化异步连接池
|
|
|
|
|
- :param min_size: 池中最小连接数
|
|
|
|
|
- :param max_size: 池中最大连接数
|
|
|
|
|
- :param pool_recycle: 连接回收时间(秒),防止MySQL 8小时超时
|
|
|
|
|
- :param connect_timeout: 连接超时(秒)
|
|
|
|
|
- :param write_timeout: 写入超时(秒)
|
|
|
|
|
- :param log: 自定义日志记录器
|
|
|
|
|
- """
|
|
|
|
|
- self.log = log or logger
|
|
|
|
|
- self.min_size = min_size
|
|
|
|
|
- self.max_size = max_size
|
|
|
|
|
- self.pool_recycle = pool_recycle
|
|
|
|
|
- self.connect_timeout = connect_timeout
|
|
|
|
|
- self.write_timeout = write_timeout
|
|
|
|
|
- self._pool = None
|
|
|
|
|
-
|
|
|
|
|
- # 从 YamlLoader 读取配置
|
|
|
|
|
- from YamlLoader import readYaml
|
|
|
|
|
- yaml = readYaml()
|
|
|
|
|
- mysql_yaml = yaml.get("mysql")
|
|
|
|
|
- self.host = mysql_yaml.getValueAsString("host")
|
|
|
|
|
- self.port = mysql_yaml.getValueAsInt("port")
|
|
|
|
|
- self.user = mysql_yaml.getValueAsString("username")
|
|
|
|
|
- self.password = mysql_yaml.getValueAsString("password")
|
|
|
|
|
- self.database = mysql_yaml.getValueAsString("db")
|
|
|
|
|
-
|
|
|
|
|
- async def init(self):
|
|
|
|
|
- """初始化连接池"""
|
|
|
|
|
- if self._pool is not None:
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- self._pool = await aiomysql.create_pool(
|
|
|
|
|
- minsize=self.min_size,
|
|
|
|
|
- maxsize=self.max_size,
|
|
|
|
|
- host=self.host,
|
|
|
|
|
- port=self.port,
|
|
|
|
|
- user=self.user,
|
|
|
|
|
- password=self.password,
|
|
|
|
|
- db=self.database,
|
|
|
|
|
- autocommit=True, # 自动提交,避免每次手动commit
|
|
|
|
|
- charset='utf8mb4',
|
|
|
|
|
- connect_timeout=self.connect_timeout,
|
|
|
|
|
- write_timeout=self.write_timeout,
|
|
|
|
|
- pool_recycle=self.pool_recycle,
|
|
|
|
|
- )
|
|
|
|
|
- self.log.info(
|
|
|
|
|
- f"异步MySQL连接池初始化成功: {self.host}:{self.port}/{self.database}, "
|
|
|
|
|
- f"连接池大小: {self.min_size}-{self.max_size}"
|
|
|
|
|
- )
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- self.log.error(f"异步MySQL连接池初始化失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
-
|
|
|
|
|
- async def close(self):
|
|
|
|
|
- """关闭连接池"""
|
|
|
|
|
- if self._pool:
|
|
|
|
|
- self._pool.close()
|
|
|
|
|
- await self._pool.wait_closed()
|
|
|
|
|
- self._pool = None
|
|
|
|
|
- self.log.info("异步MySQL连接池已关闭")
|
|
|
|
|
-
|
|
|
|
|
- async def get_connection(self):
|
|
|
|
|
- """
|
|
|
|
|
- 获取连接(推荐使用 async with)
|
|
|
|
|
- :return: async with 可用的连接对象
|
|
|
|
|
- """
|
|
|
|
|
- if self._pool is None:
|
|
|
|
|
- await self.init()
|
|
|
|
|
- return self._pool.acquire()
|
|
|
|
|
-
|
|
|
|
|
- async def check_pool_health(self):
|
|
|
|
|
- """检查连接池健康状态"""
|
|
|
|
|
- try:
|
|
|
|
|
- async with self.get_connection() as conn:
|
|
|
|
|
- async with conn.cursor() as cursor:
|
|
|
|
|
- await cursor.execute("SELECT 1")
|
|
|
|
|
- result = await cursor.fetchone()
|
|
|
|
|
- return result[0] == 1
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- self.log.error(f"连接池健康检查失败: {e}")
|
|
|
|
|
- return False
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- def _safe_identifier(name):
|
|
|
|
|
- """SQL标识符安全校验"""
|
|
|
|
|
- if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
|
|
|
|
|
- raise ValueError(f"Invalid SQL identifier: {name}")
|
|
|
|
|
- return name
|
|
|
|
|
-
|
|
|
|
|
- async def execute(self, query, args=None, commit=True):
|
|
|
|
|
- """
|
|
|
|
|
- 执行SQL
|
|
|
|
|
- :param query: SQL语句
|
|
|
|
|
- :param args: SQL参数
|
|
|
|
|
- :param commit: 是否提交(autocommit=True时忽略)
|
|
|
|
|
- :return: cursor
|
|
|
|
|
- """
|
|
|
|
|
- async with self.get_connection() as conn:
|
|
|
|
|
- async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
|
|
|
- await cursor.execute(query, args)
|
|
|
|
|
- if commit and not conn.get_autocommit():
|
|
|
|
|
- await conn.commit()
|
|
|
|
|
- self.log.debug(f"SQL执行: {query[:80]}..., 影响行数: {cursor.rowcount}")
|
|
|
|
|
- return cursor
|
|
|
|
|
-
|
|
|
|
|
- async def select_one(self, query, args=None):
|
|
|
|
|
- """查询单条"""
|
|
|
|
|
- cursor = await self.execute(query, args, commit=False)
|
|
|
|
|
- return await cursor.fetchone()
|
|
|
|
|
-
|
|
|
|
|
- async def select_all(self, query, args=None):
|
|
|
|
|
- """查询全部"""
|
|
|
|
|
- cursor = await self.execute(query, args, commit=False)
|
|
|
|
|
- return await cursor.fetchall()
|
|
|
|
|
-
|
|
|
|
|
- async def insert_one(self, query, args):
|
|
|
|
|
- """插入单条"""
|
|
|
|
|
- cursor = await self.execute(query, args)
|
|
|
|
|
- return cursor.lastrowid
|
|
|
|
|
-
|
|
|
|
|
- async def insert_one_or_dict(self, table, data, ignore=False, commit=True):
|
|
|
|
|
- """
|
|
|
|
|
- 字典格式插入单条
|
|
|
|
|
- :param table: 表名
|
|
|
|
|
- :param data: 字典 {列名: 值}
|
|
|
|
|
- :param ignore: 是否使用INSERT IGNORE
|
|
|
|
|
- :param commit: 是否提交
|
|
|
|
|
- """
|
|
|
|
|
- if not isinstance(data, dict):
|
|
|
|
|
- raise ValueError("data must be a dictionary")
|
|
|
|
|
-
|
|
|
|
|
- keys = ', '.join([self._safe_identifier(k) for k in data.keys()])
|
|
|
|
|
- values = ', '.join(['%s'] * len(data))
|
|
|
|
|
- ignore_clause = "IGNORE" if ignore else ""
|
|
|
|
|
- query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values})"
|
|
|
|
|
- args = tuple(data.values())
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- cursor = await self.execute(query, args, commit=commit)
|
|
|
|
|
- return cursor.lastrowid
|
|
|
|
|
- except aiomysql.IntegrityError as e:
|
|
|
|
|
- if "Duplicate entry" in str(e):
|
|
|
|
|
- self.log.debug(f"跳过重复条目: {e}")
|
|
|
|
|
- return -1
|
|
|
|
|
- raise
|
|
|
|
|
-
|
|
|
|
|
- async def insert_many(
|
|
|
|
|
- self,
|
|
|
|
|
- table,
|
|
|
|
|
- data_list,
|
|
|
|
|
- batch_size=2000,
|
|
|
|
|
- ignore=False,
|
|
|
|
|
- commit=True
|
|
|
|
|
- ):
|
|
|
|
|
- """
|
|
|
|
|
- 批量插入(字典列表) - 高性能版本
|
|
|
|
|
-
|
|
|
|
|
- 优化点:
|
|
|
|
|
- 1. 使用事务批量提交,减少I/O次数
|
|
|
|
|
- 2. 批次大小增加到2000,减少Python循环次数
|
|
|
|
|
- 3. 失败时自动降级为逐条插入(仅跳过重复,不阻塞)
|
|
|
|
|
-
|
|
|
|
|
- :param table: 表名
|
|
|
|
|
- :param data_list: 字典列表 [{列名: 值}, ...]
|
|
|
|
|
- :param batch_size: 每批插入数量(默认2000)
|
|
|
|
|
- :param ignore: 是否使用INSERT IGNORE
|
|
|
|
|
- :param commit: 是否提交
|
|
|
|
|
- :return: 成功插入的行数
|
|
|
|
|
- """
|
|
|
|
|
- if not data_list:
|
|
|
|
|
- return 0
|
|
|
|
|
-
|
|
|
|
|
- if not isinstance(data_list[0], dict):
|
|
|
|
|
- raise ValueError("data_list must be list of dictionaries")
|
|
|
|
|
-
|
|
|
|
|
- keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
|
|
|
|
|
- values_placeholder = ', '.join(['%s'] * len(data_list[0]))
|
|
|
|
|
- ignore_clause = "IGNORE" if ignore else ""
|
|
|
|
|
- query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
|
|
|
|
|
-
|
|
|
|
|
- total_inserted = 0
|
|
|
|
|
- total_skipped = 0
|
|
|
|
|
-
|
|
|
|
|
- async with self.get_connection() as conn:
|
|
|
|
|
- async with conn.cursor() as cursor:
|
|
|
|
|
- try:
|
|
|
|
|
- for i in range(0, len(data_list), batch_size):
|
|
|
|
|
- batch = data_list[i:i + batch_size]
|
|
|
|
|
- args_list = [tuple(d.values()) for d in batch]
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- await cursor.executemany(query, args_list)
|
|
|
|
|
- if commit:
|
|
|
|
|
- await conn.commit()
|
|
|
|
|
- total_inserted += cursor.rowcount
|
|
|
|
|
-
|
|
|
|
|
- except aiomysql.IntegrityError as e:
|
|
|
|
|
- await conn.rollback()
|
|
|
|
|
- if "Duplicate entry" in str(e) and not ignore:
|
|
|
|
|
- # 降级为逐条插入
|
|
|
|
|
- inserted, skipped = await self._insert_batch_one_by_one(
|
|
|
|
|
- cursor, query, batch, commit
|
|
|
|
|
- )
|
|
|
|
|
- total_inserted += inserted
|
|
|
|
|
- total_skipped += skipped
|
|
|
|
|
- else:
|
|
|
|
|
- raise
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- await conn.rollback()
|
|
|
|
|
- self.log.error(f"批量插入失败: {e}")
|
|
|
|
|
- # 降级为逐条插入
|
|
|
|
|
- inserted, skipped = await self._insert_batch_one_by_one(
|
|
|
|
|
- cursor, query, batch, commit
|
|
|
|
|
- )
|
|
|
|
|
- total_inserted += inserted
|
|
|
|
|
- total_skipped += skipped
|
|
|
|
|
-
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- self.log.exception(f"批量插入最终失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
-
|
|
|
|
|
- if total_skipped > 0:
|
|
|
|
|
- self.log.info(f"插入完成: 成功{total_inserted}条, 跳过重复{total_skipped}条")
|
|
|
|
|
- else:
|
|
|
|
|
- self.log.info(f"插入完成: 成功{total_inserted}条")
|
|
|
|
|
-
|
|
|
|
|
- return total_inserted
|
|
|
|
|
-
|
|
|
|
|
- async def _insert_batch_one_by_one(self, cursor, query, batch, commit):
|
|
|
|
|
- """
|
|
|
|
|
- 逐条插入(降级方案)
|
|
|
|
|
- :return: (插入数, 跳过数)
|
|
|
|
|
- """
|
|
|
|
|
- inserted = 0
|
|
|
|
|
- skipped = 0
|
|
|
|
|
- for data in batch:
|
|
|
|
|
- try:
|
|
|
|
|
- args = tuple(data.values())
|
|
|
|
|
- await cursor.execute(query, args)
|
|
|
|
|
- if commit:
|
|
|
|
|
- await cursor.execute("COMMIT")
|
|
|
|
|
- inserted += 1
|
|
|
|
|
- except aiomysql.IntegrityError as e:
|
|
|
|
|
- if "Duplicate entry" in str(e):
|
|
|
|
|
- skipped += 1
|
|
|
|
|
- else:
|
|
|
|
|
- self.log.error(f"插入失败: {e}")
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- self.log.error(f"插入失败: {e}")
|
|
|
|
|
- return inserted, skipped
|
|
|
|
|
-
|
|
|
|
|
- async def insert_raw_many(self, query, args_list, batch_size=2000, commit=True):
|
|
|
|
|
- """
|
|
|
|
|
- 批量插入(原始SQL) - 高性能版本
|
|
|
|
|
-
|
|
|
|
|
- :param query: 预编译SQL语句,使用 %s 占位符
|
|
|
|
|
- :param args_list: 参数列表 [(val1, val2), ...]
|
|
|
|
|
- :param batch_size: 每批数量
|
|
|
|
|
- :param commit: 是否提交
|
|
|
|
|
- :return: 成功插入行数
|
|
|
|
|
- """
|
|
|
|
|
- if not args_list:
|
|
|
|
|
- return 0
|
|
|
|
|
-
|
|
|
|
|
- total = 0
|
|
|
|
|
- async with self.get_connection() as conn:
|
|
|
|
|
- async with conn.cursor() as cursor:
|
|
|
|
|
- for i in range(0, len(args_list), batch_size):
|
|
|
|
|
- batch = args_list[i:i + batch_size]
|
|
|
|
|
- try:
|
|
|
|
|
- await cursor.executemany(query, batch)
|
|
|
|
|
- if commit:
|
|
|
|
|
- await conn.commit()
|
|
|
|
|
- total += cursor.rowcount
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- await conn.rollback()
|
|
|
|
|
- self.log.error(f"批量插入失败 [{i}]: {e}")
|
|
|
|
|
- # 降级为逐条
|
|
|
|
|
- for args in batch:
|
|
|
|
|
- try:
|
|
|
|
|
- await cursor.execute(query, args)
|
|
|
|
|
- if commit:
|
|
|
|
|
- await conn.commit()
|
|
|
|
|
- total += 1
|
|
|
|
|
- except Exception as e2:
|
|
|
|
|
- if "Duplicate" not in str(e2):
|
|
|
|
|
- self.log.error(f"单条插入失败: {e2}")
|
|
|
|
|
-
|
|
|
|
|
- self.log.debug(f"insert_raw_many 完成: {total}条")
|
|
|
|
|
- return total
|
|
|
|
|
-
|
|
|
|
|
- async def insert_or_update(self, table, data, update_columns=None, commit=True):
|
|
|
|
|
- """
|
|
|
|
|
- 插入或更新 (INSERT ... ON DUPLICATE KEY UPDATE)
|
|
|
|
|
-
|
|
|
|
|
- :param table: 表名
|
|
|
|
|
- :param data: 字典 {列名: 值}
|
|
|
|
|
- :param update_columns: 更新时更新的列,None表示更新所有非主键列
|
|
|
|
|
- :param commit: 是否提交
|
|
|
|
|
- """
|
|
|
|
|
- if not isinstance(data, dict):
|
|
|
|
|
- raise ValueError("data must be a dictionary")
|
|
|
|
|
-
|
|
|
|
|
- keys = ', '.join([self._safe_identifier(k) for k in data.keys()])
|
|
|
|
|
- values_placeholder = ', '.join(['%s'] * len(data))
|
|
|
|
|
- query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
|
|
|
|
|
-
|
|
|
|
|
- if update_columns:
|
|
|
|
|
- update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})"
|
|
|
|
|
- for col in update_columns])
|
|
|
|
|
- else:
|
|
|
|
|
- update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})"
|
|
|
|
|
- for k in data.keys()])
|
|
|
|
|
-
|
|
|
|
|
- query += f" ON DUPLICATE KEY UPDATE {update_clause}"
|
|
|
|
|
-
|
|
|
|
|
- cursor = await self.execute(query, tuple(data.values()), commit=commit)
|
|
|
|
|
- return cursor.lastrowid
|
|
|
|
|
-
|
|
|
|
|
- async def insert_or_update_many(
|
|
|
|
|
- self,
|
|
|
|
|
- table,
|
|
|
|
|
- data_list,
|
|
|
|
|
- update_columns=None,
|
|
|
|
|
- batch_size=2000,
|
|
|
|
|
- commit=True
|
|
|
|
|
- ):
|
|
|
|
|
- """
|
|
|
|
|
- 批量插入或更新 - 适合爬虫去重场景
|
|
|
|
|
-
|
|
|
|
|
- :param table: 表名
|
|
|
|
|
- :param data_list: 字典列表
|
|
|
|
|
- :param update_columns: 更新时更新的列
|
|
|
|
|
- :param batch_size: 每批数量
|
|
|
|
|
- :param commit: 是否提交
|
|
|
|
|
- """
|
|
|
|
|
- if not data_list or not isinstance(data_list[0], dict):
|
|
|
|
|
- raise ValueError("data_list must be non-empty list of dictionaries")
|
|
|
|
|
-
|
|
|
|
|
- keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
|
|
|
|
|
- values_placeholder = ', '.join(['%s'] * len(data_list[0]))
|
|
|
|
|
- query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
|
|
|
|
|
-
|
|
|
|
|
- if update_columns:
|
|
|
|
|
- update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})"
|
|
|
|
|
- for col in update_columns])
|
|
|
|
|
- else:
|
|
|
|
|
- update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})"
|
|
|
|
|
- for k in data_list[0].keys()])
|
|
|
|
|
-
|
|
|
|
|
- query += f" ON DUPLICATE KEY UPDATE {update_clause}"
|
|
|
|
|
-
|
|
|
|
|
- total = 0
|
|
|
|
|
- async with self.get_connection() as conn:
|
|
|
|
|
- async with conn.cursor() as cursor:
|
|
|
|
|
- for i in range(0, len(data_list), batch_size):
|
|
|
|
|
- batch = data_list[i:i + batch_size]
|
|
|
|
|
- args_list = [tuple(d.values()) for d in batch]
|
|
|
|
|
- try:
|
|
|
|
|
- await cursor.executemany(query, args_list)
|
|
|
|
|
- if commit:
|
|
|
|
|
- await conn.commit()
|
|
|
|
|
- total += cursor.rowcount
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- await conn.rollback()
|
|
|
|
|
- self.log.error(f"批量插入/更新失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
-
|
|
|
|
|
- self.log.info(f"insert_or_update_many 完成: {total}条(包含更新)")
|
|
|
|
|
- return total
|
|
|
|
|
-
|
|
|
|
|
- async def update(self, query, args=None, commit=True):
|
|
|
|
|
- """执行更新"""
|
|
|
|
|
- cursor = await self.execute(query, args, commit=commit)
|
|
|
|
|
- return cursor.rowcount
|
|
|
|
|
-
|
|
|
|
|
- async def update_one_or_dict(self, table, data, condition, commit=True):
|
|
|
|
|
- """
|
|
|
|
|
- 字典格式更新单条
|
|
|
|
|
-
|
|
|
|
|
- :param table: 表名
|
|
|
|
|
- :param data: 要更新的数据 {列名: 值}
|
|
|
|
|
- :param condition: 更新条件 {"id": 1} 或 "id = 1"
|
|
|
|
|
- """
|
|
|
|
|
- set_clause = ', '.join([f"{self._safe_identifier(k)} = %s" for k in data.keys()])
|
|
|
|
|
- query = f"UPDATE {self._safe_identifier(table)} SET {set_clause}"
|
|
|
|
|
-
|
|
|
|
|
- if isinstance(condition, dict):
|
|
|
|
|
- where_clause = ' AND '.join([f"{self._safe_identifier(k)} = %s" for k in condition.keys()])
|
|
|
|
|
- query += f" WHERE {where_clause}"
|
|
|
|
|
- args = list(data.values()) + list(condition.values())
|
|
|
|
|
- else:
|
|
|
|
|
- query += f" WHERE {condition}"
|
|
|
|
|
- args = list(data.values())
|
|
|
|
|
-
|
|
|
|
|
- cursor = await self.execute(query, args, commit=commit)
|
|
|
|
|
- return cursor.rowcount
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class AsyncMySQLBatchWriter:
|
|
|
|
|
- """
|
|
|
|
|
- 异步批量写入器 - 专为爬虫设计
|
|
|
|
|
-
|
|
|
|
|
- 使用方式:
|
|
|
|
|
- 1. 在爬虫主循环中调用 add() 添加数据
|
|
|
|
|
- 2. 内部自动定时批量写入数据库
|
|
|
|
|
- 3. 程序结束时调用 flush() 确保数据全部写入
|
|
|
|
|
-
|
|
|
|
|
- 优势:
|
|
|
|
|
- - 非阻塞写入,不影响爬虫速度
|
|
|
|
|
- - 自动批量合并,减少数据库I/O
|
|
|
|
|
- - 自动去重(基于主键)
|
|
|
|
|
- """
|
|
|
|
|
-
|
|
|
|
|
- def __init__(
|
|
|
|
|
- self,
|
|
|
|
|
- pool: AsyncMySQLPool,
|
|
|
|
|
- table: str,
|
|
|
|
|
- batch_size: int = 2000,
|
|
|
|
|
- flush_interval: float = 2.0,
|
|
|
|
|
- dedup: bool = True,
|
|
|
|
|
- update_columns: list = None
|
|
|
|
|
- ):
|
|
|
|
|
- """
|
|
|
|
|
- :param pool: AsyncMySQLPool 实例
|
|
|
|
|
- :param table: 表名
|
|
|
|
|
- :param batch_size: 每批写入数量
|
|
|
|
|
- :param flush_interval: 自动刷新间隔(秒)
|
|
|
|
|
- :param dedup: 是否基于字典去重(仅保留最新)
|
|
|
|
|
- :param update_columns: 更新时更新的列,None表示全部更新
|
|
|
|
|
- """
|
|
|
|
|
- self.pool = pool
|
|
|
|
|
- self.table = table
|
|
|
|
|
- self.batch_size = batch_size
|
|
|
|
|
- self.flush_interval = flush_interval
|
|
|
|
|
- self.dedup = dedup
|
|
|
|
|
- self.update_columns = update_columns
|
|
|
|
|
- self._buffer = {} # 使用dict去重
|
|
|
|
|
- self._keys = None
|
|
|
|
|
- self._last_flush = asyncio.get_event_loop().time()
|
|
|
|
|
- self._lock = asyncio.Lock()
|
|
|
|
|
-
|
|
|
|
|
- def add(self, data: dict):
|
|
|
|
|
- """
|
|
|
|
|
- 添加数据到缓冲区(线程安全)
|
|
|
|
|
- :param data: 字典格式数据
|
|
|
|
|
- """
|
|
|
|
|
- if self._keys is None:
|
|
|
|
|
- self._keys = tuple(data.keys())
|
|
|
|
|
-
|
|
|
|
|
- if self.dedup:
|
|
|
|
|
- # 基于第一个字段去重(假设是主键)
|
|
|
|
|
- key = data.get(self._keys[0])
|
|
|
|
|
- self._buffer[key] = data
|
|
|
|
|
- else:
|
|
|
|
|
- self._buffer[len(self._buffer)] = data
|
|
|
|
|
-
|
|
|
|
|
- async def _should_flush(self) -> bool:
|
|
|
|
|
- """检查是否需要刷新"""
|
|
|
|
|
- if len(self._buffer) >= self.batch_size:
|
|
|
|
|
- return True
|
|
|
|
|
- elapsed = asyncio.get_event_loop().time() - self._last_flush
|
|
|
|
|
- if elapsed >= self.flush_interval and self._buffer:
|
|
|
|
|
- return True
|
|
|
|
|
- return False
|
|
|
|
|
-
|
|
|
|
|
- async def flush(self):
|
|
|
|
|
- """强制刷新缓冲区"""
|
|
|
|
|
- async with self._lock:
|
|
|
|
|
- if not self._buffer:
|
|
|
|
|
- return 0
|
|
|
|
|
-
|
|
|
|
|
- data_list = list(self._buffer.values())
|
|
|
|
|
- self._buffer.clear()
|
|
|
|
|
- self._last_flush = asyncio.get_event_loop().time()
|
|
|
|
|
-
|
|
|
|
|
- try:
|
|
|
|
|
- if self.update_columns:
|
|
|
|
|
- return await self.pool.insert_or_update_many(
|
|
|
|
|
- self.table, data_list, self.update_columns
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- return await self.pool.insert_many(self.table, data_list)
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"批量写入失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
-
|
|
|
|
|
- async def auto_flush(self):
|
|
|
|
|
- """自动刷新协程 - 应在后台运行"""
|
|
|
|
|
- while True:
|
|
|
|
|
- await asyncio.sleep(0.5) # 每0.5秒检查一次
|
|
|
|
|
- if await self._should_flush():
|
|
|
|
|
- try:
|
|
|
|
|
- await self.flush()
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"自动刷新失败: {e}")
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-# ============================================================
|
|
|
|
|
-# 同步封装层 - 方便在同步代码中调用
|
|
|
|
|
-# ============================================================
|
|
|
|
|
-import threading
|
|
|
|
|
-from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
-
|
|
|
|
|
-_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="mysql_async_")
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def _run_async(coro):
|
|
|
|
|
- """在新线程中运行异步协程"""
|
|
|
|
|
- loop = asyncio.new_event_loop()
|
|
|
|
|
- asyncio.set_event_loop(loop)
|
|
|
|
|
- try:
|
|
|
|
|
- return loop.run_until_complete(coro)
|
|
|
|
|
- finally:
|
|
|
|
|
- loop.close()
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-class SyncWrapper:
|
|
|
|
|
- """
|
|
|
|
|
- 同步封装 - 让异步连接池可以在同步代码中使用
|
|
|
|
|
-
|
|
|
|
|
- 使用示例:
|
|
|
|
|
- pool = AsyncMySQLPool()
|
|
|
|
|
- sync_pool = SyncWrapper(pool)
|
|
|
|
|
- sync_pool.init() # 同步初始化
|
|
|
|
|
-
|
|
|
|
|
- # 同步调用
|
|
|
|
|
- sync_pool.insert_many("table", [{"col": "val"}])
|
|
|
|
|
- sync_pool.close()
|
|
|
|
|
- """
|
|
|
|
|
-
|
|
|
|
|
- def __init__(self, async_pool: AsyncMySQLPool):
|
|
|
|
|
- self.async_pool = async_pool
|
|
|
|
|
- self._loop = None
|
|
|
|
|
- self._thread = None
|
|
|
|
|
- self._running = False
|
|
|
|
|
-
|
|
|
|
|
- def _run_loop(self):
|
|
|
|
|
- """在线程中运行事件循环"""
|
|
|
|
|
- asyncio.set_event_loop(self._loop)
|
|
|
|
|
- self._loop.run_forever()
|
|
|
|
|
-
|
|
|
|
|
- def init(self):
|
|
|
|
|
- """同步初始化"""
|
|
|
|
|
- if self._running:
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
- self._loop = asyncio.new_event_loop()
|
|
|
|
|
- self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
|
|
|
|
- self._thread.start()
|
|
|
|
|
- self._running = True
|
|
|
|
|
-
|
|
|
|
|
- # 在新线程中初始化
|
|
|
|
|
- asyncio.run_coroutine_threadsafe(self.async_pool.init(), self._loop).result()
|
|
|
|
|
-
|
|
|
|
|
- def close(self):
|
|
|
|
|
- """同步关闭"""
|
|
|
|
|
- if not self._running:
|
|
|
|
|
- return
|
|
|
|
|
-
|
|
|
|
|
- asyncio.run_coroutine_threadsafe(self.async_pool.close(), self._loop).result()
|
|
|
|
|
-
|
|
|
|
|
- if self._loop:
|
|
|
|
|
- self._loop.call_soon_threadsafe(self._loop.stop)
|
|
|
|
|
- self._thread.join(timeout=2)
|
|
|
|
|
- self._running = False
|
|
|
|
|
-
|
|
|
|
|
- def insert_many(self, table, data_list, **kwargs):
|
|
|
|
|
- """同步批量插入"""
|
|
|
|
|
- coro = self.async_pool.insert_many(table, data_list, **kwargs)
|
|
|
|
|
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
|
|
|
|
-
|
|
|
|
|
- def insert_one_or_dict(self, table, data, **kwargs):
|
|
|
|
|
- """同步单条插入"""
|
|
|
|
|
- coro = self.async_pool.insert_one_or_dict(table, data, **kwargs)
|
|
|
|
|
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
|
|
|
|
-
|
|
|
|
|
- def select_all(self, query, args=None):
|
|
|
|
|
- """同步查询全部"""
|
|
|
|
|
- coro = self.async_pool.select_all(query, args)
|
|
|
|
|
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
|
|
|
|
-
|
|
|
|
|
- def select_one(self, query, args=None):
|
|
|
|
|
- """同步查询单条"""
|
|
|
|
|
- coro = self.async_pool.select_one(query, args)
|
|
|
|
|
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
|
|
- import asyncio
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
- async def test():
|
|
|
|
|
- pool = AsyncMySQLPool(min_size=5, max_size=20)
|
|
|
|
|
- await pool.init()
|
|
|
|
|
-
|
|
|
|
|
- # 健康检查
|
|
|
|
|
- health = await pool.check_pool_health()
|
|
|
|
|
- print(f"连接池健康: {health}")
|
|
|
|
|
-
|
|
|
|
|
- # 测试插入
|
|
|
|
|
- test_data = [
|
|
|
|
|
- {"card_id": 1, "card_name": "测试卡牌1", "card_type": "角色"},
|
|
|
|
|
- {"card_id": 2, "card_name": "测试卡牌2", "card_type": "角色"},
|
|
|
|
|
- ]
|
|
|
|
|
- result = await pool.insert_many("one_piece_record", test_data, ignore=True)
|
|
|
|
|
- print(f"插入结果: {result}")
|
|
|
|
|
-
|
|
|
|
|
- # 测试插入或更新
|
|
|
|
|
- await pool.insert_or_update_many(
|
|
|
|
|
- "one_piece_record",
|
|
|
|
|
- [{"card_id": 1, "card_name": "更新卡牌1"}],
|
|
|
|
|
- update_columns=["card_name"]
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- await pool.close()
|
|
|
|
|
- print("测试完成!")
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
- asyncio.run(test())
|
|
|