| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- # -*- coding:utf-8 -*-
- from typing import Dict, Union, List
- from dw_base.utils.config_utils import parse_args
- from dw_base.spark.spark_sql import SparkSQL
- from pyspark.sql import SparkSession
- """
- @author xunxu
- 提供一种类似spark-submit提交方式的SparkSession初始化工具类
- xxx.py \
- --name "spark-sql-test-job" \
- --queue cts \
- --num-executors 2 \
- --executor-memory 1g \
- --executor-cores 1 \
- --conf spark.sql.shuffle.partitions=300 \
- --conf spark.default.parallelism=300 \
- --conf spark.dynamicAllocation.enabled=true \
- --py-files dw_base/spark/udf/customs/common_clean.py,dw_base/spark/udf/spark_eng_ent_name_clean.py \
- -dt="" -cdt="" -ydt="" ...
- """
- def get_spark(argv: list) -> (SparkSQL, SparkSession):
- """
- Args:
- argv: sys.argv parsed from the command line
- Returns: tendata SparkSQL and SparkSession Tuple
- """
- conf_args: Dict[str, Union[str, bool, List[str]]]
- conf_args, _ = parse_args(argv[1:])
- spark_conf_dict = {
- "hive.exec.dynamic.partition": "true",
- "hive.exec.dynamic.partition.mode": "nonstrict",
- "spark.dynamicAllocation.enabled": "true"
- }
- # 添加所有的--conf配置到extra_spark_conf中
- if conf_args.__contains__('conf'):
- spark_conf = conf_args['conf']
- if isinstance(spark_conf, list):
- spark_conf_dict.update(
- dict(map(lambda kv_str: kv_str.split("="), spark_conf))
- )
- elif isinstance(spark_conf, str):
- k, v = spark_conf.split("=")
- spark_conf_dict[k] = v
- td_spark = SparkSQL(
- session_name=conf_args.get("name", argv[0]),
- master=conf_args.get("master", "yarn"),
- spark_yarn_queue=conf_args.get("queue", "default"),
- spark_driver_memory=conf_args.get("driver-memory", "1g"),
- spark_driver_cores=conf_args.get("driver-core", 1),
- spark_executor_instances=conf_args.get("num-executors", 2),
- spark_executor_cores=conf_args.get("executor-cores", 2),
- spark_executor_memory=conf_args.get("executor-memory", "6g"),
- extra_spark_config=spark_conf_dict,
- udf_files=conf_args['py-files'].split(",") if conf_args.__contains__('py-files') else None
- )
- spark: SparkSession = td_spark.get_spark_session()
- return td_spark, spark
- # if __name__ == "__main__":
- # spark: SparkSession
- # td_spark, spark = get_spark(sys.argv)
- # spark.sql("show databases").show(100, truncate=False)
|