# -*- coding: utf-8 -*- # Author: Charley # Python: 3.10+ # Date: 2026/04/16 # Desc: 异步MySQL连接池 - 专为爬虫高并发场景优化 import asyncio import re import aiomysql from loguru import logger class AsyncMySQLPool: """ 异步MySQL连接池 - 专为爬虫高并发场景优化 相比同步版本(pymysql)的优势: 1. 真正的异步非阻塞I/O,爬虫并发爬取时不会因DB写入而阻塞 2. 支持连接池复用,减少连接创建开销 3. 批量插入更高效,支持更大的批次 使用示例: pool = AsyncMySQLPool() await pool.init() # 批量插入 await pool.insert_many("table_name", [{"col1": "val1"}, {"col2": "val2"}]) # 或者在async with中使用 async with pool.get_connection() as conn: async with conn.cursor() as cursor: await cursor.execute("SELECT * FROM table LIMIT 10") results = await cursor.fetchall() await pool.close() """ def __init__( self, min_size=5, max_size=20, pool_recycle=3600, connect_timeout=10, write_timeout=30, log=None ): """ 初始化异步连接池 :param min_size: 池中最小连接数 :param max_size: 池中最大连接数 :param pool_recycle: 连接回收时间(秒),防止MySQL 8小时超时 :param connect_timeout: 连接超时(秒) :param write_timeout: 写入超时(秒) :param log: 自定义日志记录器 """ self.log = log or logger self.min_size = min_size self.max_size = max_size self.pool_recycle = pool_recycle self.connect_timeout = connect_timeout self.write_timeout = write_timeout self._pool = None # 从 YamlLoader 读取配置 from YamlLoader import readYaml yaml = readYaml() mysql_yaml = yaml.get("mysql") self.host = mysql_yaml.getValueAsString("host") self.port = mysql_yaml.getValueAsInt("port") self.user = mysql_yaml.getValueAsString("username") self.password = mysql_yaml.getValueAsString("password") self.database = mysql_yaml.getValueAsString("db") async def init(self): """初始化连接池""" if self._pool is not None: return try: self._pool = await aiomysql.create_pool( minsize=self.min_size, maxsize=self.max_size, host=self.host, port=self.port, user=self.user, password=self.password, db=self.database, autocommit=True, # 自动提交,避免每次手动commit charset='utf8mb4', connect_timeout=self.connect_timeout, write_timeout=self.write_timeout, pool_recycle=self.pool_recycle, ) self.log.info( f"异步MySQL连接池初始化成功: {self.host}:{self.port}/{self.database}, " f"连接池大小: {self.min_size}-{self.max_size}" ) except Exception as e: self.log.error(f"异步MySQL连接池初始化失败: {e}") raise async def close(self): """关闭连接池""" if self._pool: self._pool.close() await self._pool.wait_closed() self._pool = None self.log.info("异步MySQL连接池已关闭") async def get_connection(self): """ 获取连接(推荐使用 async with) :return: async with 可用的连接对象 """ if self._pool is None: await self.init() return self._pool.acquire() async def check_pool_health(self): """检查连接池健康状态""" try: async with self.get_connection() as conn: async with conn.cursor() as cursor: await cursor.execute("SELECT 1") result = await cursor.fetchone() return result[0] == 1 except Exception as e: self.log.error(f"连接池健康检查失败: {e}") return False @staticmethod def _safe_identifier(name): """SQL标识符安全校验""" if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): raise ValueError(f"Invalid SQL identifier: {name}") return name async def execute(self, query, args=None, commit=True): """ 执行SQL :param query: SQL语句 :param args: SQL参数 :param commit: 是否提交(autocommit=True时忽略) :return: cursor """ async with self.get_connection() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query, args) if commit and not conn.get_autocommit(): await conn.commit() self.log.debug(f"SQL执行: {query[:80]}..., 影响行数: {cursor.rowcount}") return cursor async def select_one(self, query, args=None): """查询单条""" cursor = await self.execute(query, args, commit=False) return await cursor.fetchone() async def select_all(self, query, args=None): """查询全部""" cursor = await self.execute(query, args, commit=False) return await cursor.fetchall() async def insert_one(self, query, args): """插入单条""" cursor = await self.execute(query, args) return cursor.lastrowid async def insert_one_or_dict(self, table, data, ignore=False, commit=True): """ 字典格式插入单条 :param table: 表名 :param data: 字典 {列名: 值} :param ignore: 是否使用INSERT IGNORE :param commit: 是否提交 """ if not isinstance(data, dict): raise ValueError("data must be a dictionary") keys = ', '.join([self._safe_identifier(k) for k in data.keys()]) values = ', '.join(['%s'] * len(data)) ignore_clause = "IGNORE" if ignore else "" query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values})" args = tuple(data.values()) try: cursor = await self.execute(query, args, commit=commit) return cursor.lastrowid except aiomysql.IntegrityError as e: if "Duplicate entry" in str(e): self.log.debug(f"跳过重复条目: {e}") return -1 raise async def insert_many( self, table, data_list, batch_size=2000, ignore=False, commit=True ): """ 批量插入(字典列表) - 高性能版本 优化点: 1. 使用事务批量提交,减少I/O次数 2. 批次大小增加到2000,减少Python循环次数 3. 失败时自动降级为逐条插入(仅跳过重复,不阻塞) :param table: 表名 :param data_list: 字典列表 [{列名: 值}, ...] :param batch_size: 每批插入数量(默认2000) :param ignore: 是否使用INSERT IGNORE :param commit: 是否提交 :return: 成功插入的行数 """ if not data_list: return 0 if not isinstance(data_list[0], dict): raise ValueError("data_list must be list of dictionaries") keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()]) values_placeholder = ', '.join(['%s'] * len(data_list[0])) ignore_clause = "IGNORE" if ignore else "" query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})" total_inserted = 0 total_skipped = 0 async with self.get_connection() as conn: async with conn.cursor() as cursor: try: for i in range(0, len(data_list), batch_size): batch = data_list[i:i + batch_size] args_list = [tuple(d.values()) for d in batch] try: await cursor.executemany(query, args_list) if commit: await conn.commit() total_inserted += cursor.rowcount except aiomysql.IntegrityError as e: await conn.rollback() if "Duplicate entry" in str(e) and not ignore: # 降级为逐条插入 inserted, skipped = await self._insert_batch_one_by_one( cursor, query, batch, commit ) total_inserted += inserted total_skipped += skipped else: raise except Exception as e: await conn.rollback() self.log.error(f"批量插入失败: {e}") # 降级为逐条插入 inserted, skipped = await self._insert_batch_one_by_one( cursor, query, batch, commit ) total_inserted += inserted total_skipped += skipped except Exception as e: self.log.exception(f"批量插入最终失败: {e}") raise if total_skipped > 0: self.log.info(f"插入完成: 成功{total_inserted}条, 跳过重复{total_skipped}条") else: self.log.info(f"插入完成: 成功{total_inserted}条") return total_inserted async def _insert_batch_one_by_one(self, cursor, query, batch, commit): """ 逐条插入(降级方案) :return: (插入数, 跳过数) """ inserted = 0 skipped = 0 for data in batch: try: args = tuple(data.values()) await cursor.execute(query, args) if commit: await cursor.execute("COMMIT") inserted += 1 except aiomysql.IntegrityError as e: if "Duplicate entry" in str(e): skipped += 1 else: self.log.error(f"插入失败: {e}") except Exception as e: self.log.error(f"插入失败: {e}") return inserted, skipped async def insert_raw_many(self, query, args_list, batch_size=2000, commit=True): """ 批量插入(原始SQL) - 高性能版本 :param query: 预编译SQL语句,使用 %s 占位符 :param args_list: 参数列表 [(val1, val2), ...] :param batch_size: 每批数量 :param commit: 是否提交 :return: 成功插入行数 """ if not args_list: return 0 total = 0 async with self.get_connection() as conn: async with conn.cursor() as cursor: for i in range(0, len(args_list), batch_size): batch = args_list[i:i + batch_size] try: await cursor.executemany(query, batch) if commit: await conn.commit() total += cursor.rowcount except Exception as e: await conn.rollback() self.log.error(f"批量插入失败 [{i}]: {e}") # 降级为逐条 for args in batch: try: await cursor.execute(query, args) if commit: await conn.commit() total += 1 except Exception as e2: if "Duplicate" not in str(e2): self.log.error(f"单条插入失败: {e2}") self.log.debug(f"insert_raw_many 完成: {total}条") return total async def insert_or_update(self, table, data, update_columns=None, commit=True): """ 插入或更新 (INSERT ... ON DUPLICATE KEY UPDATE) :param table: 表名 :param data: 字典 {列名: 值} :param update_columns: 更新时更新的列,None表示更新所有非主键列 :param commit: 是否提交 """ if not isinstance(data, dict): raise ValueError("data must be a dictionary") keys = ', '.join([self._safe_identifier(k) for k in data.keys()]) values_placeholder = ', '.join(['%s'] * len(data)) query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})" if update_columns: update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})" for col in update_columns]) else: update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})" for k in data.keys()]) query += f" ON DUPLICATE KEY UPDATE {update_clause}" cursor = await self.execute(query, tuple(data.values()), commit=commit) return cursor.lastrowid async def insert_or_update_many( self, table, data_list, update_columns=None, batch_size=2000, commit=True ): """ 批量插入或更新 - 适合爬虫去重场景 :param table: 表名 :param data_list: 字典列表 :param update_columns: 更新时更新的列 :param batch_size: 每批数量 :param commit: 是否提交 """ if not data_list or not isinstance(data_list[0], dict): raise ValueError("data_list must be non-empty list of dictionaries") keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()]) values_placeholder = ', '.join(['%s'] * len(data_list[0])) query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})" if update_columns: update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})" for col in update_columns]) else: update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})" for k in data_list[0].keys()]) query += f" ON DUPLICATE KEY UPDATE {update_clause}" total = 0 async with self.get_connection() as conn: async with conn.cursor() as cursor: for i in range(0, len(data_list), batch_size): batch = data_list[i:i + batch_size] args_list = [tuple(d.values()) for d in batch] try: await cursor.executemany(query, args_list) if commit: await conn.commit() total += cursor.rowcount except Exception as e: await conn.rollback() self.log.error(f"批量插入/更新失败: {e}") raise self.log.info(f"insert_or_update_many 完成: {total}条(包含更新)") return total async def update(self, query, args=None, commit=True): """执行更新""" cursor = await self.execute(query, args, commit=commit) return cursor.rowcount async def update_one_or_dict(self, table, data, condition, commit=True): """ 字典格式更新单条 :param table: 表名 :param data: 要更新的数据 {列名: 值} :param condition: 更新条件 {"id": 1} 或 "id = 1" """ set_clause = ', '.join([f"{self._safe_identifier(k)} = %s" for k in data.keys()]) query = f"UPDATE {self._safe_identifier(table)} SET {set_clause}" if isinstance(condition, dict): where_clause = ' AND '.join([f"{self._safe_identifier(k)} = %s" for k in condition.keys()]) query += f" WHERE {where_clause}" args = list(data.values()) + list(condition.values()) else: query += f" WHERE {condition}" args = list(data.values()) cursor = await self.execute(query, args, commit=commit) return cursor.rowcount class AsyncMySQLBatchWriter: """ 异步批量写入器 - 专为爬虫设计 使用方式: 1. 在爬虫主循环中调用 add() 添加数据 2. 内部自动定时批量写入数据库 3. 程序结束时调用 flush() 确保数据全部写入 优势: - 非阻塞写入,不影响爬虫速度 - 自动批量合并,减少数据库I/O - 自动去重(基于主键) """ def __init__( self, pool: AsyncMySQLPool, table: str, batch_size: int = 2000, flush_interval: float = 2.0, dedup: bool = True, update_columns: list = None ): """ :param pool: AsyncMySQLPool 实例 :param table: 表名 :param batch_size: 每批写入数量 :param flush_interval: 自动刷新间隔(秒) :param dedup: 是否基于字典去重(仅保留最新) :param update_columns: 更新时更新的列,None表示全部更新 """ self.pool = pool self.table = table self.batch_size = batch_size self.flush_interval = flush_interval self.dedup = dedup self.update_columns = update_columns self._buffer = {} # 使用dict去重 self._keys = None self._last_flush = asyncio.get_event_loop().time() self._lock = asyncio.Lock() def add(self, data: dict): """ 添加数据到缓冲区(线程安全) :param data: 字典格式数据 """ if self._keys is None: self._keys = tuple(data.keys()) if self.dedup: # 基于第一个字段去重(假设是主键) key = data.get(self._keys[0]) self._buffer[key] = data else: self._buffer[len(self._buffer)] = data async def _should_flush(self) -> bool: """检查是否需要刷新""" if len(self._buffer) >= self.batch_size: return True elapsed = asyncio.get_event_loop().time() - self._last_flush if elapsed >= self.flush_interval and self._buffer: return True return False async def flush(self): """强制刷新缓冲区""" async with self._lock: if not self._buffer: return 0 data_list = list(self._buffer.values()) self._buffer.clear() self._last_flush = asyncio.get_event_loop().time() try: if self.update_columns: return await self.pool.insert_or_update_many( self.table, data_list, self.update_columns ) else: return await self.pool.insert_many(self.table, data_list) except Exception as e: logger.error(f"批量写入失败: {e}") raise async def auto_flush(self): """自动刷新协程 - 应在后台运行""" while True: await asyncio.sleep(0.5) # 每0.5秒检查一次 if await self._should_flush(): try: await self.flush() except Exception as e: logger.error(f"自动刷新失败: {e}") # ============================================================ # 同步封装层 - 方便在同步代码中调用 # ============================================================ import threading from concurrent.futures import ThreadPoolExecutor _executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="mysql_async_") def _run_async(coro): """在新线程中运行异步协程""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete(coro) finally: loop.close() class SyncWrapper: """ 同步封装 - 让异步连接池可以在同步代码中使用 使用示例: pool = AsyncMySQLPool() sync_pool = SyncWrapper(pool) sync_pool.init() # 同步初始化 # 同步调用 sync_pool.insert_many("table", [{"col": "val"}]) sync_pool.close() """ def __init__(self, async_pool: AsyncMySQLPool): self.async_pool = async_pool self._loop = None self._thread = None self._running = False def _run_loop(self): """在线程中运行事件循环""" asyncio.set_event_loop(self._loop) self._loop.run_forever() def init(self): """同步初始化""" if self._running: return self._loop = asyncio.new_event_loop() self._thread = threading.Thread(target=self._run_loop, daemon=True) self._thread.start() self._running = True # 在新线程中初始化 asyncio.run_coroutine_threadsafe(self.async_pool.init(), self._loop).result() def close(self): """同步关闭""" if not self._running: return asyncio.run_coroutine_threadsafe(self.async_pool.close(), self._loop).result() if self._loop: self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join(timeout=2) self._running = False def insert_many(self, table, data_list, **kwargs): """同步批量插入""" coro = self.async_pool.insert_many(table, data_list, **kwargs) return asyncio.run_coroutine_threadsafe(coro, self._loop).result() def insert_one_or_dict(self, table, data, **kwargs): """同步单条插入""" coro = self.async_pool.insert_one_or_dict(table, data, **kwargs) return asyncio.run_coroutine_threadsafe(coro, self._loop).result() def select_all(self, query, args=None): """同步查询全部""" coro = self.async_pool.select_all(query, args) return asyncio.run_coroutine_threadsafe(coro, self._loop).result() def select_one(self, query, args=None): """同步查询单条""" coro = self.async_pool.select_one(query, args) return asyncio.run_coroutine_threadsafe(coro, self._loop).result() if __name__ == '__main__': import asyncio async def test(): pool = AsyncMySQLPool(min_size=5, max_size=20) await pool.init() # 健康检查 health = await pool.check_pool_health() print(f"连接池健康: {health}") # 测试插入 test_data = [ {"card_id": 1, "card_name": "测试卡牌1", "card_type": "角色"}, {"card_id": 2, "card_name": "测试卡牌2", "card_type": "角色"}, ] result = await pool.insert_many("one_piece_record", test_data, ignore=True) print(f"插入结果: {result}") # 测试插入或更新 await pool.insert_or_update_many( "one_piece_record", [{"card_id": 1, "card_name": "更新卡牌1"}], update_columns=["card_name"] ) await pool.close() print("测试完成!") asyncio.run(test())