spark_sql.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. # -*- coding:utf-8 -*-
  2. import inspect
  3. import re
  4. from importlib import import_module
  5. from typing import List, Optional, Union, Dict, Any, Tuple
  6. from pyspark.sql import Row, SparkSession, DataFrame
  7. from dw_base import *
  8. from dw_base.utils import common_utils
  9. from dw_base.utils.datetime_utils import get_today
  10. from dw_base.utils.file_utils import get_abs_path
  11. from dw_base.utils.log_utils import pretty_print
  12. from dw_base.utils.sql_utils import get_sql_list_from_file, check_parameter_substituted
  13. HDFS_EXPORT_DATA_PATH = '/hdfs-mnt/export-data'
  14. def _load_spark_conf_file(path: str) -> Dict[str, str]:
  15. """读 Spark 原生 conf(每行 `key value`,# 注释,空白分隔)。文件缺失或行非法时返回空 dict,不抛错。"""
  16. if not os.path.isfile(path):
  17. return {}
  18. config = {}
  19. with open(path, 'r') as f:
  20. for line in f:
  21. line = line.strip()
  22. if not line or line.startswith('#'):
  23. continue
  24. parts = line.split(None, 1)
  25. if len(parts) != 2:
  26. continue
  27. config[parts[0]] = parts[1]
  28. return config
  29. class SparkSQL(object):
  30. """封装 Spark 会话与 SQL 执行;参数三级覆盖见 conf/spark-defaults.conf + spark-tuning.conf 与 kb/00 §4.2。"""
  31. REGISTERED_UDF_FILES = []
  32. ADDED_RESOURCE_FILES = []
  33. REGISTERED_UDF = {}
  34. IGNORED_UDF = ['attr', 'udf']
  35. def __init__(self,
  36. session_name: str = 'spark',
  37. master: str = 'yarn',
  38. spark_yarn_queue: Optional[str] = None,
  39. spark_driver_memory: Optional[str] = None,
  40. spark_executor_memory: Optional[str] = None,
  41. spark_executor_memory_overhead: Optional[str] = None,
  42. spark_driver_cores: Optional[int] = None,
  43. spark_executor_cores: Optional[int] = None,
  44. spark_executor_instances: Optional[int] = None,
  45. spark_driver_max_result_size: Optional[str] = None,
  46. spark_shuffle_partitions: Optional[int] = None,
  47. spark_default_parallelism: Optional[int] = None,
  48. extra_spark_config: Dict[str, Any] = None,
  49. udf_files: List[str] = None,
  50. resource_files: List[str] = None):
  51. """
  52. Args:
  53. session_name:
  54. master:
  55. spark_yarn_queue:
  56. spark_driver_memory:
  57. spark_executor_memory:
  58. spark_executor_memory_overhead:
  59. spark_driver_cores:
  60. spark_executor_cores:
  61. spark_executor_instances:
  62. spark_driver_max_result_size:
  63. spark_shuffle_partitions:
  64. spark_default_parallelism:
  65. extra_spark_config: 额外的Spark配置,优先级高于SQL文件中定义的Spark配置
  66. udf_files: udf文件列表,必须传相对路径,否则不识别
  67. resource_files: 资源文件列表,必须传相对路径,否则不识别
  68. """
  69. self._spark_session = None
  70. self._session_name = session_name
  71. self._master = master
  72. self._spark_driver_memory = spark_driver_memory
  73. self._spark_executor_memory = spark_executor_memory
  74. self._spark_executor_memory_overhead = spark_executor_memory_overhead
  75. self._spark_driver_cores = spark_driver_cores
  76. self._spark_executor_cores = spark_executor_cores
  77. self._spark_executor_instances = spark_executor_instances
  78. self._spark_yarn_queue = spark_yarn_queue
  79. self._spark_driver_max_result_size = spark_driver_max_result_size
  80. self._spark_shuffle_partitions = spark_shuffle_partitions
  81. self._spark_default_parallelism = spark_default_parallelism
  82. self._extra_spark_config = extra_spark_config
  83. self._final_spark_config = {}
  84. self.limit = 20
  85. if not udf_files:
  86. udf_files = []
  87. self._udf_files = udf_files
  88. self._py_files = []
  89. if not resource_files:
  90. resource_files = []
  91. self._resource_files = resource_files
  92. self.global_start_time = None
  93. self.start_time = None
  94. def __add_spark_config(self, sql_file: str, spark_config_def: str):
  95. if not spark_config_def.__contains__('='):
  96. pretty_print(f'{NORM_YEL}无效的 Spark 配置 {NORM_GRN}{spark_config_def}')
  97. return
  98. spark_config, config_value = [e.strip() for e in spark_config_def.split('=')]
  99. if self._final_spark_config.__contains__(spark_config):
  100. pretty_print(
  101. f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_file}{NORM_YEL} 中定义的 Spark 配置 {NORM_GRN}{spark_config} '
  102. f'{NORM_YEL}重复提供,原值 {NORM_GRN}{self._final_spark_config[spark_config]} '
  103. f'{NORM_YEL}将被覆盖为新值 {NORM_GRN}{config_value}')
  104. else:
  105. self._final_spark_config[spark_config] = config_value
  106. self._final_spark_config[spark_config] = config_value
  107. def __init_spark_session(self):
  108. if self._spark_session:
  109. return
  110. pretty_print(f'{NORM_MGT}基于用户 {NORM_GRN}{USER}{NORM_MGT} 创建 SparkSession')
  111. builder = SparkSession.builder \
  112. .appName(self._session_name) \
  113. .master(self._master)
  114. # L1:conf/spark-defaults.conf(底层)+ conf/spark-tuning.conf(调优,相同 key 覆盖 defaults)
  115. l1_defaults = {}
  116. l1_defaults.update(_load_spark_conf_file(f'{PROJECT_ROOT_PATH}/conf/spark-defaults.conf'))
  117. l1_defaults.update(_load_spark_conf_file(f'{PROJECT_ROOT_PATH}/conf/spark-tuning.conf'))
  118. pretty_print(f'{NORM_MGT}L1 加载 {NORM_GRN}{len(l1_defaults)}{NORM_MGT} 条 conf 默认')
  119. for key, value in l1_defaults.items():
  120. pretty_print(f'{NORM_MGT}L1 应用 conf 默认 {NORM_GRN}{key} => {str(value)}')
  121. builder.config(key, value)
  122. # L2:SQL 内 SET(query() 预扫描,session 启动前写入 builder)
  123. for key, value in self._final_spark_config.items():
  124. pretty_print(f'{NORM_MGT}L2 应用 SQL SET {NORM_GRN}{key} => {str(value)}')
  125. builder.config(key, value)
  126. # L3:构造函数显式传参 + extra_spark_config
  127. l3_overrides: Dict[str, Any] = {}
  128. for conf_key, attr_val in (
  129. ('spark.yarn.queue', self._spark_yarn_queue),
  130. ('spark.driver.memory', self._spark_driver_memory),
  131. ('spark.executor.memory', self._spark_executor_memory),
  132. ('spark.executor.memoryOverhead', self._spark_executor_memory_overhead),
  133. ('spark.driver.cores', self._spark_driver_cores),
  134. ('spark.executor.cores', self._spark_executor_cores),
  135. ('spark.executor.instances', self._spark_executor_instances),
  136. ('spark.driver.maxResultSize', self._spark_driver_max_result_size),
  137. ('spark.sql.shuffle.partitions', self._spark_shuffle_partitions),
  138. ('spark.default.parallelism', self._spark_default_parallelism),
  139. ):
  140. if attr_val is not None:
  141. l3_overrides[conf_key] = attr_val
  142. if self._extra_spark_config:
  143. l3_overrides.update(self._extra_spark_config)
  144. for key, value in l3_overrides.items():
  145. pretty_print(f'{NORM_MGT}L3 应用构造参数/extra {NORM_GRN}{key} => {str(value)}')
  146. builder.config(key, value)
  147. pretty_print(f'{NORM_MGT}创建 SparkSession')
  148. self._spark_session = builder.enableHiveSupport().getOrCreate()
  149. self._spark_session.sparkContext._jsc.hadoopConfiguration().set('mapred.max.split.size', '33554432')
  150. self._spark_session.sparkContext._jsc.hadoopConfiguration().set(
  151. 'mapreduce.fileoutputcommitter.marksuccessfuljobs', 'false')
  152. first_py_file = 'dw_base.zip'
  153. if IS_RUN_IN_RELEASE_DIR and IS_RUN_BY_RELEASE_USER:
  154. command = f'cd {PROJECT_ROOT_PATH} && if [ ! -f {first_py_file} ];then zip -qr {first_py_file} dw_base; fi'
  155. else:
  156. command = f'cd {PROJECT_ROOT_PATH} && rm -f {first_py_file} && zip -qr {first_py_file} dw_base'
  157. pretty_print(f'{NORM_MGT}执行 Shell 命令 {NORM_GRN}{command}')
  158. os.system(command)
  159. self._spark_session.sparkContext.addPyFile(get_abs_path(first_py_file))
  160. self.register_udf_files(self._udf_files)
  161. for py_file in self._py_files:
  162. self.add_py_file(py_file)
  163. for resource_file in self._resource_files:
  164. self.add_resource_file(resource_file)
  165. @staticmethod
  166. def add_parameters(sql_or_file: str, parameters: Dict, key: str, value: Any):
  167. if parameters.__contains__(key):
  168. pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} '
  169. f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} '
  170. f'{NORM_YEL}发生变更,将使用新值 {NORM_GRN}{Any} '
  171. f'{NORM_YEL}替代原值 {NORM_GRN}{parameters[key]}')
  172. parameters[key] = value
  173. def add_py_file(self, py_file: str) -> bool:
  174. abs_py_file = get_abs_path(py_file)
  175. if self.REGISTERED_UDF_FILES.__contains__(abs_py_file):
  176. return False
  177. if not self._spark_session:
  178. self._py_files.append(py_file)
  179. return True
  180. pretty_print(f'{NORM_MGT}添加 Python 文件 {NORM_GRN}{py_file}')
  181. self._spark_session.sparkContext.addPyFile(abs_py_file)
  182. self.REGISTERED_UDF_FILES.append(abs_py_file)
  183. return True
  184. def add_resource_file(self, resource_file: str) -> bool:
  185. abs_resource_file = get_abs_path(resource_file)
  186. if self.ADDED_RESOURCE_FILES.__contains__(abs_resource_file):
  187. return False
  188. if not self._spark_session:
  189. self._resource_files.append(resource_file)
  190. return True
  191. pretty_print(f'{NORM_MGT}添加资源文件 {NORM_GRN}{resource_file}')
  192. self._spark_session.sparkContext.addFile(abs_resource_file)
  193. self.ADDED_RESOURCE_FILES.append(abs_resource_file)
  194. return True
  195. def execute(self,
  196. sql_or_file: str,
  197. check_parameter: bool = False,
  198. silent: bool = False,
  199. fill_null=None,
  200. **kwargs):
  201. """
  202. 仅执行,不返回任何值
  203. Args:
  204. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件路径(相对或绝对都可以)
  205. check_parameter: 运行前是否进行参数检查
  206. silent: 执行时是否打印日志,True则不打印
  207. fill_null: 填充null值
  208. **kwargs: sql 语句或 sql 文件所用到的参数
  209. Returns:
  210. """
  211. # 运行SQL语句或SQL文件
  212. start_time = time.time()
  213. data_frame, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs)
  214. if is_select:
  215. self.show_data_frame(data_frame, show_number=self.limit, truncate=USER == RELEASE_USER)
  216. end_time = time.time()
  217. cost = round(float(end_time - self.start_time), 2)
  218. pretty_print(f'{NORM_MGT}SQL 语句执行完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒')
  219. if sql_or_file.endswith('.sql'):
  220. total_cost = round(float(end_time - self.global_start_time), 2)
  221. pretty_print(f'{NORM_MGT}SQL 文件 {NORM_GRN}{sql_or_file}'
  222. f'{NORM_MGT} 执行完毕,共耗时 {NORM_GRN}{str(total_cost)}{NORM_MGT} 秒')
  223. self.global_start_time = None
  224. def export_data(self,
  225. data_set_name: str,
  226. sql_or_file: str,
  227. show_number: int = 100,
  228. truncate: Union[bool, int] = 40,
  229. delimiter: str = ',',
  230. partition: int = 1,
  231. **kwargs) -> str:
  232. """
  233. 导出数据,默认存储于HDFS目录(/hdfs-mnt/export-data),映射本地目录(/opt/hdfs-mnt/export-data)
  234. Args:
  235. data_set_name: 导出的数据集名称标识(作为目录)
  236. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径
  237. show_number: 显示的行数
  238. truncate: If set to ``True``, truncate strings longer than 20 chars by default.
  239. If set to a number greater than one, truncates long strings to length ``truncate``
  240. and align cells right.
  241. delimiter: 分隔符
  242. partition: 导出文件个数
  243. **kwargs: sql 语句或 sql 文件所用到的参数
  244. Returns:
  245. """
  246. start_time = time.time()
  247. data_frame, _ = self.query(sql_or_file, **kwargs)
  248. data_frame.persist()
  249. self.show_data_frame(data_frame, data_set_name, show_number, truncate)
  250. today = get_today()
  251. # 注意: 此目录不是linux服务器的目录, 而是hdfs文件系统的目录
  252. hdfs_directory = f'{data_set_name}.{str(int(time.time()))}'
  253. hdfs_export_path = os.path.join(HDFS_EXPORT_DATA_PATH, USER, today, hdfs_directory)
  254. # linux服务器导出文件存放的目录
  255. # local_export_path = os.path.join(LOCAL_EXPORT_DATA_PATH, USER, today, data_set_name)
  256. pretty_print(f'{NORM_MGT}准备导出文件到HDFS目录: {NORM_GRN}{hdfs_export_path}')
  257. if not delimiter:
  258. delimiter = ','
  259. # 先将数据保存至hdfs文件系统, 加 option("escape", "\"") 不会导致csv文件“字段文本内容中的英文逗号来分割字段(列膨胀错位)”
  260. data_frame \
  261. .repartition(partition) \
  262. .write \
  263. .mode('overwrite') \
  264. .format('com.databricks.spark.csv') \
  265. .option('header', 'true') \
  266. .option("escape", "\"") \
  267. .option('delimiter', delimiter) \
  268. .save(hdfs_export_path)
  269. cost = round(float(time.time() - start_time), 2)
  270. pretty_print(f'{NORM_MGT}数据导出完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒')
  271. return hdfs_export_path
  272. def get_columns(self, hive_table_name: str):
  273. desc_df, _ = self.query(f'desc {hive_table_name}', silent=True)
  274. desc_rows = desc_df.collect()
  275. hive_columns = {}
  276. partition_col_start = False
  277. for row in desc_rows: # type: Row
  278. if row[0] == '# Partition Information':
  279. partition_col_start = True
  280. continue
  281. if row[0] == '# col_name':
  282. continue
  283. if partition_col_start:
  284. del hive_columns[list(hive_columns.keys())[len(hive_columns.keys()) - 1]]
  285. else:
  286. hive_columns[row[0]] = (row[1], row[2])
  287. return hive_columns
  288. def get_partition_columns(self, hive_table_name: str):
  289. desc_df, _ = self.query(f'desc {hive_table_name}', silent=True)
  290. desc_rows = desc_df.collect()
  291. partition_columns = {}
  292. partitioned = False
  293. for row in desc_rows: # type: Row
  294. if row[0] != '# Partition Information':
  295. continue
  296. if row[0].startswith('#'):
  297. continue
  298. partitioned = True
  299. partition_columns[row[0]] = (row[1], row[2])
  300. return partitioned, partition_columns
  301. def get_spark_session(self):
  302. if not self._spark_session:
  303. self.__init_spark_session()
  304. self.start_time = time.time()
  305. if not self.global_start_time:
  306. self.global_start_time = time.time()
  307. return self._spark_session
  308. def jdbc_read(self, jdbc_url: str, table_name: str) -> DataFrame:
  309. return self.get_spark_session().read.jdbc(jdbc_url, table_name)
  310. def list_tables(self, hive_database_name: str = 'default', include_regex: str = None, exclude_regex: str = None):
  311. if not hive_database_name:
  312. hive_database_name = 'default'
  313. tables_df, _ = self.query(f'show tables in {hive_database_name}')
  314. tables = []
  315. for row in tables_df.collect():
  316. if hive_database_name == 'default':
  317. table_name = row.asDict()['tableName']
  318. else:
  319. table_name = f"{hive_database_name}.{row.asDict()['tableName']}"
  320. if exclude_regex and re.match(exclude_regex, table_name):
  321. continue
  322. if include_regex:
  323. if re.match(include_regex, table_name):
  324. tables.append(table_name)
  325. continue
  326. else:
  327. tables.append(table_name)
  328. return tables
  329. def query(self,
  330. sql_or_file: str,
  331. check_parameter: bool = False,
  332. silent: bool = False,
  333. fill_null=None,
  334. **kwargs) -> Tuple[DataFrame, bool]:
  335. """
  336. 执行查询,返回一个DataFrame和该DataFrame是否是SELECT
  337. Args:
  338. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径
  339. check_parameter: 运行前是否进行参数检查
  340. silent: 执行时是否打印日志,True则不打印
  341. fill_null: 填充null值
  342. **kwargs: sql 语句或 sql 文件所用到的参数
  343. Returns: Tuple[DataFrame, bool]
  344. """
  345. if sql_or_file.endswith('.sql'):
  346. if not silent:
  347. pretty_print(f'{NORM_MGT}准备执行 SQL 文件 {NORM_GRN}{sql_or_file}')
  348. sql_list = get_sql_list_from_file(sql_or_file, trim_comment=True)
  349. else:
  350. sql_list = [sql_or_file]
  351. cleaned_sql_list = []
  352. # 外部传入的参数
  353. external_parameters = {}
  354. if kwargs:
  355. for key, value in kwargs.items():
  356. if not silent:
  357. pretty_print(f'{NORM_MGT}收到外部 SQL 参数 {NORM_GRN}{key} => {value}')
  358. external_parameters[key] = value
  359. # SQL文件内部定义的参数
  360. internal_parameters = {}
  361. for sql in sql_list:
  362. upper_sql = sql.upper()
  363. if upper_sql.startswith('ADD FILES'):
  364. file_names = sql[len('ADD FILES'):].strip().split(' ')
  365. for file_name in file_names:
  366. file_name = file_name.strip()
  367. if not file_name:
  368. continue
  369. if file_name.endswith('.py'):
  370. self.register_udf_file(file_name)
  371. else:
  372. self.add_resource_file(file_name)
  373. elif upper_sql.startswith('ADD FILE'):
  374. file_name = sql[len('ADD FILE'):].strip()
  375. if file_name.endswith('.py'):
  376. self.register_udf_file(file_name)
  377. else:
  378. self.add_resource_file(file_name)
  379. elif upper_sql.startswith('ADD SPARK_UDF'):
  380. self.register_udf_file(sql[len('ADD SPARK_UDF'):].strip())
  381. elif upper_sql.startswith('SET SPARK_VAR:'):
  382. var_key, var_value = sql[len('SET SPARK_VAR:'):].strip().split('=')
  383. var_key = var_key.strip()
  384. var_value = var_value.strip()
  385. self.add_parameters(sql_or_file, internal_parameters, var_key, var_value)
  386. elif upper_sql.startswith('SET LIMIT='):
  387. if not self._spark_session:
  388. self.limit = int(sql.split('=')[1].strip())
  389. elif upper_sql.startswith('SET '):
  390. set_expr = sql[len('SET '):].strip()
  391. set_expr_splits = set_expr.split('=')
  392. set_key = set_expr_splits[0].strip()
  393. set_value = set_expr_splits[1].strip()
  394. if re.match(r'.+\..+(\..+)*', set_key):
  395. if not self._spark_session:
  396. self.__add_spark_config(sql_or_file, sql[len('SET '):].strip())
  397. else:
  398. self.add_parameters(sql_or_file, internal_parameters, set_key, set_value)
  399. else:
  400. # 优先使用外部参数
  401. for key, value in external_parameters.items():
  402. sql = sql.replace('${%s}' % key, value)
  403. # SQL文件里的参数必须在使用前定义(同时也意味着相同的参数可以多次定义不同值,供多个SQL使用)
  404. for key, value in internal_parameters.items():
  405. if external_parameters.__contains__(key):
  406. # 如果外部参数中包括当前参数,则使用外部参数(实际上已在上一步替换,此处仅为提示)
  407. pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} '
  408. f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} '
  409. f'{NORM_YEL}已由外部参数抢占,将使用新值 {NORM_GRN}{external_parameters[key]} '
  410. f'{NORM_YEL}替代原值 {NORM_GRN}{internal_parameters[key]}')
  411. continue
  412. sql = sql.replace('${%s}' % key, value)
  413. cleaned_sql_list.append(sql)
  414. if len(cleaned_sql_list) == 0:
  415. return self.get_spark_session().sql("SELECT ''"), False
  416. for index in range(len(cleaned_sql_list) - 1):
  417. sql = cleaned_sql_list[index]
  418. self.execute(sql)
  419. last_or_only_sql = cleaned_sql_list[-1]
  420. check_parameter_substituted(last_or_only_sql, check_parameter)
  421. self.get_spark_session()
  422. if not silent:
  423. if last_or_only_sql.__contains__('\n'):
  424. pretty_print(f'{NORM_MGT}开始执行 SQL 语句: \n{NORM_GRN}{last_or_only_sql}')
  425. else:
  426. pretty_print(f'{NORM_MGT}开始执行 SQL 语句: {NORM_GRN}{last_or_only_sql}')
  427. data_frame = self._spark_session.sql(last_or_only_sql)
  428. if fill_null:
  429. data_frame.fillna(fill_null)
  430. return data_frame, len(data_frame.schema.fields) > 0
  431. def query_scalar(self,
  432. sql_or_file: str,
  433. check_parameter: bool = False,
  434. silent: bool = False,
  435. fill_null=None,
  436. **kwargs):
  437. data, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs)
  438. if is_select is not True:
  439. raise Exception('Get scalar from non-query statement')
  440. res = data.collect()
  441. res1 = res[0]
  442. return res[0][0]
  443. def register_udf_files(self, udf_file_list: List[str]):
  444. if not udf_file_list:
  445. return
  446. for udf_file in udf_file_list:
  447. self.register_udf_file(udf_file)
  448. def register_udf_file(self, udf_file: str):
  449. if not udf_file:
  450. return
  451. if not self.add_py_file(udf_file):
  452. pretty_print(f'{NORM_YEL}Python文件 {NORM_GRN}{udf_file}{NORM_YEL} 已被添加过(重复添加)')
  453. return
  454. if not self._spark_session:
  455. self._udf_files.append(udf_file)
  456. return
  457. pretty_print(f'{NORM_MGT}注册文件 {NORM_GRN}{udf_file}{NORM_MGT} 中的 Python UDF')
  458. udf_module_name = udf_file.replace('/', '.').replace('.py', '')
  459. module = import_module(udf_module_name)
  460. for (name, attr) in inspect.getmembers(module):
  461. if self.IGNORED_UDF.__contains__(name):
  462. continue
  463. if self.REGISTERED_UDF.__contains__(name):
  464. pretty_print(f'{NORM_YEL}名为 {NORM_GRN}{name}{NORM_YEL} 的 Python UDF '
  465. f'{NORM_YEL}已在文件 {NORM_GRN}{self.REGISTERED_UDF[name]} '
  466. f'{NORM_YEL}被注册(重名UDF以首次出现的为准)')
  467. continue
  468. if inspect.isfunction(attr):
  469. # todo: 注册每个udf的返回值类型
  470. pretty_print(f'{NORM_MGT}注册 Python UDF {NORM_GRN}{name}')
  471. self._spark_session.udf.register(name, attr)
  472. self.REGISTERED_UDF[name] = udf_file
  473. @staticmethod
  474. def show_data_frame(data_frame: DataFrame,
  475. data_frame_name: str = None,
  476. show_number: int = 20,
  477. truncate: Union[bool, int] = False):
  478. """
  479. 打印``DataFrame``的Schema和指定行数的数据
  480. Args:
  481. data_frame: DataFrame
  482. data_frame_name: DataFrame的名称
  483. show_number: 打印的行数
  484. truncate: If set to ``True``, truncate strings longer than 20 chars by default.
  485. If set to a number greater than one, truncates long strings to length ``truncate``
  486. and align cells right.
  487. Returns:
  488. """
  489. if data_frame_name:
  490. pretty_print(f'{NORM_MGT}%s %s DataFrame Information: %s ' % ('=' * 30, data_frame_name, '=' * 30))
  491. else:
  492. pretty_print(f'{NORM_MGT}%s DataFrame Information: %s ' % ('=' * 30, '=' * 30))
  493. pretty_print(f'{NORM_MGT}Schema :')
  494. data_frame.printSchema()
  495. pretty_print(f'{NORM_MGT}Data sample :')
  496. data_frame.show(show_number, truncate=truncate)
  497. pretty_print(f'{NORM_MGT}=' * 70)
  498. if __name__ == '__main__':
  499. with SparkSQL('spark-sql') as spark_sql:
  500. # spark_sql.execute('workspace/x.sql')
  501. # spark_sql._spark_session.conf.set('app.name', 'tmp.sa33')
  502. # spark_sql.execute('workspace/y.sql')
  503. # spark_sql.query('show databases')[0].show(100, truncate=False)
  504. # spark_sql.query('show functions')[0].show(1000, truncate=False)
  505. spark_sql.get_columns('crl_ads.ads_crl_dim_source_stat')