# -*- coding:utf-8 -*- from typing import Dict, Union, List from dw_base.utils.config_utils import parse_args from dw_base.spark.spark_sql import SparkSQL from pyspark.sql import SparkSession """ @author xunxu 提供一种类似spark-submit提交方式的SparkSession初始化工具类 xxx.py \ --name "spark-sql-test-job" \ --queue cts \ --num-executors 2 \ --executor-memory 1g \ --executor-cores 1 \ --conf spark.sql.shuffle.partitions=300 \ --conf spark.default.parallelism=300 \ --conf spark.dynamicAllocation.enabled=true \ --py-files dw_base/spark/udf/customs/common_clean.py,dw_base/spark/udf/spark_eng_ent_name_clean.py \ -dt="" -cdt="" -ydt="" ... """ def get_spark(argv: list) -> (SparkSQL, SparkSession): """ Args: argv: sys.argv parsed from the command line Returns: tendata SparkSQL and SparkSession Tuple """ conf_args: Dict[str, Union[str, bool, List[str]]] conf_args, _ = parse_args(argv[1:]) spark_conf_dict = { "hive.exec.dynamic.partition": "true", "hive.exec.dynamic.partition.mode": "nonstrict", "spark.dynamicAllocation.enabled": "true" } # 添加所有的--conf配置到extra_spark_conf中 if conf_args.__contains__('conf'): spark_conf = conf_args['conf'] if isinstance(spark_conf, list): spark_conf_dict.update( dict(map(lambda kv_str: kv_str.split("="), spark_conf)) ) elif isinstance(spark_conf, str): k, v = spark_conf.split("=") spark_conf_dict[k] = v td_spark = SparkSQL( session_name=conf_args.get("name", argv[0]), master=conf_args.get("master", "yarn"), spark_yarn_queue=conf_args.get("queue", "default"), spark_driver_memory=conf_args.get("driver-memory", "1g"), spark_driver_cores=conf_args.get("driver-core", 1), spark_executor_instances=conf_args.get("num-executors", 2), spark_executor_cores=conf_args.get("executor-cores", 2), spark_executor_memory=conf_args.get("executor-memory", "6g"), extra_spark_config=spark_conf_dict, udf_files=conf_args['py-files'].split(",") if conf_args.__contains__('py-files') else None ) spark: SparkSession = td_spark.get_spark_session() return td_spark, spark # if __name__ == "__main__": # spark: SparkSession # td_spark, spark = get_spark(sys.argv) # spark.sql("show databases").show(100, truncate=False)