mysql_pool_async.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. # -*- coding: utf-8 -*-
  2. # Author: Charley
  3. # Python: 3.10+
  4. # Date: 2026/04/16
  5. # Desc: 异步MySQL连接池 - 专为爬虫高并发场景优化
  6. import asyncio
  7. import re
  8. import aiomysql
  9. from loguru import logger
  10. class AsyncMySQLPool:
  11. """
  12. 异步MySQL连接池 - 专为爬虫高并发场景优化
  13. 相比同步版本(pymysql)的优势:
  14. 1. 真正的异步非阻塞I/O,爬虫并发爬取时不会因DB写入而阻塞
  15. 2. 支持连接池复用,减少连接创建开销
  16. 3. 批量插入更高效,支持更大的批次
  17. 使用示例:
  18. pool = AsyncMySQLPool()
  19. await pool.init()
  20. # 批量插入
  21. await pool.insert_many("table_name", [{"col1": "val1"}, {"col2": "val2"}])
  22. # 或者在async with中使用
  23. async with pool.get_connection() as conn:
  24. async with conn.cursor() as cursor:
  25. await cursor.execute("SELECT * FROM table LIMIT 10")
  26. results = await cursor.fetchall()
  27. await pool.close()
  28. """
  29. def __init__(
  30. self,
  31. min_size=5,
  32. max_size=20,
  33. pool_recycle=3600,
  34. connect_timeout=10,
  35. write_timeout=30,
  36. log=None
  37. ):
  38. """
  39. 初始化异步连接池
  40. :param min_size: 池中最小连接数
  41. :param max_size: 池中最大连接数
  42. :param pool_recycle: 连接回收时间(秒),防止MySQL 8小时超时
  43. :param connect_timeout: 连接超时(秒)
  44. :param write_timeout: 写入超时(秒)
  45. :param log: 自定义日志记录器
  46. """
  47. self.log = log or logger
  48. self.min_size = min_size
  49. self.max_size = max_size
  50. self.pool_recycle = pool_recycle
  51. self.connect_timeout = connect_timeout
  52. self.write_timeout = write_timeout
  53. self._pool = None
  54. # 从 YamlLoader 读取配置
  55. from YamlLoader import readYaml
  56. yaml = readYaml()
  57. mysql_yaml = yaml.get("mysql")
  58. self.host = mysql_yaml.getValueAsString("host")
  59. self.port = mysql_yaml.getValueAsInt("port")
  60. self.user = mysql_yaml.getValueAsString("username")
  61. self.password = mysql_yaml.getValueAsString("password")
  62. self.database = mysql_yaml.getValueAsString("db")
  63. async def init(self):
  64. """初始化连接池"""
  65. if self._pool is not None:
  66. return
  67. try:
  68. self._pool = await aiomysql.create_pool(
  69. minsize=self.min_size,
  70. maxsize=self.max_size,
  71. host=self.host,
  72. port=self.port,
  73. user=self.user,
  74. password=self.password,
  75. db=self.database,
  76. autocommit=True, # 自动提交,避免每次手动commit
  77. charset='utf8mb4',
  78. connect_timeout=self.connect_timeout,
  79. write_timeout=self.write_timeout,
  80. pool_recycle=self.pool_recycle,
  81. )
  82. self.log.info(
  83. f"异步MySQL连接池初始化成功: {self.host}:{self.port}/{self.database}, "
  84. f"连接池大小: {self.min_size}-{self.max_size}"
  85. )
  86. except Exception as e:
  87. self.log.error(f"异步MySQL连接池初始化失败: {e}")
  88. raise
  89. async def close(self):
  90. """关闭连接池"""
  91. if self._pool:
  92. self._pool.close()
  93. await self._pool.wait_closed()
  94. self._pool = None
  95. self.log.info("异步MySQL连接池已关闭")
  96. async def get_connection(self):
  97. """
  98. 获取连接(推荐使用 async with)
  99. :return: async with 可用的连接对象
  100. """
  101. if self._pool is None:
  102. await self.init()
  103. return self._pool.acquire()
  104. async def check_pool_health(self):
  105. """检查连接池健康状态"""
  106. try:
  107. async with self.get_connection() as conn:
  108. async with conn.cursor() as cursor:
  109. await cursor.execute("SELECT 1")
  110. result = await cursor.fetchone()
  111. return result[0] == 1
  112. except Exception as e:
  113. self.log.error(f"连接池健康检查失败: {e}")
  114. return False
  115. @staticmethod
  116. def _safe_identifier(name):
  117. """SQL标识符安全校验"""
  118. if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
  119. raise ValueError(f"Invalid SQL identifier: {name}")
  120. return name
  121. async def execute(self, query, args=None, commit=True):
  122. """
  123. 执行SQL
  124. :param query: SQL语句
  125. :param args: SQL参数
  126. :param commit: 是否提交(autocommit=True时忽略)
  127. :return: cursor
  128. """
  129. async with self.get_connection() as conn:
  130. async with conn.cursor(aiomysql.DictCursor) as cursor:
  131. await cursor.execute(query, args)
  132. if commit and not conn.get_autocommit():
  133. await conn.commit()
  134. self.log.debug(f"SQL执行: {query[:80]}..., 影响行数: {cursor.rowcount}")
  135. return cursor
  136. async def select_one(self, query, args=None):
  137. """查询单条"""
  138. cursor = await self.execute(query, args, commit=False)
  139. return await cursor.fetchone()
  140. async def select_all(self, query, args=None):
  141. """查询全部"""
  142. cursor = await self.execute(query, args, commit=False)
  143. return await cursor.fetchall()
  144. async def insert_one(self, query, args):
  145. """插入单条"""
  146. cursor = await self.execute(query, args)
  147. return cursor.lastrowid
  148. async def insert_one_or_dict(self, table, data, ignore=False, commit=True):
  149. """
  150. 字典格式插入单条
  151. :param table: 表名
  152. :param data: 字典 {列名: 值}
  153. :param ignore: 是否使用INSERT IGNORE
  154. :param commit: 是否提交
  155. """
  156. if not isinstance(data, dict):
  157. raise ValueError("data must be a dictionary")
  158. keys = ', '.join([self._safe_identifier(k) for k in data.keys()])
  159. values = ', '.join(['%s'] * len(data))
  160. ignore_clause = "IGNORE" if ignore else ""
  161. query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values})"
  162. args = tuple(data.values())
  163. try:
  164. cursor = await self.execute(query, args, commit=commit)
  165. return cursor.lastrowid
  166. except aiomysql.IntegrityError as e:
  167. if "Duplicate entry" in str(e):
  168. self.log.debug(f"跳过重复条目: {e}")
  169. return -1
  170. raise
  171. async def insert_many(
  172. self,
  173. table,
  174. data_list,
  175. batch_size=2000,
  176. ignore=False,
  177. commit=True
  178. ):
  179. """
  180. 批量插入(字典列表) - 高性能版本
  181. 优化点:
  182. 1. 使用事务批量提交,减少I/O次数
  183. 2. 批次大小增加到2000,减少Python循环次数
  184. 3. 失败时自动降级为逐条插入(仅跳过重复,不阻塞)
  185. :param table: 表名
  186. :param data_list: 字典列表 [{列名: 值}, ...]
  187. :param batch_size: 每批插入数量(默认2000)
  188. :param ignore: 是否使用INSERT IGNORE
  189. :param commit: 是否提交
  190. :return: 成功插入的行数
  191. """
  192. if not data_list:
  193. return 0
  194. if not isinstance(data_list[0], dict):
  195. raise ValueError("data_list must be list of dictionaries")
  196. keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
  197. values_placeholder = ', '.join(['%s'] * len(data_list[0]))
  198. ignore_clause = "IGNORE" if ignore else ""
  199. query = f"INSERT {ignore_clause} INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
  200. total_inserted = 0
  201. total_skipped = 0
  202. async with self.get_connection() as conn:
  203. async with conn.cursor() as cursor:
  204. try:
  205. for i in range(0, len(data_list), batch_size):
  206. batch = data_list[i:i + batch_size]
  207. args_list = [tuple(d.values()) for d in batch]
  208. try:
  209. await cursor.executemany(query, args_list)
  210. if commit:
  211. await conn.commit()
  212. total_inserted += cursor.rowcount
  213. except aiomysql.IntegrityError as e:
  214. await conn.rollback()
  215. if "Duplicate entry" in str(e) and not ignore:
  216. # 降级为逐条插入
  217. inserted, skipped = await self._insert_batch_one_by_one(
  218. cursor, query, batch, commit
  219. )
  220. total_inserted += inserted
  221. total_skipped += skipped
  222. else:
  223. raise
  224. except Exception as e:
  225. await conn.rollback()
  226. self.log.error(f"批量插入失败: {e}")
  227. # 降级为逐条插入
  228. inserted, skipped = await self._insert_batch_one_by_one(
  229. cursor, query, batch, commit
  230. )
  231. total_inserted += inserted
  232. total_skipped += skipped
  233. except Exception as e:
  234. self.log.exception(f"批量插入最终失败: {e}")
  235. raise
  236. if total_skipped > 0:
  237. self.log.info(f"插入完成: 成功{total_inserted}条, 跳过重复{total_skipped}条")
  238. else:
  239. self.log.info(f"插入完成: 成功{total_inserted}条")
  240. return total_inserted
  241. async def _insert_batch_one_by_one(self, cursor, query, batch, commit):
  242. """
  243. 逐条插入(降级方案)
  244. :return: (插入数, 跳过数)
  245. """
  246. inserted = 0
  247. skipped = 0
  248. for data in batch:
  249. try:
  250. args = tuple(data.values())
  251. await cursor.execute(query, args)
  252. if commit:
  253. await cursor.execute("COMMIT")
  254. inserted += 1
  255. except aiomysql.IntegrityError as e:
  256. if "Duplicate entry" in str(e):
  257. skipped += 1
  258. else:
  259. self.log.error(f"插入失败: {e}")
  260. except Exception as e:
  261. self.log.error(f"插入失败: {e}")
  262. return inserted, skipped
  263. async def insert_raw_many(self, query, args_list, batch_size=2000, commit=True):
  264. """
  265. 批量插入(原始SQL) - 高性能版本
  266. :param query: 预编译SQL语句,使用 %s 占位符
  267. :param args_list: 参数列表 [(val1, val2), ...]
  268. :param batch_size: 每批数量
  269. :param commit: 是否提交
  270. :return: 成功插入行数
  271. """
  272. if not args_list:
  273. return 0
  274. total = 0
  275. async with self.get_connection() as conn:
  276. async with conn.cursor() as cursor:
  277. for i in range(0, len(args_list), batch_size):
  278. batch = args_list[i:i + batch_size]
  279. try:
  280. await cursor.executemany(query, batch)
  281. if commit:
  282. await conn.commit()
  283. total += cursor.rowcount
  284. except Exception as e:
  285. await conn.rollback()
  286. self.log.error(f"批量插入失败 [{i}]: {e}")
  287. # 降级为逐条
  288. for args in batch:
  289. try:
  290. await cursor.execute(query, args)
  291. if commit:
  292. await conn.commit()
  293. total += 1
  294. except Exception as e2:
  295. if "Duplicate" not in str(e2):
  296. self.log.error(f"单条插入失败: {e2}")
  297. self.log.debug(f"insert_raw_many 完成: {total}条")
  298. return total
  299. async def insert_or_update(self, table, data, update_columns=None, commit=True):
  300. """
  301. 插入或更新 (INSERT ... ON DUPLICATE KEY UPDATE)
  302. :param table: 表名
  303. :param data: 字典 {列名: 值}
  304. :param update_columns: 更新时更新的列,None表示更新所有非主键列
  305. :param commit: 是否提交
  306. """
  307. if not isinstance(data, dict):
  308. raise ValueError("data must be a dictionary")
  309. keys = ', '.join([self._safe_identifier(k) for k in data.keys()])
  310. values_placeholder = ', '.join(['%s'] * len(data))
  311. query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
  312. if update_columns:
  313. update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})"
  314. for col in update_columns])
  315. else:
  316. update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})"
  317. for k in data.keys()])
  318. query += f" ON DUPLICATE KEY UPDATE {update_clause}"
  319. cursor = await self.execute(query, tuple(data.values()), commit=commit)
  320. return cursor.lastrowid
  321. async def insert_or_update_many(
  322. self,
  323. table,
  324. data_list,
  325. update_columns=None,
  326. batch_size=2000,
  327. commit=True
  328. ):
  329. """
  330. 批量插入或更新 - 适合爬虫去重场景
  331. :param table: 表名
  332. :param data_list: 字典列表
  333. :param update_columns: 更新时更新的列
  334. :param batch_size: 每批数量
  335. :param commit: 是否提交
  336. """
  337. if not data_list or not isinstance(data_list[0], dict):
  338. raise ValueError("data_list must be non-empty list of dictionaries")
  339. keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
  340. values_placeholder = ', '.join(['%s'] * len(data_list[0]))
  341. query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values_placeholder})"
  342. if update_columns:
  343. update_clause = ', '.join([f"{self._safe_identifier(col)} = VALUES({self._safe_identifier(col)})"
  344. for col in update_columns])
  345. else:
  346. update_clause = ', '.join([f"{self._safe_identifier(k)} = VALUES({self._safe_identifier(k)})"
  347. for k in data_list[0].keys()])
  348. query += f" ON DUPLICATE KEY UPDATE {update_clause}"
  349. total = 0
  350. async with self.get_connection() as conn:
  351. async with conn.cursor() as cursor:
  352. for i in range(0, len(data_list), batch_size):
  353. batch = data_list[i:i + batch_size]
  354. args_list = [tuple(d.values()) for d in batch]
  355. try:
  356. await cursor.executemany(query, args_list)
  357. if commit:
  358. await conn.commit()
  359. total += cursor.rowcount
  360. except Exception as e:
  361. await conn.rollback()
  362. self.log.error(f"批量插入/更新失败: {e}")
  363. raise
  364. self.log.info(f"insert_or_update_many 完成: {total}条(包含更新)")
  365. return total
  366. async def update(self, query, args=None, commit=True):
  367. """执行更新"""
  368. cursor = await self.execute(query, args, commit=commit)
  369. return cursor.rowcount
  370. async def update_one_or_dict(self, table, data, condition, commit=True):
  371. """
  372. 字典格式更新单条
  373. :param table: 表名
  374. :param data: 要更新的数据 {列名: 值}
  375. :param condition: 更新条件 {"id": 1} 或 "id = 1"
  376. """
  377. set_clause = ', '.join([f"{self._safe_identifier(k)} = %s" for k in data.keys()])
  378. query = f"UPDATE {self._safe_identifier(table)} SET {set_clause}"
  379. if isinstance(condition, dict):
  380. where_clause = ' AND '.join([f"{self._safe_identifier(k)} = %s" for k in condition.keys()])
  381. query += f" WHERE {where_clause}"
  382. args = list(data.values()) + list(condition.values())
  383. else:
  384. query += f" WHERE {condition}"
  385. args = list(data.values())
  386. cursor = await self.execute(query, args, commit=commit)
  387. return cursor.rowcount
  388. class AsyncMySQLBatchWriter:
  389. """
  390. 异步批量写入器 - 专为爬虫设计
  391. 使用方式:
  392. 1. 在爬虫主循环中调用 add() 添加数据
  393. 2. 内部自动定时批量写入数据库
  394. 3. 程序结束时调用 flush() 确保数据全部写入
  395. 优势:
  396. - 非阻塞写入,不影响爬虫速度
  397. - 自动批量合并,减少数据库I/O
  398. - 自动去重(基于主键)
  399. """
  400. def __init__(
  401. self,
  402. pool: AsyncMySQLPool,
  403. table: str,
  404. batch_size: int = 2000,
  405. flush_interval: float = 2.0,
  406. dedup: bool = True,
  407. update_columns: list = None
  408. ):
  409. """
  410. :param pool: AsyncMySQLPool 实例
  411. :param table: 表名
  412. :param batch_size: 每批写入数量
  413. :param flush_interval: 自动刷新间隔(秒)
  414. :param dedup: 是否基于字典去重(仅保留最新)
  415. :param update_columns: 更新时更新的列,None表示全部更新
  416. """
  417. self.pool = pool
  418. self.table = table
  419. self.batch_size = batch_size
  420. self.flush_interval = flush_interval
  421. self.dedup = dedup
  422. self.update_columns = update_columns
  423. self._buffer = {} # 使用dict去重
  424. self._keys = None
  425. self._last_flush = asyncio.get_event_loop().time()
  426. self._lock = asyncio.Lock()
  427. def add(self, data: dict):
  428. """
  429. 添加数据到缓冲区(线程安全)
  430. :param data: 字典格式数据
  431. """
  432. if self._keys is None:
  433. self._keys = tuple(data.keys())
  434. if self.dedup:
  435. # 基于第一个字段去重(假设是主键)
  436. key = data.get(self._keys[0])
  437. self._buffer[key] = data
  438. else:
  439. self._buffer[len(self._buffer)] = data
  440. async def _should_flush(self) -> bool:
  441. """检查是否需要刷新"""
  442. if len(self._buffer) >= self.batch_size:
  443. return True
  444. elapsed = asyncio.get_event_loop().time() - self._last_flush
  445. if elapsed >= self.flush_interval and self._buffer:
  446. return True
  447. return False
  448. async def flush(self):
  449. """强制刷新缓冲区"""
  450. async with self._lock:
  451. if not self._buffer:
  452. return 0
  453. data_list = list(self._buffer.values())
  454. self._buffer.clear()
  455. self._last_flush = asyncio.get_event_loop().time()
  456. try:
  457. if self.update_columns:
  458. return await self.pool.insert_or_update_many(
  459. self.table, data_list, self.update_columns
  460. )
  461. else:
  462. return await self.pool.insert_many(self.table, data_list)
  463. except Exception as e:
  464. logger.error(f"批量写入失败: {e}")
  465. raise
  466. async def auto_flush(self):
  467. """自动刷新协程 - 应在后台运行"""
  468. while True:
  469. await asyncio.sleep(0.5) # 每0.5秒检查一次
  470. if await self._should_flush():
  471. try:
  472. await self.flush()
  473. except Exception as e:
  474. logger.error(f"自动刷新失败: {e}")
  475. # ============================================================
  476. # 同步封装层 - 方便在同步代码中调用
  477. # ============================================================
  478. import threading
  479. from concurrent.futures import ThreadPoolExecutor
  480. _executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="mysql_async_")
  481. def _run_async(coro):
  482. """在新线程中运行异步协程"""
  483. loop = asyncio.new_event_loop()
  484. asyncio.set_event_loop(loop)
  485. try:
  486. return loop.run_until_complete(coro)
  487. finally:
  488. loop.close()
  489. class SyncWrapper:
  490. """
  491. 同步封装 - 让异步连接池可以在同步代码中使用
  492. 使用示例:
  493. pool = AsyncMySQLPool()
  494. sync_pool = SyncWrapper(pool)
  495. sync_pool.init() # 同步初始化
  496. # 同步调用
  497. sync_pool.insert_many("table", [{"col": "val"}])
  498. sync_pool.close()
  499. """
  500. def __init__(self, async_pool: AsyncMySQLPool):
  501. self.async_pool = async_pool
  502. self._loop = None
  503. self._thread = None
  504. self._running = False
  505. def _run_loop(self):
  506. """在线程中运行事件循环"""
  507. asyncio.set_event_loop(self._loop)
  508. self._loop.run_forever()
  509. def init(self):
  510. """同步初始化"""
  511. if self._running:
  512. return
  513. self._loop = asyncio.new_event_loop()
  514. self._thread = threading.Thread(target=self._run_loop, daemon=True)
  515. self._thread.start()
  516. self._running = True
  517. # 在新线程中初始化
  518. asyncio.run_coroutine_threadsafe(self.async_pool.init(), self._loop).result()
  519. def close(self):
  520. """同步关闭"""
  521. if not self._running:
  522. return
  523. asyncio.run_coroutine_threadsafe(self.async_pool.close(), self._loop).result()
  524. if self._loop:
  525. self._loop.call_soon_threadsafe(self._loop.stop)
  526. self._thread.join(timeout=2)
  527. self._running = False
  528. def insert_many(self, table, data_list, **kwargs):
  529. """同步批量插入"""
  530. coro = self.async_pool.insert_many(table, data_list, **kwargs)
  531. return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
  532. def insert_one_or_dict(self, table, data, **kwargs):
  533. """同步单条插入"""
  534. coro = self.async_pool.insert_one_or_dict(table, data, **kwargs)
  535. return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
  536. def select_all(self, query, args=None):
  537. """同步查询全部"""
  538. coro = self.async_pool.select_all(query, args)
  539. return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
  540. def select_one(self, query, args=None):
  541. """同步查询单条"""
  542. coro = self.async_pool.select_one(query, args)
  543. return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
  544. if __name__ == '__main__':
  545. import asyncio
  546. async def test():
  547. pool = AsyncMySQLPool(min_size=5, max_size=20)
  548. await pool.init()
  549. # 健康检查
  550. health = await pool.check_pool_health()
  551. print(f"连接池健康: {health}")
  552. # 测试插入
  553. test_data = [
  554. {"card_id": 1, "card_name": "测试卡牌1", "card_type": "角色"},
  555. {"card_id": 2, "card_name": "测试卡牌2", "card_type": "角色"},
  556. ]
  557. result = await pool.insert_many("one_piece_record", test_data, ignore=True)
  558. print(f"插入结果: {result}")
  559. # 测试插入或更新
  560. await pool.insert_or_update_many(
  561. "one_piece_record",
  562. [{"card_id": 1, "card_name": "更新卡牌1"}],
  563. update_columns=["card_name"]
  564. )
  565. await pool.close()
  566. print("测试完成!")
  567. asyncio.run(test())