spark_sql.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. # -*- coding:utf-8 -*-
  2. import inspect
  3. import re
  4. from importlib import import_module
  5. from typing import List, Optional, Union, Dict, Any, Tuple
  6. from pyspark.sql import Row, SparkSession, DataFrame
  7. from dw_base import *
  8. from dw_base.utils import common_utils
  9. from dw_base.utils.datetime_utils import get_today
  10. from dw_base.utils.file_utils import get_abs_path
  11. from dw_base.utils.log_utils import pretty_print
  12. from dw_base.utils.sql_utils import get_sql_list_from_file, check_parameter_substituted
  13. HDFS_EXPORT_DATA_PATH = '/hdfs-mnt/export-data'
  14. def _load_spark_conf_file(path: str) -> Dict[str, str]:
  15. """读 Spark 原生 conf(每行 `key value`,# 注释,空白分隔)。文件缺失或行非法时返回空 dict,不抛错。"""
  16. if not os.path.isfile(path):
  17. return {}
  18. config = {}
  19. with open(path, 'r') as f:
  20. for line in f:
  21. line = line.strip()
  22. if not line or line.startswith('#'):
  23. continue
  24. parts = line.split(None, 1)
  25. if len(parts) != 2:
  26. continue
  27. config[parts[0]] = parts[1]
  28. return config
  29. class SparkSQL(object):
  30. """
  31. 封装执行 Spark 相关操作的类, 相关参数说明:
  32. In Spark 2.0+ version
  33. use spark session variable to set number of executors dynamically (from within program)
  34. spark.conf.set("spark.executor.instances', 4)
  35. spark.conf.set("spark.executor.cores', 4)
  36. In above case maximum 16 tasks will be executed at any given time.
  37. other option is dynamic allocation of executors as below -
  38. spark.conf.set("spark.dynamicAllocation.enabled', "true')
  39. spark.conf.set("spark.executor.cores', 4)
  40. spark.conf.set("spark.dynamicAllocation.minExecutors',"1')
  41. spark.conf.set("spark.dynamicAllocation.maxExecutors',"5')
  42. spark.yarn.executor.memoryOverhead:default is executorMemory * 0.07, with minimum of 384
  43. spark.yarn.driver.memoryOverhead:default is driverMemory * 0.07, with minimum of 384
  44. spark.yarn.am.memoryOverhead:default is AM memory * 0.07, with minimum of 384
  45. """
  46. REGISTERED_UDF_FILES = []
  47. ADDED_RESOURCE_FILES = []
  48. REGISTERED_UDF = {}
  49. IGNORED_UDF = ['attr', 'udf']
  50. def __init__(self,
  51. session_name: str = 'spark',
  52. master: str = 'yarn',
  53. spark_yarn_queue: Optional[str] = None,
  54. spark_driver_memory: Optional[str] = None,
  55. spark_executor_memory: Optional[str] = None,
  56. spark_executor_memory_overhead: Optional[str] = None,
  57. spark_driver_cores: Optional[int] = None,
  58. spark_executor_cores: Optional[int] = None,
  59. spark_executor_instances: Optional[int] = None,
  60. spark_driver_max_result_size: Optional[str] = None,
  61. spark_shuffle_partitions: Optional[int] = None,
  62. spark_default_parallelism: Optional[int] = None,
  63. extra_spark_config: Dict[str, Any] = None,
  64. udf_files: List[str] = None,
  65. resource_files: List[str] = None):
  66. """
  67. Args:
  68. session_name:
  69. master:
  70. spark_yarn_queue:
  71. spark_driver_memory:
  72. spark_executor_memory:
  73. spark_executor_memory_overhead:
  74. spark_driver_cores:
  75. spark_executor_cores:
  76. spark_executor_instances:
  77. spark_driver_max_result_size:
  78. spark_shuffle_partitions:
  79. spark_default_parallelism:
  80. extra_spark_config: 额外的Spark配置,优先级高于SQL文件中定义的Spark配置
  81. udf_files: udf文件列表,必须传相对路径,否则不识别
  82. resource_files: 资源文件列表,必须传相对路径,否则不识别
  83. """
  84. self._spark_session = None
  85. self._session_name = session_name
  86. self._master = master
  87. self._spark_driver_memory = spark_driver_memory
  88. self._spark_executor_memory = spark_executor_memory
  89. self._spark_executor_memory_overhead = spark_executor_memory_overhead
  90. self._spark_driver_cores = spark_driver_cores
  91. self._spark_executor_cores = spark_executor_cores
  92. self._spark_executor_instances = spark_executor_instances
  93. self._spark_yarn_queue = spark_yarn_queue
  94. self._spark_driver_max_result_size = spark_driver_max_result_size
  95. self._spark_shuffle_partitions = spark_shuffle_partitions
  96. self._spark_default_parallelism = spark_default_parallelism
  97. self._extra_spark_config = extra_spark_config
  98. self._final_spark_config = {}
  99. self.limit = 20
  100. if not udf_files:
  101. udf_files = []
  102. self._udf_files = udf_files
  103. self._py_files = []
  104. if not resource_files:
  105. resource_files = []
  106. self._resource_files = resource_files
  107. self.global_start_time = None
  108. self.start_time = None
  109. def __add_spark_config(self, sql_file: str, spark_config_def: str):
  110. if not spark_config_def.__contains__('='):
  111. pretty_print(f'{NORM_YEL}无效的 Spark 配置 {NORM_GRN}{spark_config_def}')
  112. return
  113. spark_config, config_value = [e.strip() for e in spark_config_def.split('=')]
  114. if self._final_spark_config.__contains__(spark_config):
  115. pretty_print(
  116. f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_file}{NORM_YEL} 中定义的 Spark 配置 {NORM_GRN}{spark_config} '
  117. f'{NORM_YEL}重复提供,原值 {NORM_GRN}{self._final_spark_config[spark_config]} '
  118. f'{NORM_YEL}将被覆盖为新值 {NORM_GRN}{config_value}')
  119. else:
  120. self._final_spark_config[spark_config] = config_value
  121. self._final_spark_config[spark_config] = config_value
  122. def __init_spark_session(self):
  123. if self._spark_session:
  124. return
  125. pretty_print(f'{NORM_MGT}基于用户 {NORM_GRN}{USER}{NORM_MGT} 创建 SparkSession')
  126. builder = SparkSession.builder \
  127. .appName(self._session_name) \
  128. .master(self._master)
  129. # L1:conf/spark-defaults.conf(底层)+ conf/spark-tuning.conf(调优,相同 key 覆盖 defaults)
  130. l1_defaults = {}
  131. l1_defaults.update(_load_spark_conf_file(f'{PROJECT_ROOT_PATH}/conf/spark-defaults.conf'))
  132. l1_defaults.update(_load_spark_conf_file(f'{PROJECT_ROOT_PATH}/conf/spark-tuning.conf'))
  133. pretty_print(f'{NORM_MGT}L1 加载 {NORM_GRN}{len(l1_defaults)}{NORM_MGT} 条 conf 默认')
  134. for key, value in l1_defaults.items():
  135. pretty_print(f'{NORM_MGT}L1 应用 conf 默认 {NORM_GRN}{key} => {str(value)}')
  136. builder.config(key, value)
  137. # L2:SQL 内 SET(query() 预扫描,session 启动前写入 builder)
  138. for key, value in self._final_spark_config.items():
  139. pretty_print(f'{NORM_MGT}L2 应用 SQL SET {NORM_GRN}{key} => {str(value)}')
  140. builder.config(key, value)
  141. # L3:构造函数显式传参 + extra_spark_config
  142. l3_overrides: Dict[str, Any] = {}
  143. for conf_key, attr_val in (
  144. ('spark.yarn.queue', self._spark_yarn_queue),
  145. ('spark.driver.memory', self._spark_driver_memory),
  146. ('spark.executor.memory', self._spark_executor_memory),
  147. ('spark.executor.memoryOverhead', self._spark_executor_memory_overhead),
  148. ('spark.driver.cores', self._spark_driver_cores),
  149. ('spark.executor.cores', self._spark_executor_cores),
  150. ('spark.executor.instances', self._spark_executor_instances),
  151. ('spark.driver.maxResultSize', self._spark_driver_max_result_size),
  152. ('spark.sql.shuffle.partitions', self._spark_shuffle_partitions),
  153. ('spark.default.parallelism', self._spark_default_parallelism),
  154. ):
  155. if attr_val is not None:
  156. l3_overrides[conf_key] = attr_val
  157. if self._extra_spark_config:
  158. l3_overrides.update(self._extra_spark_config)
  159. for key, value in l3_overrides.items():
  160. pretty_print(f'{NORM_MGT}L3 应用构造参数/extra {NORM_GRN}{key} => {str(value)}')
  161. builder.config(key, value)
  162. pretty_print(f'{NORM_MGT}创建 SparkSession')
  163. self._spark_session = builder.enableHiveSupport().getOrCreate()
  164. self._spark_session.sparkContext._jsc.hadoopConfiguration().set('mapred.max.split.size', '33554432')
  165. self._spark_session.sparkContext._jsc.hadoopConfiguration().set(
  166. 'mapreduce.fileoutputcommitter.marksuccessfuljobs', 'false')
  167. first_py_file = 'dw_base.zip'
  168. if IS_RUN_IN_RELEASE_DIR and IS_RUN_BY_RELEASE_USER:
  169. command = f'cd {PROJECT_ROOT_PATH} && if [ ! -f {first_py_file} ];then zip -qr {first_py_file} dw_base; fi'
  170. else:
  171. command = f'cd {PROJECT_ROOT_PATH} && rm -f {first_py_file} && zip -qr {first_py_file} dw_base'
  172. pretty_print(f'{NORM_MGT}执行 Shell 命令 {NORM_GRN}{command}')
  173. os.system(command)
  174. self._spark_session.sparkContext.addPyFile(get_abs_path(first_py_file))
  175. self.register_udf_files(self._udf_files)
  176. for py_file in self._py_files:
  177. self.add_py_file(py_file)
  178. for resource_file in self._resource_files:
  179. self.add_resource_file(resource_file)
  180. @staticmethod
  181. def add_parameters(sql_or_file: str, parameters: Dict, key: str, value: Any):
  182. if parameters.__contains__(key):
  183. pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} '
  184. f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} '
  185. f'{NORM_YEL}发生变更,将使用新值 {NORM_GRN}{Any} '
  186. f'{NORM_YEL}替代原值 {NORM_GRN}{parameters[key]}')
  187. parameters[key] = value
  188. def add_py_file(self, py_file: str) -> bool:
  189. abs_py_file = get_abs_path(py_file)
  190. if self.REGISTERED_UDF_FILES.__contains__(abs_py_file):
  191. return False
  192. if not self._spark_session:
  193. self._py_files.append(py_file)
  194. return True
  195. pretty_print(f'{NORM_MGT}添加 Python 文件 {NORM_GRN}{py_file}')
  196. self._spark_session.sparkContext.addPyFile(abs_py_file)
  197. self.REGISTERED_UDF_FILES.append(abs_py_file)
  198. return True
  199. def add_resource_file(self, resource_file: str) -> bool:
  200. abs_resource_file = get_abs_path(resource_file)
  201. if self.ADDED_RESOURCE_FILES.__contains__(abs_resource_file):
  202. return False
  203. if not self._spark_session:
  204. self._resource_files.append(resource_file)
  205. return True
  206. pretty_print(f'{NORM_MGT}添加资源文件 {NORM_GRN}{resource_file}')
  207. self._spark_session.sparkContext.addFile(abs_resource_file)
  208. self.ADDED_RESOURCE_FILES.append(abs_resource_file)
  209. return True
  210. def execute(self,
  211. sql_or_file: str,
  212. check_parameter: bool = False,
  213. silent: bool = False,
  214. fill_null=None,
  215. **kwargs):
  216. """
  217. 仅执行,不返回任何值
  218. Args:
  219. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件路径(相对或绝对都可以)
  220. check_parameter: 运行前是否进行参数检查
  221. silent: 执行时是否打印日志,True则不打印
  222. fill_null: 填充null值
  223. **kwargs: sql 语句或 sql 文件所用到的参数
  224. Returns:
  225. """
  226. # 运行SQL语句或SQL文件
  227. start_time = time.time()
  228. data_frame, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs)
  229. if is_select:
  230. self.show_data_frame(data_frame, show_number=self.limit, truncate=USER == RELEASE_USER)
  231. end_time = time.time()
  232. cost = round(float(end_time - self.start_time), 2)
  233. pretty_print(f'{NORM_MGT}SQL 语句执行完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒')
  234. if sql_or_file.endswith('.sql'):
  235. total_cost = round(float(end_time - self.global_start_time), 2)
  236. pretty_print(f'{NORM_MGT}SQL 文件 {NORM_GRN}{sql_or_file}'
  237. f'{NORM_MGT} 执行完毕,共耗时 {NORM_GRN}{str(total_cost)}{NORM_MGT} 秒')
  238. self.global_start_time = None
  239. def export_data(self,
  240. data_set_name: str,
  241. sql_or_file: str,
  242. show_number: int = 100,
  243. truncate: Union[bool, int] = 40,
  244. delimiter: str = ',',
  245. partition: int = 1,
  246. **kwargs) -> str:
  247. """
  248. 导出数据,默认存储于HDFS目录(/hdfs-mnt/export-data),映射本地目录(/opt/hdfs-mnt/export-data)
  249. Args:
  250. data_set_name: 导出的数据集名称标识(作为目录)
  251. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径
  252. show_number: 显示的行数
  253. truncate: If set to ``True``, truncate strings longer than 20 chars by default.
  254. If set to a number greater than one, truncates long strings to length ``truncate``
  255. and align cells right.
  256. delimiter: 分隔符
  257. partition: 导出文件个数
  258. **kwargs: sql 语句或 sql 文件所用到的参数
  259. Returns:
  260. """
  261. start_time = time.time()
  262. data_frame, _ = self.query(sql_or_file, **kwargs)
  263. data_frame.persist()
  264. self.show_data_frame(data_frame, data_set_name, show_number, truncate)
  265. today = get_today()
  266. # 注意: 此目录不是linux服务器的目录, 而是hdfs文件系统的目录
  267. hdfs_directory = f'{data_set_name}.{str(int(time.time()))}'
  268. hdfs_export_path = os.path.join(HDFS_EXPORT_DATA_PATH, USER, today, hdfs_directory)
  269. # linux服务器导出文件存放的目录
  270. # local_export_path = os.path.join(LOCAL_EXPORT_DATA_PATH, USER, today, data_set_name)
  271. pretty_print(f'{NORM_MGT}准备导出文件到HDFS目录: {NORM_GRN}{hdfs_export_path}')
  272. if not delimiter:
  273. delimiter = ','
  274. # 先将数据保存至hdfs文件系统, 加 option("escape", "\"") 不会导致csv文件“字段文本内容中的英文逗号来分割字段(列膨胀错位)”
  275. data_frame \
  276. .repartition(partition) \
  277. .write \
  278. .mode('overwrite') \
  279. .format('com.databricks.spark.csv') \
  280. .option('header', 'true') \
  281. .option("escape", "\"") \
  282. .option('delimiter', delimiter) \
  283. .save(hdfs_export_path)
  284. cost = round(float(time.time() - start_time), 2)
  285. pretty_print(f'{NORM_MGT}数据导出完毕,耗时 {NORM_GRN}{str(cost)}{NORM_MGT} 秒')
  286. return hdfs_export_path
  287. def get_columns(self, hive_table_name: str):
  288. desc_df, _ = self.query(f'desc {hive_table_name}', silent=True)
  289. desc_rows = desc_df.collect()
  290. hive_columns = {}
  291. partition_col_start = False
  292. for row in desc_rows: # type: Row
  293. if row[0] == '# Partition Information':
  294. partition_col_start = True
  295. continue
  296. if row[0] == '# col_name':
  297. continue
  298. if partition_col_start:
  299. del hive_columns[list(hive_columns.keys())[len(hive_columns.keys()) - 1]]
  300. else:
  301. hive_columns[row[0]] = (row[1], row[2])
  302. return hive_columns
  303. def get_partition_columns(self, hive_table_name: str):
  304. desc_df, _ = self.query(f'desc {hive_table_name}', silent=True)
  305. desc_rows = desc_df.collect()
  306. partition_columns = {}
  307. partitioned = False
  308. for row in desc_rows: # type: Row
  309. if row[0] != '# Partition Information':
  310. continue
  311. if row[0].startswith('#'):
  312. continue
  313. partitioned = True
  314. partition_columns[row[0]] = (row[1], row[2])
  315. return partitioned, partition_columns
  316. def get_spark_session(self):
  317. if not self._spark_session:
  318. self.__init_spark_session()
  319. self.start_time = time.time()
  320. if not self.global_start_time:
  321. self.global_start_time = time.time()
  322. return self._spark_session
  323. def jdbc_read(self, jdbc_url: str, table_name: str) -> DataFrame:
  324. return self.get_spark_session().read.jdbc(jdbc_url, table_name)
  325. def list_tables(self, hive_database_name: str = 'default', include_regex: str = None, exclude_regex: str = None):
  326. if not hive_database_name:
  327. hive_database_name = 'default'
  328. tables_df, _ = self.query(f'show tables in {hive_database_name}')
  329. tables = []
  330. for row in tables_df.collect():
  331. if hive_database_name == 'default':
  332. table_name = row.asDict()['tableName']
  333. else:
  334. table_name = f"{hive_database_name}.{row.asDict()['tableName']}"
  335. if exclude_regex and re.match(exclude_regex, table_name):
  336. continue
  337. if include_regex:
  338. if re.match(include_regex, table_name):
  339. tables.append(table_name)
  340. continue
  341. else:
  342. tables.append(table_name)
  343. return tables
  344. def query(self,
  345. sql_or_file: str,
  346. check_parameter: bool = False,
  347. silent: bool = False,
  348. fill_null=None,
  349. **kwargs) -> Tuple[DataFrame, bool]:
  350. """
  351. 执行查询,返回一个DataFrame和该DataFrame是否是SELECT
  352. Args:
  353. sql_or_file: 需要执行的 sql 语句或者包含 sql 语句的文件全路径
  354. check_parameter: 运行前是否进行参数检查
  355. silent: 执行时是否打印日志,True则不打印
  356. fill_null: 填充null值
  357. **kwargs: sql 语句或 sql 文件所用到的参数
  358. Returns: Tuple[DataFrame, bool]
  359. """
  360. if sql_or_file.endswith('.sql'):
  361. if not silent:
  362. pretty_print(f'{NORM_MGT}准备执行 SQL 文件 {NORM_GRN}{sql_or_file}')
  363. sql_list = get_sql_list_from_file(sql_or_file, trim_comment=True)
  364. else:
  365. sql_list = [sql_or_file]
  366. cleaned_sql_list = []
  367. # 外部传入的参数
  368. external_parameters = {}
  369. if kwargs:
  370. for key, value in kwargs.items():
  371. if not silent:
  372. pretty_print(f'{NORM_MGT}收到外部 SQL 参数 {NORM_GRN}{key} => {value}')
  373. external_parameters[key] = value
  374. # SQL文件内部定义的参数
  375. internal_parameters = {}
  376. for sql in sql_list:
  377. upper_sql = sql.upper()
  378. if upper_sql.startswith('ADD FILES'):
  379. file_names = sql[len('ADD FILES'):].strip().split(' ')
  380. for file_name in file_names:
  381. file_name = file_name.strip()
  382. if not file_name:
  383. continue
  384. if file_name.endswith('.py'):
  385. self.register_udf_file(file_name)
  386. else:
  387. self.add_resource_file(file_name)
  388. elif upper_sql.startswith('ADD FILE'):
  389. file_name = sql[len('ADD FILE'):].strip()
  390. if file_name.endswith('.py'):
  391. self.register_udf_file(file_name)
  392. else:
  393. self.add_resource_file(file_name)
  394. elif upper_sql.startswith('ADD SPARK_UDF'):
  395. self.register_udf_file(sql[len('ADD SPARK_UDF'):].strip())
  396. elif upper_sql.startswith('SET SPARK_VAR:'):
  397. var_key, var_value = sql[len('SET SPARK_VAR:'):].strip().split('=')
  398. var_key = var_key.strip()
  399. var_value = var_value.strip()
  400. self.add_parameters(sql_or_file, internal_parameters, var_key, var_value)
  401. elif upper_sql.startswith('SET LIMIT='):
  402. if not self._spark_session:
  403. self.limit = int(sql.split('=')[1].strip())
  404. elif upper_sql.startswith('SET '):
  405. set_expr = sql[len('SET '):].strip()
  406. set_expr_splits = set_expr.split('=')
  407. set_key = set_expr_splits[0].strip()
  408. set_value = set_expr_splits[1].strip()
  409. if re.match(r'.+\..+(\..+)*', set_key):
  410. if not self._spark_session:
  411. self.__add_spark_config(sql_or_file, sql[len('SET '):].strip())
  412. else:
  413. self.add_parameters(sql_or_file, internal_parameters, set_key, set_value)
  414. else:
  415. # 优先使用外部参数
  416. for key, value in external_parameters.items():
  417. sql = sql.replace('${%s}' % key, value)
  418. # SQL文件里的参数必须在使用前定义(同时也意味着相同的参数可以多次定义不同值,供多个SQL使用)
  419. for key, value in internal_parameters.items():
  420. if external_parameters.__contains__(key):
  421. # 如果外部参数中包括当前参数,则使用外部参数(实际上已在上一步替换,此处仅为提示)
  422. pretty_print(f'{NORM_YEL}SQL 文件 {NORM_GRN}{sql_or_file} '
  423. f'{NORM_YEL}中定义的 SQL 参数 {NORM_GRN}{key} '
  424. f'{NORM_YEL}已由外部参数抢占,将使用新值 {NORM_GRN}{external_parameters[key]} '
  425. f'{NORM_YEL}替代原值 {NORM_GRN}{internal_parameters[key]}')
  426. continue
  427. sql = sql.replace('${%s}' % key, value)
  428. cleaned_sql_list.append(sql)
  429. if len(cleaned_sql_list) == 0:
  430. return self.get_spark_session().sql("SELECT ''"), False
  431. for index in range(len(cleaned_sql_list) - 1):
  432. sql = cleaned_sql_list[index]
  433. self.execute(sql)
  434. last_or_only_sql = cleaned_sql_list[-1]
  435. check_parameter_substituted(last_or_only_sql, check_parameter)
  436. self.get_spark_session()
  437. if not silent:
  438. if last_or_only_sql.__contains__('\n'):
  439. pretty_print(f'{NORM_MGT}开始执行 SQL 语句: \n{NORM_GRN}{last_or_only_sql}')
  440. else:
  441. pretty_print(f'{NORM_MGT}开始执行 SQL 语句: {NORM_GRN}{last_or_only_sql}')
  442. data_frame = self._spark_session.sql(last_or_only_sql)
  443. if fill_null:
  444. data_frame.fillna(fill_null)
  445. return data_frame, len(data_frame.schema.fields) > 0
  446. def query_scalar(self,
  447. sql_or_file: str,
  448. check_parameter: bool = False,
  449. silent: bool = False,
  450. fill_null=None,
  451. **kwargs):
  452. data, is_select = self.query(sql_or_file, check_parameter, silent, fill_null, **kwargs)
  453. if is_select is not True:
  454. raise Exception('Get scalar from non-query statement')
  455. res = data.collect()
  456. res1 = res[0]
  457. return res[0][0]
  458. def register_udf_files(self, udf_file_list: List[str]):
  459. if not udf_file_list:
  460. return
  461. for udf_file in udf_file_list:
  462. self.register_udf_file(udf_file)
  463. def register_udf_file(self, udf_file: str):
  464. if not udf_file:
  465. return
  466. if not self.add_py_file(udf_file):
  467. pretty_print(f'{NORM_YEL}Python文件 {NORM_GRN}{udf_file}{NORM_YEL} 已被添加过(重复添加)')
  468. return
  469. if not self._spark_session:
  470. self._udf_files.append(udf_file)
  471. return
  472. pretty_print(f'{NORM_MGT}注册文件 {NORM_GRN}{udf_file}{NORM_MGT} 中的 Python UDF')
  473. udf_module_name = udf_file.replace('/', '.').replace('.py', '')
  474. module = import_module(udf_module_name)
  475. for (name, attr) in inspect.getmembers(module):
  476. if self.IGNORED_UDF.__contains__(name):
  477. continue
  478. if self.REGISTERED_UDF.__contains__(name):
  479. pretty_print(f'{NORM_YEL}名为 {NORM_GRN}{name}{NORM_YEL} 的 Python UDF '
  480. f'{NORM_YEL}已在文件 {NORM_GRN}{self.REGISTERED_UDF[name]} '
  481. f'{NORM_YEL}被注册(重名UDF以首次出现的为准)')
  482. continue
  483. if inspect.isfunction(attr):
  484. # todo: 注册每个udf的返回值类型
  485. pretty_print(f'{NORM_MGT}注册 Python UDF {NORM_GRN}{name}')
  486. self._spark_session.udf.register(name, attr)
  487. self.REGISTERED_UDF[name] = udf_file
  488. @staticmethod
  489. def show_data_frame(data_frame: DataFrame,
  490. data_frame_name: str = None,
  491. show_number: int = 20,
  492. truncate: Union[bool, int] = False):
  493. """
  494. 打印``DataFrame``的Schema和指定行数的数据
  495. Args:
  496. data_frame: DataFrame
  497. data_frame_name: DataFrame的名称
  498. show_number: 打印的行数
  499. truncate: If set to ``True``, truncate strings longer than 20 chars by default.
  500. If set to a number greater than one, truncates long strings to length ``truncate``
  501. and align cells right.
  502. Returns:
  503. """
  504. if data_frame_name:
  505. pretty_print(f'{NORM_MGT}%s %s DataFrame Information: %s ' % ('=' * 30, data_frame_name, '=' * 30))
  506. else:
  507. pretty_print(f'{NORM_MGT}%s DataFrame Information: %s ' % ('=' * 30, '=' * 30))
  508. pretty_print(f'{NORM_MGT}Schema :')
  509. data_frame.printSchema()
  510. pretty_print(f'{NORM_MGT}Data sample :')
  511. data_frame.show(show_number, truncate=truncate)
  512. pretty_print(f'{NORM_MGT}=' * 70)
  513. if __name__ == '__main__':
  514. with SparkSQL('spark-sql') as spark_sql:
  515. # spark_sql.execute('workspace/x.sql')
  516. # spark_sql._spark_session.conf.set('app.name', 'tmp.sa33')
  517. # spark_sql.execute('workspace/y.sql')
  518. # spark_sql.query('show databases')[0].show(100, truncate=False)
  519. # spark_sql.query('show functions')[0].show(1000, truncate=False)
  520. spark_sql.get_columns('crl_ads.ads_crl_dim_source_stat')