spark_sql.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. # -*- coding:utf-8 -*-
  2. import inspect
  3. import re
  4. from importlib import import_module
  5. from typing import List, Union, Dict, Any, Tuple
  6. from pyspark.sql import Row, SparkSession, DataFrame
  7. from dw_base import *
  8. from dw_base.utils import common_utils
  9. from dw_base.utils.datetime_utils import get_today
  10. from dw_base.utils.file_utils import get_abs_path
  11. from dw_base.utils.log_utils import pretty_print
  12. from dw_base.utils.sql_utils import get_sql_list_from_file, check_parameter_substituted
  13. HDFS_EXPORT_DATA_PATH = '/hdfs-mnt/export-data'
  14. class SparkSQL(object):
  15. """
  16. 封装执行 Spark 相关操作的类, 相关参数说明:
  17. In Spark 2.0+ version
  18. use spark session variable to set number of executors dynamically (from within program)
  19. spark.conf.set("spark.executor.instances', 4)
  20. spark.conf.set("spark.executor.cores', 4)
  21. In above case maximum 16 tasks will be executed at any given time.
  22. other option is dynamic allocation of executors as below -
  23. spark.conf.set("spark.dynamicAllocation.enabled', "true')
  24. spark.conf.set("spark.executor.cores', 4)
  25. spark.conf.set("spark.dynamicAllocation.minExecutors',"1')
  26. spark.conf.set("spark.dynamicAllocation.maxExecutors',"5')
  27. spark.yarn.executor.memoryOverhead:default is executorMemory * 0.07, with minimum of 384
  28. spark.yarn.driver.memoryOverhead:default is driverMemory * 0.07, with minimum of 384
  29. spark.yarn.am.memoryOverhead:default is AM memory * 0.07, with minimum of 384
  30. """
  31. REGISTERED_UDF_FILES = []
  32. ADDED_RESOURCE_FILES = []
  33. REGISTERED_UDF = {}
  34. IGNORED_UDF = ['attr', 'udf']
  35. def __init__(self,
  36. session_name: str = 'spark',
  37. master: str = 'yarn',
  38. spark_yarn_queue: str = 'spark',
  39. spark_driver_memory: str = '2g',
  40. spark_executor_memory: str = '6g',
  41. spark_executor_memory_overhead: str = '512',
  42. spark_driver_cores: int = 2,
  43. spark_executor_cores: int = 2,
  44. spark_executor_instances: int = 15,
  45. spark_driver_max_result_size='4g',
  46. spark_shuffle_partitions=200,
  47. spark_default_parallelism=200,
  48. extra_spark_config: Dict[str, Any] = None,
  49. udf_files: List[str] = None,
  50. resource_files: List[str] = None):
  51. """
  52. Args:
  53. session_name:
  54. master:
  55. spark_yarn_queue:
  56. spark_driver_memory:
  57. spark_executor_memory:
  58. spark_executor_memory_overhead:
  59. spark_driver_cores:
  60. spark_executor_cores:
  61. spark_executor_instances:
  62. spark_driver_max_result_size:
  63. spark_shuffle_partitions:
  64. spark_default_parallelism:
  65. extra_spark_config: 额外的Spark配置,优先级高于SQL文件中定义的Spark配置
  66. udf_files: udf文件列表,必须传相对路径,否则不识别
  67. resource_files: 资源文件列表,必须传相对路径,否则不识别
  68. """
  69. self._spark_session = None
  70. self._session_name = session_name
  71. self._master = master
  72. self._spark_driver_memory = spark_driver_memory
  73. self._spark_executor_memory = spark_executor_memory
  74. self._spark_executor_memory_overhead = spark_executor_memory_overhead
  75. self._spark_driver_cores = spark_driver_cores
  76. self._spark_executor_cores = spark_executor_cores
  77. self._spark_executor_instances = spark_executor_instances
  78. self._spark_yarn_queue = spark_yarn_queue
  79. self._spark_driver_max_result_size = spark_driver_max_result_size
  80. self._spark_shuffle_partitions = spark_shuffle_partitions
  81. self._spark_default_parallelism = spark_default_parallelism
  82. self._extra_spark_config = extra_spark_config
  83. self._final_spark_config = {}
  84. self.limit = 20
  85. if not udf_files:
  86. udf_files = []
  87. self._udf_files = udf_files
  88. self._py_files = []
  89. if not resource_files:
  90. resource_files = []
  91. self._resource_files = resource_files
  92. self.global_start_time = None
  93. self.start_time = None
  94. def __add_spark_config(self, sql_file: str, spark_config_def: str):
  95. if not spark_config_def.__contains__('='):
  96. pretty_print(f'{NORM_YEL}无效的 Spark 配置 {NORM_GRN}{spark_config_def}')
  97. return
  98. spark_config, config_value = [e.strip() for e in spark_config_def.split('=')]
  99. if self._final_spark_config.__contains__(spark_config):
  100. pretty_print(
  101. f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_file}{NORM_YEL} 中定义的 Spark 配置 {NORM_GRN}{spark_config} '
  102. f'{NORM_YEL}重复提供,原值 {NORM_GRN}{self._final_spark_config[spark_config]} '
  103. f'{NORM_YEL}将被覆盖为新值 {NORM_GRN}{config_value}')
  104. else:
  105. self._final_spark_config[spark_config] = config_value
  106. self._final_spark_config[spark_config] = config_value
  107. def __init_spark_session(self):
  108. if self._spark_session:
  109. return
  110. pretty_print(f'{NORM_MGT}基于用户 {NORM_GRN}{USER}{NORM_MGT} 创建 SparkSession')
  111. # for element in os.environ:
  112. # pretty_print(f'{NORM_MGT}Environment {NORM_GRN}{element} => {os.environ[element]}')
  113. builder = SparkSession.builder \
  114. .appName(self._session_name) \
  115. .master(self._master) \
  116. .config('hive.exec.orc.default.block.size', 134217728) \
  117. .config('spark.debug.maxToStringFields', 5000) \
  118. .config('spark.default.parallelism', self._spark_default_parallelism) \
  119. .config('spark.driver.cores', self._spark_driver_cores) \
  120. .config('spark.driver.maxResultSize', self._spark_driver_max_result_size) \
  121. .config('spark.driver.memory', self._spark_driver_memory) \
  122. .config('spark.dynamicAllocation.enabled', False) \
  123. .config('spark.files.ignoreCorruptFiles', True) \
  124. .config('spark.executor.cores', self._spark_executor_cores) \
  125. .config('spark.executor.instances', self._spark_executor_instances) \
  126. .config('spark.executor.memory', self._spark_executor_memory) \
  127. .config('spark.executor.memoryOverhead', self._spark_executor_memory_overhead) \
  128. .config('spark.sql.adaptive.enabled', 'true') \
  129. .config('spark.sql.broadcastTimeout', -1) \
  130. .config('spark.sql.codegen.wholeStage', 'false') \
  131. .config('spark.sql.execution.arrow.enabled', True) \
  132. .config('spark.sql.execution.arrow.fallback.enabled', True) \
  133. .config('spark.sql.files.ignoreCorruptFiles', True) \
  134. .config('spark.sql.shuffle.partitions', self._spark_shuffle_partitions) \
  135. .config('spark.sql.statistics.fallBackToHdfs', True) \
  136. .config('spark.yarn.queue', self._spark_yarn_queue) \
  137. .config('spark.port.maxRetries', 999)
  138. if self._extra_spark_config:
  139. for spark_config, config_value in self._extra_spark_config.items():
  140. if self._final_spark_config.__contains__(spark_config):
  141. pretty_print(f'{NORM_YEL}构造函数传入的 Spark 配置 {NORM_GRN}{spark_config} => {config_value} '
  142. f'{NORM_YEL}覆盖了在 SQL 文件中定义的配置 '
  143. f'{NORM_GRN}{spark_config} => {self._final_spark_config[spark_config]}')
  144. self._final_spark_config[spark_config] = config_value
  145. if self._final_spark_config:
  146. for key, value in self._final_spark_config.items():
  147. pretty_print(f'{NORM_MGT}添加自定义 Spark 配置 {NORM_GRN}{key} => {str(value)}')
  148. builder.config(key, value)
  149. pretty_print(f'{NORM_MGT}创建 SparkSession')
  150. self._spark_session = builder.enableHiveSupport().getOrCreate()
  151. self._spark_session.sparkContext._jsc.hadoopConfiguration().set('mapred.max.split.size', '33554432')
  152. self._spark_session.sparkContext._jsc.hadoopConfiguration().set(
  153. 'mapreduce.fileoutputcommitter.marksuccessfuljobs', 'false')
  154. first_py_file = 'dw_base.zip'
  155. if IS_RUN_IN_RELEASE_DIR and IS_RUN_BY_RELEASE_USER:
  156. command = f'cd {PROJECT_ROOT_PATH} && if [ ! -f {first_py_file} ];then zip -qr {first_py_file} dw_base; fi'
  157. else:
  158. command = f'cd {PROJECT_ROOT_PATH} && rm -f {first_py_file} && zip -qr {first_py_file} dw_base'
  159. pretty_print(f'{NORM_MGT}执行 Shell 命令 {NORM_GRN}{command}')
  160. os.system(command)
  161. self._spark_session.sparkContext.addPyFile(get_abs_path(first_py_file))
  162. self.register_udf_files(self._udf_files)
  163. for py_file in self._py_files:
  164. self.add_py_file(py_file)
  165. for resource_file in self._resource_files:
  166. self.add_resource_file(resource_file)
  167. @staticmethod
  168. def add_parameters(sql_or_file: str, parameters: Dict, key: str, value: Any):
  169. if parameters.__contains__(key):
  170. pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} '
  171. f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} '
  172. f'{NORM_YEL}发生变更,将使用新值 {NORM_GRN}{Any} '
  173. f'{NORM_YEL}替代原值 {NORM_GRN}{parameters[key]}')
  174. parameters[key] = value
  175. def add_py_file(self, py_file: str) -> bool:
  176. abs_py_file = get_abs_path(py_file)
  177. if self.REGISTERED_UDF_FILES.__contains__(abs_py_file):
  178. return False
  179. if not self._spark_session:
  180. self._py_files.append(py_file)
  181. return True
  182. pretty_print(f'{NORM_MGT}添加 Python 文件 {NORM_GRN}{py_file}')
  183. self._spark_session.sparkContext.addPyFile(abs_py_file)
  184. self.REGISTERED_UDF_FILES.append(abs_py_file)
  185. return True
  186. def add_resource_file(self, resource_file: str) -> bool:
  187. abs_resource_file = get_abs_path(resource_file)
  188. if self.ADDED_RESOURCE_FILES.__contains__(abs_resource_file):
  189. return False
  190. if not self._spark_session:
  191. self._resource_files.append(resource_file)
  192. return True
  193. pretty_print(f'{NORM_MGT}添加资源文件 {NORM_GRN}{resource_file}')
  194. self._spark_session.sparkContext.addFile(abs_resource_file)
  195. self.ADDED_RESOURCE_FILES.append(abs_resource_file)
  196. return True
  197. def execute(self,
  198. sql_or_file: str,
  199. check_parameter: bool = False,
  200. silent: bool = False,
  201. fill_null=None,
  202. **kwargs):
  203. """
  204. 仅执行,不返回任何值
  205. Args:
  206. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件路径(相对或绝对都可以)
  207. check_parameter: 运行前是否进行参数检查
  208. silent: 执行时是否打印日志,True则不打印
  209. fill_null: 填充null值
  210. **kwargs: sql 语句或 sql 文件所用到的参数
  211. Returns:
  212. """
  213. # 运行SQL语句或SQL文件
  214. start_time = time.time()
  215. data_frame, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs)
  216. if is_select:
  217. self.show_data_frame(data_frame, show_number=self.limit, truncate=USER == RELEASE_USER)
  218. end_time = time.time()
  219. cost = round(float(end_time - self.start_time), 2)
  220. pretty_print(f'{NORM_MGT}SQL 语句执行完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒')
  221. if sql_or_file.endswith('.sql'):
  222. total_cost = round(float(end_time - self.global_start_time), 2)
  223. pretty_print(f'{NORM_MGT}SQL 文件 {NORM_GRN}{sql_or_file}'
  224. f'{NORM_MGT} 执行完毕,共耗时 {NORM_GRN}{str(total_cost)}{NORM_MGT} 秒')
  225. self.global_start_time = None
  226. def export_data(self,
  227. data_set_name: str,
  228. sql_or_file: str,
  229. show_number: int = 100,
  230. truncate: Union[bool, int] = 40,
  231. delimiter: str = ',',
  232. partition: int = 1,
  233. **kwargs) -> str:
  234. """
  235. 导出数据,默认存储于HDFS目录(/hdfs-mnt/export-data),映射本地目录(/opt/hdfs-mnt/export-data)
  236. Args:
  237. data_set_name: 导出的数据集名称标识(作为目录)
  238. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径
  239. show_number: 显示的行数
  240. truncate: If set to ``True``, truncate strings longer than 20 chars by default.
  241. If set to a number greater than one, truncates long strings to length ``truncate``
  242. and align cells right.
  243. delimiter: 分隔符
  244. partition: 导出文件个数
  245. **kwargs: sql 语句或 sql 文件所用到的参数
  246. Returns:
  247. """
  248. start_time = time.time()
  249. data_frame, _ = self.query(sql_or_file, **kwargs)
  250. data_frame.persist()
  251. self.show_data_frame(data_frame, data_set_name, show_number, truncate)
  252. today = get_today()
  253. # 注意: 此目录不是linux服务器的目录, 而是hdfs文件系统的目录
  254. hdfs_directory = f'{data_set_name}.{str(int(time.time()))}'
  255. hdfs_export_path = os.path.join(HDFS_EXPORT_DATA_PATH, USER, today, hdfs_directory)
  256. # linux服务器导出文件存放的目录
  257. # local_export_path = os.path.join(LOCAL_EXPORT_DATA_PATH, USER, today, data_set_name)
  258. pretty_print(f'{NORM_MGT}准备导出文件到HDFS目录: {NORM_GRN}{hdfs_export_path}')
  259. if not delimiter:
  260. delimiter = ','
  261. # 先将数据保存至hdfs文件系统, 加 option("escape", "\"") 不会导致csv文件“字段文本内容中的英文逗号来分割字段(列膨胀错位)”
  262. data_frame \
  263. .repartition(partition) \
  264. .write \
  265. .mode('overwrite') \
  266. .format('com.databricks.spark.csv') \
  267. .option('header', 'true') \
  268. .option("escape", "\"") \
  269. .option('delimiter', delimiter) \
  270. .save(hdfs_export_path)
  271. cost = round(float(time.time() - start_time), 2)
  272. pretty_print(f'{NORM_MGT}数据导出完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒')
  273. return hdfs_export_path
  274. def get_columns(self, hive_table_name: str):
  275. desc_df, _ = self.query(f'desc {hive_table_name}', silent=True)
  276. desc_rows = desc_df.collect()
  277. hive_columns = {}
  278. partition_col_start = False
  279. for row in desc_rows: # type: Row
  280. if row[0] == '# Partition Information':
  281. partition_col_start = True
  282. continue
  283. if row[0] == '# col_name':
  284. continue
  285. if partition_col_start:
  286. del hive_columns[list(hive_columns.keys())[len(hive_columns.keys()) - 1]]
  287. else:
  288. hive_columns[row[0]] = (row[1], row[2])
  289. return hive_columns
  290. def get_partition_columns(self, hive_table_name: str):
  291. desc_df, _ = self.query(f'desc {hive_table_name}', silent=True)
  292. desc_rows = desc_df.collect()
  293. partition_columns = {}
  294. partitioned = False
  295. for row in desc_rows: # type: Row
  296. if row[0] != '# Partition Information':
  297. continue
  298. if row[0].startswith('#'):
  299. continue
  300. partitioned = True
  301. partition_columns[row[0]] = (row[1], row[2])
  302. return partitioned, partition_columns
  303. def get_spark_session(self):
  304. if not self._spark_session:
  305. self.__init_spark_session()
  306. self.start_time = time.time()
  307. if not self.global_start_time:
  308. self.global_start_time = time.time()
  309. return self._spark_session
  310. def jdbc_read(self, jdbc_url: str, table_name: str) -> DataFrame:
  311. return self.get_spark_session().read.jdbc(jdbc_url, table_name)
  312. def list_tables(self, hive_database_name: str = 'default', include_regex: str = None, exclude_regex: str = None):
  313. if not hive_database_name:
  314. hive_database_name = 'default'
  315. tables_df, _ = self.query(f'show tables in {hive_database_name}')
  316. tables = []
  317. for row in tables_df.collect():
  318. if hive_database_name == 'default':
  319. table_name = row.asDict()['tableName']
  320. else:
  321. table_name = f"{hive_database_name}.{row.asDict()['tableName']}"
  322. if exclude_regex and re.match(exclude_regex, table_name):
  323. continue
  324. if include_regex:
  325. if re.match(include_regex, table_name):
  326. tables.append(table_name)
  327. continue
  328. else:
  329. tables.append(table_name)
  330. return tables
  331. def query(self,
  332. sql_or_file: str,
  333. check_parameter: bool = False,
  334. silent: bool = False,
  335. fill_null=None,
  336. **kwargs) -> Tuple[DataFrame, bool]:
  337. """
  338. 执行查询,返回一个DataFrame和该DataFrame是否是SELECT
  339. Args:
  340. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径
  341. check_parameter: 运行前是否进行参数检查
  342. silent: 执行时是否打印日志,True则不打印
  343. fill_null: 填充null值
  344. **kwargs: sql 语句或 sql 文件所用到的参数
  345. Returns: Tuple[DataFrame, bool]
  346. """
  347. if sql_or_file.endswith('.sql'):
  348. if not silent:
  349. pretty_print(f'{NORM_MGT}准备执行 SQL 文件 {NORM_GRN}{sql_or_file}')
  350. sql_list = get_sql_list_from_file(sql_or_file, trim_comment=True)
  351. else:
  352. sql_list = [sql_or_file]
  353. cleaned_sql_list = []
  354. # 外部传入的参数
  355. external_parameters = {}
  356. if kwargs:
  357. for key, value in kwargs.items():
  358. if not silent:
  359. pretty_print(f'{NORM_MGT}收到外部 SQL 参数 {NORM_GRN}{key} => {value}')
  360. external_parameters[key] = value
  361. # SQL文件内部定义的参数
  362. internal_parameters = {}
  363. for sql in sql_list:
  364. upper_sql = sql.upper()
  365. if upper_sql.startswith('ADD FILES'):
  366. file_names = sql[len('ADD FILES'):].strip().split(' ')
  367. for file_name in file_names:
  368. file_name = file_name.strip()
  369. if not file_name:
  370. continue
  371. if file_name.endswith('.py'):
  372. self.register_udf_file(file_name)
  373. else:
  374. self.add_resource_file(file_name)
  375. elif upper_sql.startswith('ADD FILE'):
  376. file_name = sql[len('ADD FILE'):].strip()
  377. if file_name.endswith('.py'):
  378. self.register_udf_file(file_name)
  379. else:
  380. self.add_resource_file(file_name)
  381. elif upper_sql.startswith('ADD SPARK_UDF'):
  382. self.register_udf_file(sql[len('ADD SPARK_UDF'):].strip())
  383. elif upper_sql.startswith('SET SPARK_VAR:'):
  384. var_key, var_value = sql[len('SET SPARK_VAR:'):].strip().split('=')
  385. var_key = var_key.strip()
  386. var_value = var_value.strip()
  387. self.add_parameters(sql_or_file, internal_parameters, var_key, var_value)
  388. elif upper_sql.startswith('SET LIMIT='):
  389. if not self._spark_session:
  390. self.limit = int(sql.split('=')[1].strip())
  391. elif upper_sql.startswith('SET '):
  392. set_expr = sql[len('SET '):].strip()
  393. set_expr_splits = set_expr.split('=')
  394. set_key = set_expr_splits[0].strip()
  395. set_value = set_expr_splits[1].strip()
  396. if re.match(r'.+\..+(\..+)*', set_key):
  397. if not self._spark_session:
  398. self.__add_spark_config(sql_or_file, sql[len('SET '):].strip())
  399. else:
  400. self.add_parameters(sql_or_file, internal_parameters, set_key, set_value)
  401. else:
  402. # 优先使用外部参数
  403. for key, value in external_parameters.items():
  404. sql = sql.replace('${%s}' % key, value)
  405. # SQL文件里的参数必须在使用前定义(同时也意味着相同的参数可以多次定义不同值,供多个SQL使用)
  406. for key, value in internal_parameters.items():
  407. if external_parameters.__contains__(key):
  408. # 如果外部参数中包括当前参数,则使用外部参数(实际上已在上一步替换,此处仅为提示)
  409. pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} '
  410. f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} '
  411. f'{NORM_YEL}已由外部参数抢占,将使用新值 {NORM_GRN}{external_parameters[key]} '
  412. f'{NORM_YEL}替代原值 {NORM_GRN}{internal_parameters[key]}')
  413. continue
  414. sql = sql.replace('${%s}' % key, value)
  415. cleaned_sql_list.append(sql)
  416. if len(cleaned_sql_list) == 0:
  417. return self.get_spark_session().sql("SELECT ''"), False
  418. for index in range(len(cleaned_sql_list) - 1):
  419. sql = cleaned_sql_list[index]
  420. self.execute(sql)
  421. last_or_only_sql = cleaned_sql_list[-1]
  422. check_parameter_substituted(last_or_only_sql, check_parameter)
  423. self.get_spark_session()
  424. if not silent:
  425. if last_or_only_sql.__contains__('\n'):
  426. pretty_print(f'{NORM_MGT}开始执行 SQL 语句: \n{NORM_GRN}{last_or_only_sql}')
  427. else:
  428. pretty_print(f'{NORM_MGT}开始执行 SQL 语句: {NORM_GRN}{last_or_only_sql}')
  429. data_frame = self._spark_session.sql(last_or_only_sql)
  430. if fill_null:
  431. data_frame.fillna(fill_null)
  432. return data_frame, len(data_frame.schema.fields) > 0
  433. def query_scalar(self,
  434. sql_or_file: str,
  435. check_parameter: bool = False,
  436. silent: bool = False,
  437. fill_null=None,
  438. **kwargs):
  439. data, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs)
  440. if is_select is not True:
  441. raise Exception('Get scalar from non-query statement')
  442. res = data.collect()
  443. res1 = res[0]
  444. return res[0][0]
  445. def register_udf_files(self, udf_file_list: List[str]):
  446. if not udf_file_list:
  447. return
  448. for udf_file in udf_file_list:
  449. self.register_udf_file(udf_file)
  450. def register_udf_file(self, udf_file: str):
  451. if not udf_file:
  452. return
  453. if not self.add_py_file(udf_file):
  454. pretty_print(f'{NORM_YEL}Python文件 {NORM_GRN}{udf_file}{NORM_YEL} 已被添加过(重复添加)')
  455. return
  456. if not self._spark_session:
  457. self._udf_files.append(udf_file)
  458. return
  459. pretty_print(f'{NORM_MGT}注册文件 {NORM_GRN}{udf_file}{NORM_MGT} 中的 Python UDF')
  460. udf_module_name = udf_file.replace('/', '.').replace('.py', '')
  461. module = import_module(udf_module_name)
  462. for (name, attr) in inspect.getmembers(module):
  463. if self.IGNORED_UDF.__contains__(name):
  464. continue
  465. if self.REGISTERED_UDF.__contains__(name):
  466. pretty_print(f'{NORM_YEL}名为 {NORM_GRN}{name}{NORM_YEL} 的 Python UDF '
  467. f'{NORM_YEL}已在文件 {NORM_GRN}{self.REGISTERED_UDF[name]} '
  468. f'{NORM_YEL}被注册(重名UDF以首次出现的为准)')
  469. continue
  470. if inspect.isfunction(attr):
  471. # todo: 注册每个udf的返回值类型
  472. pretty_print(f'{NORM_MGT}注册 Python UDF {NORM_GRN}{name}')
  473. self._spark_session.udf.register(name, attr)
  474. self.REGISTERED_UDF[name] = udf_file
  475. @staticmethod
  476. def show_data_frame(data_frame: DataFrame,
  477. data_frame_name: str = None,
  478. show_number: int = 20,
  479. truncate: Union[bool, int] = False):
  480. """
  481. 打印``DataFrame``的Schema和指定行数的数据
  482. Args:
  483. data_frame: DataFrame
  484. data_frame_name: DataFrame的名称
  485. show_number: 打印的行数
  486. truncate: If set to ``True``, truncate strings longer than 20 chars by default.
  487. If set to a number greater than one, truncates long strings to length ``truncate``
  488. and align cells right.
  489. Returns:
  490. """
  491. if data_frame_name:
  492. pretty_print(f'{NORM_MGT}%s %s DataFrame Information: %s ' % ('=' * 30, data_frame_name, '=' * 30))
  493. else:
  494. pretty_print(f'{NORM_MGT}%s DataFrame Information: %s ' % ('=' * 30, '=' * 30))
  495. pretty_print(f'{NORM_MGT}Schema :')
  496. data_frame.printSchema()
  497. pretty_print(f'{NORM_MGT}Data sample :')
  498. data_frame.show(show_number, truncate=truncate)
  499. pretty_print(f'{NORM_MGT}=' * 70)
  500. if __name__ == '__main__':
  501. with SparkSQL('spark-sql') as spark_sql:
  502. # spark_sql.execute('workspace/x.sql')
  503. # spark_sql._spark_session.conf.set('app.name', 'tmp.sa33')
  504. # spark_sql.execute('workspace/y.sql')
  505. # spark_sql.query('show databases')[0].show(100, truncate=False)
  506. # spark_sql.query('show functions')[0].show(1000, truncate=False)
  507. spark_sql.get_columns('crl_ads.ads_crl_dim_source_stat')