| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667 |
- # -*- coding: utf-8 -*-
- # Author: Charley
- # Python: 3.10+
- # Date: 2026/04/16
- # Desc: 异步MySQL连接池 - 专为爬虫高并发场景优化
- import asyncio
- import re
- import aiomysql
- from loguru import logger
- class AsyncMySQLPool:
- """
- 异步MySQL连接池 - 专为爬虫高并发场景优化
- 相比同步版本(pymysql)的优势:
- 1. 真正的异步非阻塞I/O,爬虫并发爬取时不会因DB写入而阻塞
- 2. 支持连接池复用,减少连接创建开销
- 3. 批量插入更高效,支持更大的批次
- 使用示例:
- pool = AsyncMySQLPool()
- await pool.init()
- # 批量插入
- await pool.insert_many("table_name", [{"col1": "val1"}, {"col2": "val2"}])
- # 或者在async with中使用
- async with pool.get_connection() as conn:
- async with conn.cursor() as cursor:
- await cursor.execute("SELECT * FROM table LIMIT 10")
- results = await cursor.fetchall()
- await pool.close()
- """
- def __init__(
- self,
- min_size=5,
- max_size=20,
- pool_recycle=3600,
- connect_timeout=10,
- write_timeout=30,
- log=None
- ):
- """
- 初始化异步连接池
- :param min_size: 池中最小连接数
- :param max_size: 池中最大连接数
- :param pool_recycle: 连接回收时间(秒),防止MySQL 8小时超时
- :param connect_timeout: 连接超时(秒)
- :param write_timeout: 写入超时(秒)
- :param log: 自定义日志记录器
- """
- self.log = log or logger
- self.min_size = min_size
- self.max_size = max_size
- self.pool_recycle = pool_recycle
- self.connect_timeout = connect_timeout
- self.write_timeout = write_timeout
- self._pool = None
- # 从 YamlLoader 读取配置
- from YamlLoader import readYaml
- yaml = readYaml()
- mysql_yaml = yaml.get("mysql")
- self.host = mysql_yaml.getValueAsString("host")
- self.port = mysql_yaml.getValueAsInt("port")
- self.user = mysql_yaml.getValueAsString("username")
- self.password = mysql_yaml.getValueAsString("password")
- self.database = mysql_yaml.getValueAsString("db")
- async def init(self):
- """初始化连接池"""
- if self._pool is not None:
- return
- try:
- self._pool = await aiomysql.create_pool(
- minsize=self.min_size,
- maxsize=self.max_size,
- host=self.host,
- port=self.port,
- user=self.user,
- password=self.password,
- db=self.database,
- autocommit=True, # 自动提交,避免每次手动commit
- charset='utf8mb4',
- connect_timeout=self.connect_timeout,
- write_timeout=self.write_timeout,
- pool_recycle=self.pool_recycle,
- )
- self.log.info(
- f"异步MySQL连接池初始化成功: {self.host}:{self.port}/{self.database}, "
- f"连接池大小: {self.min_size}-{self.max_size}"
- )
- except Exception as e:
- self.log.error(f"异步MySQL连接池初始化失败: {e}")
- raise
- async def close(self):
- """关闭连接池"""
- if self._pool:
- self._pool.close()
- await self._pool.wait_closed()
- self._pool = None
- self.log.info("异步MySQL连接池已关闭")
- async def get_connection(self):
- """
- 获取连接(推荐使用 async with)
- :return: async with 可用的连接对象
- """
- if self._pool is None:
- await self.init()
- return self._pool.acquire()
- async def check_pool_health(self):
- """检查连接池健康状态"""
- try:
- async with self.get_connection() as conn:
- async with conn.cursor() as cursor:
- await cursor.execute("SELECT 1")
- result = await cursor.fetchone()
- return result[0] == 1
- except Exception as e:
- self.log.error(f"连接池健康检查失败: {e}")
- return False
- @staticmethod
- def _safe_identifier(name):
- """SQL标识符安全校验"""
- if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
- raise ValueError(f"Invalid SQL identifier: {name}")
- return name
- async def execute(self, query, args=None, commit=True):
- """
- 执行SQL
- :param query: SQL语句
- :param args: SQL参数
- :param commit: 是否提交(autocommit=True时忽略)
- :return: cursor
- """
- async with self.get_connection() as conn:
- async with conn.cursor(aiomysql.DictCursor) as cursor:
- await cursor.execute(query, args)
- if commit and not conn.get_autocommit():
- await conn.commit()
- self.log.debug(f"SQL执行: {query[:80]}..., 影响行数: {cursor.rowcount}")
- return cursor
- async def select_one(self, query, args=None):
- """查询单条"""
- cursor = await self.execute(query, args, commit=False)
- return await cursor.fetchone()
- async def select_all(self, query, args=None):
- """查询全部"""
- cursor = await self.execute(query, args, commit=False)
- return await cursor.fetchall()
- async def insert_one(self, query, args):
- """插入单条"""
- cursor = await self.execute(query, args)
- return cursor.lastrowid
- async def insert_one_or_dict(self, table, data, ignore=False, commit=True):
- """
- 字典格式插入单条
- :param table: 表名
- :param data: 字典 {列名: 值}
- :param ignore: 是否使用INSERT IGNORE
- :param commit: 是否提交
- """
- if not isinstance(data, dict):
- raise ValueError("data must be a dictionary")
- keys = ', '.join([self._safe_identifier(k) for k in data.keys()])
- values = ', '.join(['%s'] * len(data))
- ignore_clause = "IGNORE" if ignore else ""
- query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values})"
- args = tuple(data.values())
- try:
- cursor = await self.execute(query, args, commit=commit)
- return cursor.lastrowid
- except aiomysql.IntegrityError as e:
- if "Duplicate entry" in str(e):
- self.log.debug(f"跳过重复条目: {e}")
- return -1
- raise
- async def insert_many(
- self,
- table,
- data_list,
- batch_size=2000,
- ignore=False,
- commit=True
- ):
- """
- 批量插入(字典列表) - 高性能版本
- 优化点:
- 1. 使用事务批量提交,减少I/O次数
- 2. 批次大小增加到2000,减少Python循环次数
- 3. 失败时自动降级为逐条插入(仅跳过重复,不阻塞)
- :param table: 表名
- :param data_list: 字典列表 [{列名: 值}, ...]
- :param batch_size: 每批插入数量(默认2000)
- :param ignore: 是否使用INSERT IGNORE
- :param commit: 是否提交
- :return: 成功插入的行数
- """
- if not data_list:
- return 0
- if not isinstance(data_list[0], dict):
- raise ValueError("data_list must be list of dictionaries")
- keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
- values_placeholder = ', '.join(['%s'] * len(data_list[0]))
- ignore_clause = "IGNORE" if ignore else ""
- query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
- total_inserted = 0
- total_skipped = 0
- async with self.get_connection() as conn:
- async with conn.cursor() as cursor:
- try:
- for i in range(0, len(data_list), batch_size):
- batch = data_list[i:i + batch_size]
- args_list = [tuple(d.values()) for d in batch]
- try:
- await cursor.executemany(query, args_list)
- if commit:
- await conn.commit()
- total_inserted += cursor.rowcount
- except aiomysql.IntegrityError as e:
- await conn.rollback()
- if "Duplicate entry" in str(e) and not ignore:
- # 降级为逐条插入
- inserted, skipped = await self._insert_batch_one_by_one(
- cursor, query, batch, commit
- )
- total_inserted += inserted
- total_skipped += skipped
- else:
- raise
- except Exception as e:
- await conn.rollback()
- self.log.error(f"批量插入失败: {e}")
- # 降级为逐条插入
- inserted, skipped = await self._insert_batch_one_by_one(
- cursor, query, batch, commit
- )
- total_inserted += inserted
- total_skipped += skipped
- except Exception as e:
- self.log.exception(f"批量插入最终失败: {e}")
- raise
- if total_skipped > 0:
- self.log.info(f"插入完成: 成功{total_inserted}条, 跳过重复{total_skipped}条")
- else:
- self.log.info(f"插入完成: 成功{total_inserted}条")
- return total_inserted
- async def _insert_batch_one_by_one(self, cursor, query, batch, commit):
- """
- 逐条插入(降级方案)
- :return: (插入数, 跳过数)
- """
- inserted = 0
- skipped = 0
- for data in batch:
- try:
- args = tuple(data.values())
- await cursor.execute(query, args)
- if commit:
- await cursor.execute("COMMIT")
- inserted += 1
- except aiomysql.IntegrityError as e:
- if "Duplicate entry" in str(e):
- skipped += 1
- else:
- self.log.error(f"插入失败: {e}")
- except Exception as e:
- self.log.error(f"插入失败: {e}")
- return inserted, skipped
- async def insert_raw_many(self, query, args_list, batch_size=2000, commit=True):
- """
- 批量插入(原始SQL) - 高性能版本
- :param query: 预编译SQL语句,使用 %s 占位符
- :param args_list: 参数列表 [(val1, val2), ...]
- :param batch_size: 每批数量
- :param commit: 是否提交
- :return: 成功插入行数
- """
- if not args_list:
- return 0
- total = 0
- async with self.get_connection() as conn:
- async with conn.cursor() as cursor:
- for i in range(0, len(args_list), batch_size):
- batch = args_list[i:i + batch_size]
- try:
- await cursor.executemany(query, batch)
- if commit:
- await conn.commit()
- total += cursor.rowcount
- except Exception as e:
- await conn.rollback()
- self.log.error(f"批量插入失败 [{i}]: {e}")
- # 降级为逐条
- for args in batch:
- try:
- await cursor.execute(query, args)
- if commit:
- await conn.commit()
- total += 1
- except Exception as e2:
- if "Duplicate" not in str(e2):
- self.log.error(f"单条插入失败: {e2}")
- self.log.debug(f"insert_raw_many 完成: {total}条")
- return total
- async def insert_or_update(self, table, data, update_columns=None, commit=True):
- """
- 插入或更新 (INSERT ... ON DUPLICATE KEY UPDATE)
- :param table: 表名
- :param data: 字典 {列名: 值}
- :param update_columns: 更新时更新的列,None表示更新所有非主键列
- :param commit: 是否提交
- """
- if not isinstance(data, dict):
- raise ValueError("data must be a dictionary")
- keys = ', '.join([self._safe_identifier(k) for k in data.keys()])
- values_placeholder = ', '.join(['%s'] * len(data))
- query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
- if update_columns:
- update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})"
- for col in update_columns])
- else:
- update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})"
- for k in data.keys()])
- query += f" ON DUPLICATE KEY UPDATE {update_clause}"
- cursor = await self.execute(query, tuple(data.values()), commit=commit)
- return cursor.lastrowid
- async def insert_or_update_many(
- self,
- table,
- data_list,
- update_columns=None,
- batch_size=2000,
- commit=True
- ):
- """
- 批量插入或更新 - 适合爬虫去重场景
- :param table: 表名
- :param data_list: 字典列表
- :param update_columns: 更新时更新的列
- :param batch_size: 每批数量
- :param commit: 是否提交
- """
- if not data_list or not isinstance(data_list[0], dict):
- raise ValueError("data_list must be non-empty list of dictionaries")
- keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
- values_placeholder = ', '.join(['%s'] * len(data_list[0]))
- query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
- if update_columns:
- update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})"
- for col in update_columns])
- else:
- update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})"
- for k in data_list[0].keys()])
- query += f" ON DUPLICATE KEY UPDATE {update_clause}"
- total = 0
- async with self.get_connection() as conn:
- async with conn.cursor() as cursor:
- for i in range(0, len(data_list), batch_size):
- batch = data_list[i:i + batch_size]
- args_list = [tuple(d.values()) for d in batch]
- try:
- await cursor.executemany(query, args_list)
- if commit:
- await conn.commit()
- total += cursor.rowcount
- except Exception as e:
- await conn.rollback()
- self.log.error(f"批量插入/更新失败: {e}")
- raise
- self.log.info(f"insert_or_update_many 完成: {total}条(包含更新)")
- return total
- async def update(self, query, args=None, commit=True):
- """执行更新"""
- cursor = await self.execute(query, args, commit=commit)
- return cursor.rowcount
- async def update_one_or_dict(self, table, data, condition, commit=True):
- """
- 字典格式更新单条
- :param table: 表名
- :param data: 要更新的数据 {列名: 值}
- :param condition: 更新条件 {"id": 1} 或 "id = 1"
- """
- set_clause = ', '.join([f"{self._safe_identifier(k)} = %s" for k in data.keys()])
- query = f"UPDATE {self._safe_identifier(table)} SET {set_clause}"
- if isinstance(condition, dict):
- where_clause = ' AND '.join([f"{self._safe_identifier(k)} = %s" for k in condition.keys()])
- query += f" WHERE {where_clause}"
- args = list(data.values()) + list(condition.values())
- else:
- query += f" WHERE {condition}"
- args = list(data.values())
- cursor = await self.execute(query, args, commit=commit)
- return cursor.rowcount
- class AsyncMySQLBatchWriter:
- """
- 异步批量写入器 - 专为爬虫设计
- 使用方式:
- 1. 在爬虫主循环中调用 add() 添加数据
- 2. 内部自动定时批量写入数据库
- 3. 程序结束时调用 flush() 确保数据全部写入
- 优势:
- - 非阻塞写入,不影响爬虫速度
- - 自动批量合并,减少数据库I/O
- - 自动去重(基于主键)
- """
- def __init__(
- self,
- pool: AsyncMySQLPool,
- table: str,
- batch_size: int = 2000,
- flush_interval: float = 2.0,
- dedup: bool = True,
- update_columns: list = None
- ):
- """
- :param pool: AsyncMySQLPool 实例
- :param table: 表名
- :param batch_size: 每批写入数量
- :param flush_interval: 自动刷新间隔(秒)
- :param dedup: 是否基于字典去重(仅保留最新)
- :param update_columns: 更新时更新的列,None表示全部更新
- """
- self.pool = pool
- self.table = table
- self.batch_size = batch_size
- self.flush_interval = flush_interval
- self.dedup = dedup
- self.update_columns = update_columns
- self._buffer = {} # 使用dict去重
- self._keys = None
- self._last_flush = asyncio.get_event_loop().time()
- self._lock = asyncio.Lock()
- def add(self, data: dict):
- """
- 添加数据到缓冲区(线程安全)
- :param data: 字典格式数据
- """
- if self._keys is None:
- self._keys = tuple(data.keys())
- if self.dedup:
- # 基于第一个字段去重(假设是主键)
- key = data.get(self._keys[0])
- self._buffer[key] = data
- else:
- self._buffer[len(self._buffer)] = data
- async def _should_flush(self) -> bool:
- """检查是否需要刷新"""
- if len(self._buffer) >= self.batch_size:
- return True
- elapsed = asyncio.get_event_loop().time() - self._last_flush
- if elapsed >= self.flush_interval and self._buffer:
- return True
- return False
- async def flush(self):
- """强制刷新缓冲区"""
- async with self._lock:
- if not self._buffer:
- return 0
- data_list = list(self._buffer.values())
- self._buffer.clear()
- self._last_flush = asyncio.get_event_loop().time()
- try:
- if self.update_columns:
- return await self.pool.insert_or_update_many(
- self.table, data_list, self.update_columns
- )
- else:
- return await self.pool.insert_many(self.table, data_list)
- except Exception as e:
- logger.error(f"批量写入失败: {e}")
- raise
- async def auto_flush(self):
- """自动刷新协程 - 应在后台运行"""
- while True:
- await asyncio.sleep(0.5) # 每0.5秒检查一次
- if await self._should_flush():
- try:
- await self.flush()
- except Exception as e:
- logger.error(f"自动刷新失败: {e}")
- # ============================================================
- # 同步封装层 - 方便在同步代码中调用
- # ============================================================
- import threading
- from concurrent.futures import ThreadPoolExecutor
- _executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="mysql_async_")
- def _run_async(coro):
- """在新线程中运行异步协程"""
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- return loop.run_until_complete(coro)
- finally:
- loop.close()
- class SyncWrapper:
- """
- 同步封装 - 让异步连接池可以在同步代码中使用
- 使用示例:
- pool = AsyncMySQLPool()
- sync_pool = SyncWrapper(pool)
- sync_pool.init() # 同步初始化
- # 同步调用
- sync_pool.insert_many("table", [{"col": "val"}])
- sync_pool.close()
- """
- def __init__(self, async_pool: AsyncMySQLPool):
- self.async_pool = async_pool
- self._loop = None
- self._thread = None
- self._running = False
- def _run_loop(self):
- """在线程中运行事件循环"""
- asyncio.set_event_loop(self._loop)
- self._loop.run_forever()
- def init(self):
- """同步初始化"""
- if self._running:
- return
- self._loop = asyncio.new_event_loop()
- self._thread = threading.Thread(target=self._run_loop, daemon=True)
- self._thread.start()
- self._running = True
- # 在新线程中初始化
- asyncio.run_coroutine_threadsafe(self.async_pool.init(), self._loop).result()
- def close(self):
- """同步关闭"""
- if not self._running:
- return
- asyncio.run_coroutine_threadsafe(self.async_pool.close(), self._loop).result()
- if self._loop:
- self._loop.call_soon_threadsafe(self._loop.stop)
- self._thread.join(timeout=2)
- self._running = False
- def insert_many(self, table, data_list, **kwargs):
- """同步批量插入"""
- coro = self.async_pool.insert_many(table, data_list, **kwargs)
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
- def insert_one_or_dict(self, table, data, **kwargs):
- """同步单条插入"""
- coro = self.async_pool.insert_one_or_dict(table, data, **kwargs)
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
- def select_all(self, query, args=None):
- """同步查询全部"""
- coro = self.async_pool.select_all(query, args)
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
- def select_one(self, query, args=None):
- """同步查询单条"""
- coro = self.async_pool.select_one(query, args)
- return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
- if __name__ == '__main__':
- import asyncio
- async def test():
- pool = AsyncMySQLPool(min_size=5, max_size=20)
- await pool.init()
- # 健康检查
- health = await pool.check_pool_health()
- print(f"连接池健康: {health}")
- # 测试插入
- test_data = [
- {"card_id": 1, "card_name": "测试卡牌1", "card_type": "角色"},
- {"card_id": 2, "card_name": "测试卡牌2", "card_type": "角色"},
- ]
- result = await pool.insert_many("one_piece_record", test_data, ignore=True)
- print(f"插入结果: {result}")
- # 测试插入或更新
- await pool.insert_or_update_many(
- "one_piece_record",
- [{"card_id": 1, "card_name": "更新卡牌1"}],
- update_columns=["card_name"]
- )
- await pool.close()
- print("测试完成!")
- asyncio.run(test())
|