# -*- coding:utf-8 -*- import inspect import re from importlib import import_module from typing import List, Union, Dict, Any, Tuple from pyspark.sql import Row, SparkSession, DataFrame from dw_base import * from dw_base.utils import common_utils from dw_base.utils.datetime_utils import get_today from dw_base.utils.file_utils import get_abs_path from dw_base.utils.log_utils import pretty_print from dw_base.utils.sql_utils import get_sql_list_from_file, check_parameter_substituted HDFS_EXPORT_DATA_PATH = '/hdfs-mnt/export-data' class SparkSQL(object): """ 封装执行 Spark 相关操作的类, 相关参数说明: In Spark 2.0+ version use spark session variable to set number of executors dynamically (from within program) spark.conf.set("spark.executor.instances', 4) spark.conf.set("spark.executor.cores', 4) In above case maximum 16 tasks will be executed at any given time. other option is dynamic allocation of executors as below - spark.conf.set("spark.dynamicAllocation.enabled', "true') spark.conf.set("spark.executor.cores', 4) spark.conf.set("spark.dynamicAllocation.minExecutors',"1') spark.conf.set("spark.dynamicAllocation.maxExecutors',"5') spark.yarn.executor.memoryOverhead:default is executorMemory * 0.07, with minimum of 384 spark.yarn.driver.memoryOverhead:default is driverMemory * 0.07, with minimum of 384 spark.yarn.am.memoryOverhead:default is AM memory * 0.07, with minimum of 384 """ REGISTERED_UDF_FILES = [] ADDED_RESOURCE_FILES = [] REGISTERED_UDF = {} IGNORED_UDF = ['attr', 'udf'] def __init__(self, session_name: str = 'spark', master: str = 'yarn', spark_yarn_queue: str = 'spark', spark_driver_memory: str = '2g', spark_executor_memory: str = '6g', spark_executor_memory_overhead: str = '512', spark_driver_cores: int = 2, spark_executor_cores: int = 2, spark_executor_instances: int = 15, spark_driver_max_result_size='4g', spark_shuffle_partitions=200, spark_default_parallelism=200, extra_spark_config: Dict[str, Any] = None, udf_files: List[str] = None, resource_files: List[str] = None): """ Args: session_name: master: spark_yarn_queue: spark_driver_memory: spark_executor_memory: spark_executor_memory_overhead: spark_driver_cores: spark_executor_cores: spark_executor_instances: spark_driver_max_result_size: spark_shuffle_partitions: spark_default_parallelism: extra_spark_config: 额外的Spark配置,优先级高于SQL文件中定义的Spark配置 udf_files: udf文件列表,必须传相对路径,否则不识别 resource_files: 资源文件列表,必须传相对路径,否则不识别 """ self._spark_session = None self._session_name = session_name self._master = master self._spark_driver_memory = spark_driver_memory self._spark_executor_memory = spark_executor_memory self._spark_executor_memory_overhead = spark_executor_memory_overhead self._spark_driver_cores = spark_driver_cores self._spark_executor_cores = spark_executor_cores self._spark_executor_instances = spark_executor_instances self._spark_yarn_queue = spark_yarn_queue self._spark_driver_max_result_size = spark_driver_max_result_size self._spark_shuffle_partitions = spark_shuffle_partitions self._spark_default_parallelism = spark_default_parallelism self._extra_spark_config = extra_spark_config self._final_spark_config = {} self.limit = 20 if not udf_files: udf_files = [] self._udf_files = udf_files self._py_files = [] if not resource_files: resource_files = [] self._resource_files = resource_files self.global_start_time = None self.start_time = None def __add_spark_config(self, sql_file: str, spark_config_def: str): if not spark_config_def.__contains__('='): pretty_print(f'{NORM_YEL}无效的 Spark 配置 {NORM_GRN}{spark_config_def}') return spark_config, config_value = [e.strip() for e in spark_config_def.split('=')] if self._final_spark_config.__contains__(spark_config): pretty_print( f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_file}{NORM_YEL} 中定义的 Spark 配置 {NORM_GRN}{spark_config} ' f'{NORM_YEL}重复提供,原值 {NORM_GRN}{self._final_spark_config[spark_config]} ' f'{NORM_YEL}将被覆盖为新值 {NORM_GRN}{config_value}') else: self._final_spark_config[spark_config] = config_value self._final_spark_config[spark_config] = config_value def __init_spark_session(self): if self._spark_session: return pretty_print(f'{NORM_MGT}基于用户 {NORM_GRN}{USER}{NORM_MGT} 创建 SparkSession') # for element in os.environ: # pretty_print(f'{NORM_MGT}Environment {NORM_GRN}{element} => {os.environ[element]}') builder = SparkSession.builder \ .appName(self._session_name) \ .master(self._master) \ .config('hive.exec.orc.default.block.size', 134217728) \ .config('spark.debug.maxToStringFields', 5000) \ .config('spark.default.parallelism', self._spark_default_parallelism) \ .config('spark.driver.cores', self._spark_driver_cores) \ .config('spark.driver.maxResultSize', self._spark_driver_max_result_size) \ .config('spark.driver.memory', self._spark_driver_memory) \ .config('spark.dynamicAllocation.enabled', False) \ .config('spark.files.ignoreCorruptFiles', True) \ .config('spark.executor.cores', self._spark_executor_cores) \ .config('spark.executor.instances', self._spark_executor_instances) \ .config('spark.executor.memory', self._spark_executor_memory) \ .config('spark.executor.memoryOverhead', self._spark_executor_memory_overhead) \ .config('spark.sql.adaptive.enabled', 'true') \ .config('spark.sql.broadcastTimeout', -1) \ .config('spark.sql.codegen.wholeStage', 'false') \ .config('spark.sql.execution.arrow.enabled', True) \ .config('spark.sql.execution.arrow.fallback.enabled', True) \ .config('spark.sql.files.ignoreCorruptFiles', True) \ .config('spark.sql.shuffle.partitions', self._spark_shuffle_partitions) \ .config('spark.sql.statistics.fallBackToHdfs', True) \ .config('spark.yarn.queue', self._spark_yarn_queue) \ .config('spark.port.maxRetries', 999) if self._extra_spark_config: for spark_config, config_value in self._extra_spark_config.items(): if self._final_spark_config.__contains__(spark_config): pretty_print(f'{NORM_YEL}构造函数传入的 Spark 配置 {NORM_GRN}{spark_config} => {config_value} ' f'{NORM_YEL}覆盖了在 SQL 文件中定义的配置 ' f'{NORM_GRN}{spark_config} => {self._final_spark_config[spark_config]}') self._final_spark_config[spark_config] = config_value if self._final_spark_config: for key, value in self._final_spark_config.items(): pretty_print(f'{NORM_MGT}添加自定义 Spark 配置 {NORM_GRN}{key} => {str(value)}') builder.config(key, value) pretty_print(f'{NORM_MGT}创建 SparkSession') self._spark_session = builder.enableHiveSupport().getOrCreate() self._spark_session.sparkContext._jsc.hadoopConfiguration().set('mapred.max.split.size', '33554432') self._spark_session.sparkContext._jsc.hadoopConfiguration().set( 'mapreduce.fileoutputcommitter.marksuccessfuljobs', 'false') first_py_file = 'dw_base.zip' if IS_RUN_IN_RELEASE_DIR and IS_RUN_BY_RELEASE_USER: command = f'cd {PROJECT_ROOT_PATH} && if [ ! -f {first_py_file} ];then zip -qr {first_py_file} dw_base; fi' else: command = f'cd {PROJECT_ROOT_PATH} && rm -f {first_py_file} && zip -qr {first_py_file} dw_base' pretty_print(f'{NORM_MGT}执行 Shell 命令 {NORM_GRN}{command}') os.system(command) self._spark_session.sparkContext.addPyFile(get_abs_path(first_py_file)) self.register_udf_files(self._udf_files) for py_file in self._py_files: self.add_py_file(py_file) for resource_file in self._resource_files: self.add_resource_file(resource_file) @staticmethod def add_parameters(sql_or_file: str, parameters: Dict, key: str, value: Any): if parameters.__contains__(key): pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} ' f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} ' f'{NORM_YEL}发生变更,将使用新值 {NORM_GRN}{Any} ' f'{NORM_YEL}替代原值 {NORM_GRN}{parameters[key]}') parameters[key] = value def add_py_file(self, py_file: str) -> bool: abs_py_file = get_abs_path(py_file) if self.REGISTERED_UDF_FILES.__contains__(abs_py_file): return False if not self._spark_session: self._py_files.append(py_file) return True pretty_print(f'{NORM_MGT}添加 Python 文件 {NORM_GRN}{py_file}') self._spark_session.sparkContext.addPyFile(abs_py_file) self.REGISTERED_UDF_FILES.append(abs_py_file) return True def add_resource_file(self, resource_file: str) -> bool: abs_resource_file = get_abs_path(resource_file) if self.ADDED_RESOURCE_FILES.__contains__(abs_resource_file): return False if not self._spark_session: self._resource_files.append(resource_file) return True pretty_print(f'{NORM_MGT}添加资源文件 {NORM_GRN}{resource_file}') self._spark_session.sparkContext.addFile(abs_resource_file) self.ADDED_RESOURCE_FILES.append(abs_resource_file) return True def execute(self, sql_or_file: str, check_parameter: bool = False, silent: bool = False, fill_null=None, **kwargs): """ 仅执行,不返回任何值 Args: sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件路径(相对或绝对都可以) check_parameter: 运行前是否进行参数检查 silent: 执行时是否打印日志,True则不打印 fill_null: 填充null值 **kwargs: sql 语句或 sql 文件所用到的参数 Returns: """ # 运行SQL语句或SQL文件 start_time = time.time() data_frame, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs) if is_select: self.show_data_frame(data_frame, show_number=self.limit, truncate=USER == RELEASE_USER) end_time = time.time() cost = round(float(end_time - self.start_time), 2) pretty_print(f'{NORM_MGT}SQL 语句执行完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒') if sql_or_file.endswith('.sql'): total_cost = round(float(end_time - self.global_start_time), 2) pretty_print(f'{NORM_MGT}SQL 文件 {NORM_GRN}{sql_or_file}' f'{NORM_MGT} 执行完毕,共耗时 {NORM_GRN}{str(total_cost)}{NORM_MGT} 秒') self.global_start_time = None def export_data(self, data_set_name: str, sql_or_file: str, show_number: int = 100, truncate: Union[bool, int] = 40, delimiter: str = ',', partition: int = 1, **kwargs) -> str: """ 导出数据,默认存储于HDFS目录(/hdfs-mnt/export-data),映射本地目录(/opt/hdfs-mnt/export-data) Args: data_set_name: 导出的数据集名称标识(作为目录) sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径 show_number: 显示的行数 truncate: If set to ``True``, truncate strings longer than 20 chars by default. If set to a number greater than one, truncates long strings to length ``truncate`` and align cells right. delimiter: 分隔符 partition: 导出文件个数 **kwargs: sql 语句或 sql 文件所用到的参数 Returns: """ start_time = time.time() data_frame, _ = self.query(sql_or_file, **kwargs) data_frame.persist() self.show_data_frame(data_frame, data_set_name, show_number, truncate) today = get_today() # 注意: 此目录不是linux服务器的目录, 而是hdfs文件系统的目录 hdfs_directory = f'{data_set_name}.{str(int(time.time()))}' hdfs_export_path = os.path.join(HDFS_EXPORT_DATA_PATH, USER, today, hdfs_directory) # linux服务器导出文件存放的目录 # local_export_path = os.path.join(LOCAL_EXPORT_DATA_PATH, USER, today, data_set_name) pretty_print(f'{NORM_MGT}准备导出文件到HDFS目录: {NORM_GRN}{hdfs_export_path}') if not delimiter: delimiter = ',' # 先将数据保存至hdfs文件系统, 加 option("escape", "\"") 不会导致csv文件“字段文本内容中的英文逗号来分割字段(列膨胀错位)” data_frame \ .repartition(partition) \ .write \ .mode('overwrite') \ .format('com.databricks.spark.csv') \ .option('header', 'true') \ .option("escape", "\"") \ .option('delimiter', delimiter) \ .save(hdfs_export_path) cost = round(float(time.time() - start_time), 2) pretty_print(f'{NORM_MGT}数据导出完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒') return hdfs_export_path def get_columns(self, hive_table_name: str): desc_df, _ = self.query(f'desc {hive_table_name}', silent=True) desc_rows = desc_df.collect() hive_columns = {} partition_col_start = False for row in desc_rows: # type: Row if row[0] == '# Partition Information': partition_col_start = True continue if row[0] == '# col_name': continue if partition_col_start: del hive_columns[list(hive_columns.keys())[len(hive_columns.keys()) - 1]] else: hive_columns[row[0]] = (row[1], row[2]) return hive_columns def get_partition_columns(self, hive_table_name: str): desc_df, _ = self.query(f'desc {hive_table_name}', silent=True) desc_rows = desc_df.collect() partition_columns = {} partitioned = False for row in desc_rows: # type: Row if row[0] != '# Partition Information': continue if row[0].startswith('#'): continue partitioned = True partition_columns[row[0]] = (row[1], row[2]) return partitioned, partition_columns def get_spark_session(self): if not self._spark_session: self.__init_spark_session() self.start_time = time.time() if not self.global_start_time: self.global_start_time = time.time() return self._spark_session def jdbc_read(self, jdbc_url: str, table_name: str) -> DataFrame: return self.get_spark_session().read.jdbc(jdbc_url, table_name) def list_tables(self, hive_database_name: str = 'default', include_regex: str = None, exclude_regex: str = None): if not hive_database_name: hive_database_name = 'default' tables_df, _ = self.query(f'show tables in {hive_database_name}') tables = [] for row in tables_df.collect(): if hive_database_name == 'default': table_name = row.asDict()['tableName'] else: table_name = f"{hive_database_name}.{row.asDict()['tableName']}" if exclude_regex and re.match(exclude_regex, table_name): continue if include_regex: if re.match(include_regex, table_name): tables.append(table_name) continue else: tables.append(table_name) return tables def query(self, sql_or_file: str, check_parameter: bool = False, silent: bool = False, fill_null=None, **kwargs) -> Tuple[DataFrame, bool]: """ 执行查询,返回一个DataFrame和该DataFrame是否是SELECT Args: sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径 check_parameter: 运行前是否进行参数检查 silent: 执行时是否打印日志,True则不打印 fill_null: 填充null值 **kwargs: sql 语句或 sql 文件所用到的参数 Returns: Tuple[DataFrame, bool] """ if sql_or_file.endswith('.sql'): if not silent: pretty_print(f'{NORM_MGT}准备执行 SQL 文件 {NORM_GRN}{sql_or_file}') sql_list = get_sql_list_from_file(sql_or_file, trim_comment=True) else: sql_list = [sql_or_file] cleaned_sql_list = [] # 外部传入的参数 external_parameters = {} if kwargs: for key, value in kwargs.items(): if not silent: pretty_print(f'{NORM_MGT}收到外部 SQL 参数 {NORM_GRN}{key} => {value}') external_parameters[key] = value # SQL文件内部定义的参数 internal_parameters = {} for sql in sql_list: upper_sql = sql.upper() if upper_sql.startswith('ADD FILES'): file_names = sql[len('ADD FILES'):].strip().split(' ') for file_name in file_names: file_name = file_name.strip() if not file_name: continue if file_name.endswith('.py'): self.register_udf_file(file_name) else: self.add_resource_file(file_name) elif upper_sql.startswith('ADD FILE'): file_name = sql[len('ADD FILE'):].strip() if file_name.endswith('.py'): self.register_udf_file(file_name) else: self.add_resource_file(file_name) elif upper_sql.startswith('ADD SPARK_UDF'): self.register_udf_file(sql[len('ADD SPARK_UDF'):].strip()) elif upper_sql.startswith('SET SPARK_VAR:'): var_key, var_value = sql[len('SET SPARK_VAR:'):].strip().split('=') var_key = var_key.strip() var_value = var_value.strip() self.add_parameters(sql_or_file, internal_parameters, var_key, var_value) elif upper_sql.startswith('SET LIMIT='): if not self._spark_session: self.limit = int(sql.split('=')[1].strip()) elif upper_sql.startswith('SET '): set_expr = sql[len('SET '):].strip() set_expr_splits = set_expr.split('=') set_key = set_expr_splits[0].strip() set_value = set_expr_splits[1].strip() if re.match(r'.+\..+(\..+)*', set_key): if not self._spark_session: self.__add_spark_config(sql_or_file, sql[len('SET '):].strip()) else: self.add_parameters(sql_or_file, internal_parameters, set_key, set_value) else: # 优先使用外部参数 for key, value in external_parameters.items(): sql = sql.replace('${%s}' % key, value) # SQL文件里的参数必须在使用前定义(同时也意味着相同的参数可以多次定义不同值,供多个SQL使用) for key, value in internal_parameters.items(): if external_parameters.__contains__(key): # 如果外部参数中包括当前参数,则使用外部参数(实际上已在上一步替换,此处仅为提示) pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} ' f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} ' f'{NORM_YEL}已由外部参数抢占,将使用新值 {NORM_GRN}{external_parameters[key]} ' f'{NORM_YEL}替代原值 {NORM_GRN}{internal_parameters[key]}') continue sql = sql.replace('${%s}' % key, value) cleaned_sql_list.append(sql) if len(cleaned_sql_list) == 0: return self.get_spark_session().sql("SELECT ''"), False for index in range(len(cleaned_sql_list) - 1): sql = cleaned_sql_list[index] self.execute(sql) last_or_only_sql = cleaned_sql_list[-1] check_parameter_substituted(last_or_only_sql, check_parameter) self.get_spark_session() if not silent: if last_or_only_sql.__contains__('\n'): pretty_print(f'{NORM_MGT}开始执行 SQL 语句: \n{NORM_GRN}{last_or_only_sql}') else: pretty_print(f'{NORM_MGT}开始执行 SQL 语句: {NORM_GRN}{last_or_only_sql}') data_frame = self._spark_session.sql(last_or_only_sql) if fill_null: data_frame.fillna(fill_null) return data_frame, len(data_frame.schema.fields) > 0 def query_scalar(self, sql_or_file: str, check_parameter: bool = False, silent: bool = False, fill_null=None, **kwargs): data, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs) if is_select is not True: raise Exception('Get scalar from non-query statement') res = data.collect() res1 = res[0] return res[0][0] def register_udf_files(self, udf_file_list: List[str]): if not udf_file_list: return for udf_file in udf_file_list: self.register_udf_file(udf_file) def register_udf_file(self, udf_file: str): if not udf_file: return if not self.add_py_file(udf_file): pretty_print(f'{NORM_YEL}Python文件 {NORM_GRN}{udf_file}{NORM_YEL} 已被添加过(重复添加)') return if not self._spark_session: self._udf_files.append(udf_file) return pretty_print(f'{NORM_MGT}注册文件 {NORM_GRN}{udf_file}{NORM_MGT} 中的 Python UDF') udf_module_name = udf_file.replace('/', '.').replace('.py', '') module = import_module(udf_module_name) for (name, attr) in inspect.getmembers(module): if self.IGNORED_UDF.__contains__(name): continue if self.REGISTERED_UDF.__contains__(name): pretty_print(f'{NORM_YEL}名为 {NORM_GRN}{name}{NORM_YEL} 的 Python UDF ' f'{NORM_YEL}已在文件 {NORM_GRN}{self.REGISTERED_UDF[name]} ' f'{NORM_YEL}被注册(重名UDF以首次出现的为准)') continue if inspect.isfunction(attr): # todo: 注册每个udf的返回值类型 pretty_print(f'{NORM_MGT}注册 Python UDF {NORM_GRN}{name}') self._spark_session.udf.register(name, attr) self.REGISTERED_UDF[name] = udf_file @staticmethod def show_data_frame(data_frame: DataFrame, data_frame_name: str = None, show_number: int = 20, truncate: Union[bool, int] = False): """ 打印``DataFrame``的Schema和指定行数的数据 Args: data_frame: DataFrame data_frame_name: DataFrame的名称 show_number: 打印的行数 truncate: If set to ``True``, truncate strings longer than 20 chars by default. If set to a number greater than one, truncates long strings to length ``truncate`` and align cells right. Returns: """ if data_frame_name: pretty_print(f'{NORM_MGT}%s %s DataFrame Information: %s ' % ('=' * 30, data_frame_name, '=' * 30)) else: pretty_print(f'{NORM_MGT}%s DataFrame Information: %s ' % ('=' * 30, '=' * 30)) pretty_print(f'{NORM_MGT}Schema :') data_frame.printSchema() pretty_print(f'{NORM_MGT}Data sample :') data_frame.show(show_number, truncate=truncate) pretty_print(f'{NORM_MGT}=' * 70) if __name__ == '__main__': with SparkSQL('spark-sql') as spark_sql: # spark_sql.execute('workspace/x.sql') # spark_sql._spark_session.conf.set('app.name', 'tmp.sa33') # spark_sql.execute('workspace/y.sql') # spark_sql.query('show databases')[0].show(100, truncate=False) # spark_sql.query('show functions')[0].show(1000, truncate=False) spark_sql.get_columns('crl_ads.ads_crl_dim_source_stat')