| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- # -*- coding:utf-8 -*-
- import json
- import re
- from typing import Dict, List
- import pymysql
- class MySQLColumn(object):
- def __init__(self,
- column_name: str,
- column_type: str,
- column_comment: str,
- ordinal_position: str,
- is_nullable: bool):
- self.COLUMN_NAME = column_name
- self.COLUMN_TYPE = column_type
- self.COLUMN_COMMENT = column_comment
- self.ORDINAL_POSITION = ordinal_position
- self.IS_NULLABLE = is_nullable
- self._dict = {
- 'COLUMN_NAME': column_name,
- 'COLUMN_TYPE': column_type,
- 'COLUMN_COMMENT': column_comment,
- 'ORDINAL_POSITION': ordinal_position,
- 'IS_NULLABLE': is_nullable,
- }
- def __str__(self):
- return json.dumps(self._dict, ensure_ascii=False)
- class MySQLHandler:
- def __init__(self, host: str, port: int, username: str, password: str, database: str = None):
- """
- MySQL工具包
- Args:
- host: 实例地址
- port: 端口
- username: 用户名
- password: 密码
- """
- self.jdbcUrl = "jdbc:mysql://%s:%s" % (host, port)
- self.username = username
- self.password = password
- self.connection = pymysql.connect(
- host=host,
- port=port,
- user=username,
- password=password,
- database=database,
- charset='utf8'
- )
- self.connection.autocommit(True)
- def list_tables(self,
- database: str = None,
- exclude_regex: List[str] = None,
- table_regex: List[str] = None) -> Dict[str, str]:
- """
- 列出指定数据库中的表
- Args:
- database: 数据库名称
- exclude_regex: 不要的数据表正则
- table_regex: 想要的数据表正则
- Returns: 表及注释
- """
- assert database is not None
- curs = self.connection.cursor()
- curs.execute('SET NAMES utf8')
- curs.execute(f'use {database}')
- sql = "SELECT TABLE_NAME, TABLE_COMMENT " \
- " FROM information_schema.TABLES " \
- f" WHERE TABLE_SCHEMA='{database}' AND TABLE_TYPE = 'BASE TABLE'"
- curs.execute(sql)
- rows = curs.fetchall()
- tables = {}
- for each_row in rows:
- if exclude_regex:
- exclude = False
- for regex in exclude_regex:
- if re.match(regex, each_row[0]):
- exclude = True
- break
- if exclude:
- continue
- if table_regex:
- match = False
- for regex in table_regex:
- if re.match(regex, each_row[0]):
- match = True
- break
- if not match:
- continue
- tables[each_row[0]] = each_row[1]
- return tables
- def list_columns(self, database: str, table_name: str) -> List[MySQLColumn]:
- """
- 列出指定数据库、指定表的字段及字段的其他信息
- Args:
- database: 数据库
- table_name: 表
- Returns: 字段及字段的其他信息
- """
- assert database is not None
- assert table_name is not None
- curs = self.connection.cursor()
- curs.execute('SET NAMES utf8')
- curs.execute(f'use {database}')
- detail_names = ['COLUMN_TYPE', 'COLUMN_COMMENT', 'ORDINAL_POSITION', 'IS_NULLABLE']
- sql = "SELECT COLUMN_NAME, %s" \
- " FROM information_schema.COLUMNS " \
- " WHERE TABLE_SCHEMA = '%s' AND TABLE_NAME = '%s'"
- sql = sql % (', '.join(detail_names), database, table_name)
- curs.execute(sql)
- rows = curs.fetchall()
- columns = []
- for each_row in rows:
- column_name = each_row[0]
- column_type = each_row[1]
- column_comment = each_row[2]
- ordinal_position = each_row[3]
- is_nullable = each_row[4]
- mysql_column = MySQLColumn(column_name, column_type, column_comment, ordinal_position, is_nullable)
- columns.append(mysql_column)
- return columns
- def query(self, sql: str):
- curs = self.connection.cursor()
- curs.execute('SET NAMES utf8')
- curs.execute(sql)
- rows = curs.fetchall()
- return rows
- def query_column_hive_metadata(self, table_name: str):
- curs = self.connection.cursor()
- curs.execute('SET NAMES utf8')
- sql = f'SELECT' \
- f' t.TBL_NAME,' \
- f' c.COLUMN_NAME,' \
- f' c.TYPE_NAME,' \
- f' c.`COMMENT` ' \
- f'FROM ' \
- f' hive.TBLS t' \
- f' LEFT JOIN hive.SDS s ON t.SD_ID = s.SD_ID' \
- f' LEFT JOIN hive.COLUMNS_V2 c ON s.CD_ID = c.CD_ID' \
- f' LEFT JOIN hive.TBLS tbs ON s.SD_ID = tbs.SD_ID ' \
- f'WHERE t.TBL_NAME = "{table_name}"'
- curs.execute(sql)
- column_info = curs.fetchall()
- return column_info
- def query_tbl_hive_metadata(self, table_name: str):
- curs = self.connection.cursor()
- curs.execute('SET NAMES utf8')
- sql = f'SELECT' \
- f' tp.PARAM_KEY,' \
- f' tp.PARAM_VALUE ' \
- f'FROM' \
- f' hive.TABLE_PARAMS tp' \
- f' LEFT JOIN hive.TBLS t ON tp.TBL_ID = t.TBL_ID ' \
- f'WHERE' \
- f' t.TBL_NAME = "{table_name}"'
- curs.execute(sql)
- column_info = curs.fetchall()
- return column_info
- if __name__ == '__main__':
- mysql_handler = MySQLHandler(
- 'rm-m5e76y41wq677ogz7.mysql.rds.aliyuncs.com',
- 3306,
- 'bigdata_sync',
- '76iW6SG2K6RGN2X68EQb'
- )
- database_ame = 'ik_bms_production'
- tables = mysql_handler.list_tables(database_ame)
- for table_name, table_comment in tables.items():
- print(f'{table_name}\t{table_comment}')
- columns = mysql_handler.list_columns(database_ame, table_name)
- for col in columns:
- print(col)
- break
|