td_spark_init.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -*- coding:utf-8 -*-
  2. from typing import Dict, Union, List
  3. from dw_base.utils.config_utils import parse_args
  4. from dw_base.spark.spark_sql import SparkSQL
  5. from pyspark.sql import SparkSession
  6. """
  7. @author xunxu
  8. 提供一种类似spark-submit提交方式的SparkSession初始化工具类
  9. xxx.py \
  10. --name "spark-sql-test-job" \
  11. --queue cts \
  12. --num-executors 2 \
  13. --executor-memory 1g \
  14. --executor-cores 1 \
  15. --conf spark.sql.shuffle.partitions=300 \
  16. --conf spark.default.parallelism=300 \
  17. --conf spark.dynamicAllocation.enabled=true \
  18. --py-files dw_base/spark/udf/customs/common_clean.py,dw_base/spark/udf/spark_eng_ent_name_clean.py \
  19. -dt="" -cdt="" -ydt="" ...
  20. """
  21. def get_spark(argv: list) -> (SparkSQL, SparkSession):
  22. """
  23. Args:
  24. argv: sys.argv parsed from the command line
  25. Returns: tendata SparkSQL and SparkSession Tuple
  26. """
  27. conf_args: Dict[str, Union[str, bool, List[str]]]
  28. conf_args, _ = parse_args(argv[1:])
  29. spark_conf_dict = {
  30. "hive.exec.dynamic.partition": "true",
  31. "hive.exec.dynamic.partition.mode": "nonstrict",
  32. "spark.dynamicAllocation.enabled": "true"
  33. }
  34. # 添加所有的--conf配置到extra_spark_conf中
  35. if conf_args.__contains__('conf'):
  36. spark_conf = conf_args['conf']
  37. if isinstance(spark_conf, list):
  38. spark_conf_dict.update(
  39. dict(map(lambda kv_str: kv_str.split("="), spark_conf))
  40. )
  41. elif isinstance(spark_conf, str):
  42. k, v = spark_conf.split("=")
  43. spark_conf_dict[k] = v
  44. td_spark = SparkSQL(
  45. session_name=conf_args.get("name", argv[0]),
  46. master=conf_args.get("master", "yarn"),
  47. spark_yarn_queue=conf_args.get("queue", "default"),
  48. spark_driver_memory=conf_args.get("driver-memory", "1g"),
  49. spark_driver_cores=conf_args.get("driver-core", 1),
  50. spark_executor_instances=conf_args.get("num-executors", 2),
  51. spark_executor_cores=conf_args.get("executor-cores", 2),
  52. spark_executor_memory=conf_args.get("executor-memory", "6g"),
  53. extra_spark_config=spark_conf_dict,
  54. udf_files=conf_args['py-files'].split(",") if conf_args.__contains__('py-files') else None
  55. )
  56. spark: SparkSession = td_spark.get_spark_session()
  57. return td_spark, spark
  58. # if __name__ == "__main__":
  59. # spark: SparkSession
  60. # td_spark, spark = get_spark(sys.argv)
  61. # spark.sql("show databases").show(100, truncate=False)