| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- # -*- coding: utf-8 -*-
- # Author : Charley
- # Python : 3.10.8
- # Date : 2025/3/25 14:14
- import re
- import pymysql
- import YamlLoader
- from loguru import logger
- from dbutils.pooled_db import PooledDB
- # 获取yaml配置
- yaml = YamlLoader.readYaml()
- mysqlYaml = yaml.get("mysql")
- sql_host = mysqlYaml.getValueAsString("host")
- sql_port = mysqlYaml.getValueAsInt("port")
- sql_user = mysqlYaml.getValueAsString("username")
- sql_password = mysqlYaml.getValueAsString("password")
- sql_db = mysqlYaml.getValueAsString("db")
- class MySQLConnectionPool:
- """
- MySQL连接池
- """
- def __init__(self, mincached=4, maxcached=5, maxconnections=10, log=None):
- """
- 初始化连接池
- :param mincached: 初始化时,链接池中至少创建的链接,0表示不创建
- :param maxcached: 池中空闲连接的最大数目(0 或 None 表示池大小不受限制)
- :param maxconnections: 允许的最大连接数(0 或 None 表示任意数量的连接)
- :param log: 自定义日志记录器
- """
- # 使用 loguru 的 logger,如果传入了其他 logger,则使用传入的 logger
- self.log = log or logger
- self.pool = PooledDB(
- creator=pymysql,
- mincached=mincached,
- maxcached=maxcached,
- maxconnections=maxconnections,
- blocking=True, # 连接池中如果没有可用连接后,是否阻塞等待。True,等待;False,不等待然后报错
- host=sql_host,
- port=sql_port,
- user=sql_user,
- password=sql_password,
- database=sql_db,
- ping=0 # 每次连接使用时自动检查有效性(0=不检查,1=执行query前检查,2=每次执行前检查)
- )
- def _execute(self, query, args=None, commit=False):
- """
- 执行SQL
- :param query: SQL语句
- :param args: SQL参数
- :param commit: 是否提交事务
- :return: 查询结果
- """
- try:
- with self.pool.connection() as conn:
- with conn.cursor() as cursor:
- cursor.execute(query, args)
- if commit:
- conn.commit()
- self.log.debug(f"sql _execute, Query: {query}, Rows: {cursor.rowcount}")
- return cursor
- except Exception as e:
- if commit:
- conn.rollback()
- self.log.error(f"Error executing query: {e}, Query: {query}, Args: {args}")
- raise e
- def select_one(self, query, args=None):
- """
- 执行查询,返回单个结果
- :param query: 查询语句
- :param args: 查询参数
- :return: 查询结果
- """
- cursor = self._execute(query, args)
- return cursor.fetchone()
- def select_all(self, query, args=None):
- """
- 执行查询,返回所有结果
- :param query: 查询语句
- :param args: 查询参数
- :return: 查询结果
- """
- cursor = self._execute(query, args)
- return cursor.fetchall()
- def insert_one(self, query, args):
- """
- 执行单条插入语句
- :param query: 插入语句
- :param args: 插入参数
- """
- self.log.info('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>data insert_one 入库中>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
- cursor = self._execute(query, args, commit=True)
- return cursor.lastrowid # 返回插入的ID
- def insert_all(self, query, args_list):
- """
- 执行批量插入语句,如果失败则逐条插入
- :param query: 插入语句
- :param args_list: 插入参数列表
- """
- conn = None
- cursor = None
- try:
- conn = self.pool.connection()
- cursor = conn.cursor()
- cursor.executemany(query, args_list)
- conn.commit()
- self.log.debug(f"sql insert_all, SQL: {query}, Rows: {len(args_list)}")
- self.log.info('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>data insert_all 入库中>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
- except Exception as e:
- conn.rollback()
- self.log.error(f"Batch insertion failed after 5 attempts. Trying single inserts. Error: {e}")
- # 如果批量插入失败,则逐条插入
- rowcount = 0
- for args in args_list:
- self.insert_one(query, args)
- rowcount += 1
- self.log.debug(f"Batch insertion failed. Inserted {rowcount} rows individually.")
- finally:
- if cursor:
- cursor.close()
- if conn:
- conn.close()
- def insert_one_or_dict(self, table=None, data=None, query=None, args=None, commit=True):
- """
- 单条插入(支持字典或原始SQL)
- :param table: 表名(字典插入时必需)
- :param data: 字典数据 {列名: 值}
- :param query: 直接SQL语句(与data二选一)
- :param args: SQL参数(query使用时必需)
- :param commit: 是否自动提交
- :return: 最后插入ID
- """
- if data is not None:
- if not isinstance(data, dict):
- raise ValueError("Data must be a dictionary")
- keys = ', '.join([self._safe_identifier(k) for k in data.keys()])
- values = ', '.join(['%s'] * len(data))
- query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values})"
- args = tuple(data.values())
- elif query is None:
- raise ValueError("Either data or query must be provided")
- # cursor = self._execute(query, args, commit)
- # self.log.info(f"sql insert_one_or_dict, Table: {table}, Rows: {cursor.rowcount}")
- # self.log.info('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>data insert_one_or_dict 入库中>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
- # return cursor.lastrowid
- try:
- cursor = self._execute(query, args, commit)
- self.log.info(f"sql insert_one_or_dict, Table: {table}, Rows: {cursor.rowcount}")
- self.log.info('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>data insert_one_or_dict 入库中>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
- return cursor.lastrowid
- except pymysql.err.IntegrityError as e:
- if "Duplicate entry" in str(e):
- self.log.warning(f"插入失败:重复条目,已跳过。错误详情: {e}")
- # print("插入失败:重复条目", e)
- return -1 # 返回 -1 表示重复条目被跳过
- else:
- self.log.error(f"数据库完整性错误: {e}")
- # print("插入失败:完整性错误", e)
- raise e
- except Exception as e:
- self.log.error(f"未知错误: {e}", exc_info=True)
- # print("插入失败:未知错误", e)
- raise e
- def insert_many(self, table=None, data_list=None, query=None, args_list=None, batch_size=500, commit=True):
- """
- 批量插入(支持字典列表或原始SQL)
- :param table: 表名(字典插入时必需)
- :param data_list: 字典列表 [{列名: 值}]
- :param query: 直接SQL语句(与data_list二选一)
- :param args_list: SQL参数列表(query使用时必需)
- :param batch_size: 分批大小
- :param commit: 是否自动提交
- :return: 影响行数
- """
- if data_list is not None:
- if not data_list or not isinstance(data_list[0], dict):
- raise ValueError("Data_list must be a non-empty list of dictionaries")
- keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
- values = ', '.join(['%s'] * len(data_list[0]))
- query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values})"
- args_list = [tuple(d.values()) for d in data_list]
- elif query is None:
- raise ValueError("Either data_list or query must be provided")
- total = 0
- for i in range(0, len(args_list), batch_size):
- batch = args_list[i:i + batch_size]
- try:
- with self.pool.connection() as conn:
- with conn.cursor() as cursor:
- cursor.executemany(query, batch)
- if commit:
- conn.commit()
- total += cursor.rowcount
- except pymysql.Error as e:
- if "Duplicate entry" in str(e):
- # self.log.warning(f"检测到重复条目,开始逐条插入。错误详情: {e}")
- raise e
- # rowcount = 0
- # for args in batch:
- # try:
- # self.insert_one_or_dict(table=table, data=dict(zip(data_list[0].keys(), args)),
- # commit=commit)
- # rowcount += 1
- # except pymysql.err.IntegrityError as e2:
- # if "Duplicate entry" in str(e2):
- # self.log.warning(f"跳过重复条目: {args}")
- # else:
- # self.log.error(f"插入失败: {e2}, 参数: {args}")
- # total += rowcount
- else:
- self.log.error(f"数据库错误: {e}")
- if commit:
- conn.rollback()
- raise e
- # 重新抛出异常,供外部捕获
- # 降级为单条插入
- # for args in batch:
- # try:
- # self.insert_one_or_dict(table=None, query=query, args=args, commit=commit)
- # total += 1
- # except Exception as e2:
- # self.log.error(f"Single insert failed: {e2}")
- # continue
- self.log.info(f"sql insert_many, Table: {table}, Total Rows: {total}")
- return total
- def insert_many_two(self, table=None, data_list=None, query=None, args_list=None, batch_size=500, commit=True):
- """
- 批量插入(支持字典列表或原始SQL)
- :param table: 表名(字典插入时必需)
- :param data_list: 字典列表 [{列名: 值}]
- :param query: 直接SQL语句(与data_list二选一)
- :param args_list: SQL参数列表(query使用时必需)
- :param batch_size: 分批大小
- :param commit: 是否自动提交
- :return: 影响行数
- """
- if data_list is not None:
- if not data_list or not isinstance(data_list[0], dict):
- raise ValueError("Data_list must be a non-empty list of dictionaries")
- keys = ', '.join([self._safe_identifier(k) for k in data_list[0].keys()])
- values = ', '.join(['%s'] * len(data_list[0]))
- query = f"INSERT INTO {self._safe_identifier(table)} ({keys}) VALUES ({values})"
- args_list = [tuple(d.values()) for d in data_list]
- elif query is None:
- raise ValueError("Either data_list or query must be provided")
- total = 0
- for i in range(0, len(args_list), batch_size):
- batch = args_list[i:i + batch_size]
- try:
- with self.pool.connection() as conn:
- with conn.cursor() as cursor:
- # 添加调试日志:输出 SQL 和参数示例
- # self.log.debug(f"Batch insert SQL: {query}")
- # self.log.debug(f"Sample args: {batch[0] if batch else 'None'}")
- cursor.executemany(query, batch)
- if commit:
- conn.commit()
- total += cursor.rowcount
- # self.log.debug(f"Batch insert succeeded. Rows: {cursor.rowcount}")
- except Exception as e: # 明确捕获数据库异常
- self.log.exception(f"Batch insert failed: {e}") # 使用 exception 记录堆栈
- self.log.error(f"Failed SQL: {query}, Args count: {len(batch)}")
- if commit:
- conn.rollback()
- # 降级为单条插入,并记录每个错误
- rowcount = 0
- for args in batch:
- try:
- self.insert_one(query, args)
- rowcount += 1
- except Exception as e2:
- self.log.error(f"Single insert failed: {e2}, Args: {args}")
- total += rowcount
- self.log.debug(f"Inserted {rowcount}/{len(batch)} rows individually.")
- self.log.info(f"sql insert_many, Table: {table}, Total Rows: {total}")
- return total
- def insert_too_many(self, query, args_list, batch_size=1000):
- """
- 执行批量插入语句,分片提交, 单次插入大于十万+时可用, 如果失败则降级为逐条插入
- :param query: 插入语句
- :param args_list: 插入参数列表
- :param batch_size: 每次插入的条数
- """
- for i in range(0, len(args_list), batch_size):
- batch = args_list[i:i + batch_size]
- try:
- with self.pool.connection() as conn:
- with conn.cursor() as cursor:
- cursor.executemany(query, batch)
- conn.commit()
- except Exception as e:
- self.log.error(f"insert_too_many error. Trying single insert. Error: {e}")
- # 当前批次降级为单条插入
- for args in batch:
- self.insert_one(query, args)
- def update_one(self, query, args):
- """
- 执行单条更新语句
- :param query: 更新语句
- :param args: 更新参数
- """
- self.log.info('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>data update_one 更新中>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
- return self._execute(query, args, commit=True)
- def update_all(self, query, args_list):
- """
- 执行批量更新语句,如果失败则逐条更新
- :param query: 更新语句
- :param args_list: 更新参数列表
- """
- conn = None
- cursor = None
- try:
- conn = self.pool.connection()
- cursor = conn.cursor()
- cursor.executemany(query, args_list)
- conn.commit()
- self.log.debug(f"sql update_all, SQL: {query}, Rows: {len(args_list)}")
- self.log.info('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>data update_all 更新中>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
- except Exception as e:
- conn.rollback()
- self.log.error(f"Error executing query: {e}")
- # 如果批量更新失败,则逐条更新
- rowcount = 0
- for args in args_list:
- self.update_one(query, args)
- rowcount += 1
- self.log.debug(f'Batch update failed. Updated {rowcount} rows individually.')
- finally:
- if cursor:
- cursor.close()
- if conn:
- conn.close()
- def update_one_or_dict(self, table=None, data=None, condition=None, query=None, args=None, commit=True):
- """
- 单条更新(支持字典或原始SQL)
- :param table: 表名(字典模式必需)
- :param data: 字典数据 {列名: 值}(与 query 二选一)
- :param condition: 更新条件,支持以下格式:
- - 字典: {"id": 1} → "WHERE id = %s"
- - 字符串: "id = 1" → "WHERE id = 1"(需自行确保安全)
- - 元组: ("id = %s", [1]) → "WHERE id = %s"(参数化查询)
- :param query: 直接SQL语句(与 data 二选一)
- :param args: SQL参数(query 模式下必需)
- :param commit: 是否自动提交
- :return: 影响行数
- :raises: ValueError 参数校验失败时抛出
- """
- # 参数校验
- if data is not None:
- if not isinstance(data, dict):
- raise ValueError("Data must be a dictionary")
- if table is None:
- raise ValueError("Table name is required for dictionary update")
- if condition is None:
- raise ValueError("Condition is required for dictionary update")
- # 构建 SET 子句
- set_clause = ", ".join([f"{self._safe_identifier(k)} = %s" for k in data.keys()])
- set_values = list(data.values())
- # 解析条件
- condition_clause, condition_args = self._parse_condition(condition)
- query = f"UPDATE {self._safe_identifier(table)} SET {set_clause} WHERE {condition_clause}"
- args = set_values + condition_args
- elif query is None:
- raise ValueError("Either data or query must be provided")
- # 执行更新
- cursor = self._execute(query, args, commit)
- # self.log.debug(
- # f"Updated table={table}, rows={cursor.rowcount}, query={query[:100]}...",
- # extra={"table": table, "rows": cursor.rowcount}
- # )
- return cursor.rowcount
- def _parse_condition(self, condition):
- """
- 解析条件为 (clause, args) 格式
- :param condition: 字典/字符串/元组
- :return: (str, list) SQL 子句和参数列表
- """
- if isinstance(condition, dict):
- clause = " AND ".join([f"{self._safe_identifier(k)} = %s" for k in condition.keys()])
- args = list(condition.values())
- elif isinstance(condition, str):
- clause = condition # 注意:需调用方确保安全
- args = []
- elif isinstance(condition, (tuple, list)) and len(condition) == 2:
- clause, args = condition[0], condition[1]
- if not isinstance(args, (list, tuple)):
- args = [args]
- else:
- raise ValueError("Condition must be dict/str/(clause, args)")
- return clause, args
- def update_many(self, table=None, data_list=None, condition_list=None, query=None, args_list=None, batch_size=500,
- commit=True):
- """
- 批量更新(支持字典列表或原始SQL)
- :param table: 表名(字典插入时必需)
- :param data_list: 字典列表 [{列名: 值}]
- :param condition_list: 条件列表(必须为字典,与data_list等长)
- :param query: 直接SQL语句(与data_list二选一)
- :param args_list: SQL参数列表(query使用时必需)
- :param batch_size: 分批大小
- :param commit: 是否自动提交
- :return: 影响行数
- """
- if data_list is not None:
- if not data_list or not isinstance(data_list[0], dict):
- raise ValueError("Data_list must be a non-empty list of dictionaries")
- if condition_list is None or len(data_list) != len(condition_list):
- raise ValueError("Condition_list must be provided and match the length of data_list")
- if not all(isinstance(cond, dict) for cond in condition_list):
- raise ValueError("All elements in condition_list must be dictionaries")
- # 获取第一个数据项和条件项的键
- first_data_keys = set(data_list[0].keys())
- first_cond_keys = set(condition_list[0].keys())
- # 构造基础SQL
- set_clause = ', '.join([self._safe_identifier(k) + ' = %s' for k in data_list[0].keys()])
- condition_clause = ' AND '.join([self._safe_identifier(k) + ' = %s' for k in condition_list[0].keys()])
- base_query = f"UPDATE {self._safe_identifier(table)} SET {set_clause} WHERE {condition_clause}"
- total = 0
- # 分批次处理
- for i in range(0, len(data_list), batch_size):
- batch_data = data_list[i:i + batch_size]
- batch_conds = condition_list[i:i + batch_size]
- batch_args = []
- # 检查当前批次的结构是否一致
- can_batch = True
- for data, cond in zip(batch_data, batch_conds):
- data_keys = set(data.keys())
- cond_keys = set(cond.keys())
- if data_keys != first_data_keys or cond_keys != first_cond_keys:
- can_batch = False
- break
- batch_args.append(tuple(data.values()) + tuple(cond.values()))
- if not can_batch:
- # 结构不一致,转为单条更新
- for data, cond in zip(batch_data, batch_conds):
- self.update_one_or_dict(table=table, data=data, condition=cond, commit=commit)
- total += 1
- continue
- # 执行批量更新
- try:
- with self.pool.connection() as conn:
- with conn.cursor() as cursor:
- cursor.executemany(base_query, batch_args)
- if commit:
- conn.commit()
- total += cursor.rowcount
- self.log.debug(f"Batch update succeeded. Rows: {cursor.rowcount}")
- except Exception as e:
- if commit:
- conn.rollback()
- self.log.error(f"Batch update failed: {e}")
- # 降级为单条更新
- for args, data, cond in zip(batch_args, batch_data, batch_conds):
- try:
- self._execute(base_query, args, commit=commit)
- total += 1
- except Exception as e2:
- self.log.error(f"Single update failed: {e2}, Data: {data}, Condition: {cond}")
- self.log.info(f"Total updated rows: {total}")
- return total
- elif query is not None:
- # 处理原始SQL和参数列表
- if args_list is None:
- raise ValueError("args_list must be provided when using query")
- total = 0
- for i in range(0, len(args_list), batch_size):
- batch_args = args_list[i:i + batch_size]
- try:
- with self.pool.connection() as conn:
- with conn.cursor() as cursor:
- cursor.executemany(query, batch_args)
- if commit:
- conn.commit()
- total += cursor.rowcount
- self.log.debug(f"Batch update succeeded. Rows: {cursor.rowcount}")
- except Exception as e:
- if commit:
- conn.rollback()
- self.log.error(f"Batch update failed: {e}")
- # 降级为单条更新
- for args in batch_args:
- try:
- self._execute(query, args, commit=commit)
- total += 1
- except Exception as e2:
- self.log.error(f"Single update failed: {e2}, Args: {args}")
- self.log.info(f"Total updated rows: {total}")
- return total
- else:
- raise ValueError("Either data_list or query must be provided")
- def check_pool_health(self):
- """
- 检查连接池中有效连接数
- # 使用示例
- # 配置 MySQL 连接池
- sql_pool = MySQLConnectionPool(log=log)
- if not sql_pool.check_pool_health():
- log.error("数据库连接池异常")
- raise RuntimeError("数据库连接池异常")
- """
- try:
- with self.pool.connection() as conn:
- conn.ping(reconnect=True)
- return True
- except Exception as e:
- self.log.error(f"Connection pool health check failed: {e}")
- return False
- @staticmethod
- def _safe_identifier(name):
- """SQL标识符安全校验"""
- if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name):
- raise ValueError(f"Invalid SQL identifier: {name}")
- return name
|