spark_common_udf.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. #!/usr/bin/env /usr/bin/python3
  2. # -*- coding:utf-8 -*-
  3. """
  4. 通用 UDF —— 与业务无关的数据类型 / 格式操作(JSON / Array / String / Numeric / Date / Hash)
  5. SparkSQL 入口自动 ADD FILE 注册;业务专用 UDF 请放到 dw_base/udf/business/ 下按需加载
  6. """
  7. import difflib
  8. import hashlib
  9. import html
  10. import json
  11. import random
  12. import re
  13. from ast import literal_eval
  14. from datetime import datetime
  15. from typing import Dict, List, Union
  16. from pyspark.sql.functions import udf
  17. from pyspark.sql.types import (
  18. ArrayType, BooleanType, FloatType, LongType, MapType, StringType,
  19. )
  20. from dw_base.utils.datetime_utils import parse_datetime
  21. def _load_json_or_default(data, default=None):
  22. """优先按 JSON 解析,失败时返回默认值。"""
  23. try:
  24. return json.loads(data)
  25. except (TypeError, ValueError):
  26. return default
  27. def _load_json_or_literal(data, default=None):
  28. """先按 JSON 解析,失败后再按 Python 字面量兜底解析。"""
  29. parsed = _load_json_or_default(data, default=None)
  30. if parsed is not None:
  31. return parsed
  32. try:
  33. return literal_eval(data)
  34. except (ValueError, SyntaxError, TypeError):
  35. return default
  36. def _dedupe_keep_order(values: List) -> List:
  37. """按原始顺序去重。"""
  38. result = []
  39. for value in values:
  40. if value not in result:
  41. result.append(value)
  42. return result
  43. def _merge_non_empty_values(*arrays: List) -> List[str]:
  44. """合并多个数组,并过滤 None 与空字符串。"""
  45. result = set()
  46. for array in arrays:
  47. if array is None:
  48. continue
  49. for item in array:
  50. if item is not None and item != "":
  51. result.add(item)
  52. return list(result)
  53. # ==================== JSON ====================
  54. # UDF-01 JSON校验:判断输入是否为合法 JSON 字符串。
  55. @udf(returnType=BooleanType())
  56. def is_json(data) -> bool:
  57. """判断输入是否为合法 JSON 字符串。"""
  58. try:
  59. json.loads(data)
  60. except (TypeError, ValueError):
  61. return False
  62. return True
  63. # UDF-02 JSON取键:提取 JSON object 的 key 列表。
  64. @udf(returnType=ArrayType(StringType()))
  65. def json_object_keys(json_str: str) -> List[str]:
  66. """提取 JSON object 的 key 列表。"""
  67. if not json_str:
  68. return None
  69. json_dict = _load_json_or_default(json_str, default=None) # type:dict
  70. if not isinstance(json_dict, dict):
  71. return None
  72. return [k for k in json_dict.keys()]
  73. def flatten_json(json_str: str, reserve_parent: bool = True) -> str:
  74. """展平 JSON 字符串,`reserve_parent` 控制是否保留父级 key。"""
  75. def flatten_json_node(parent, json_element) -> Union[float, int, str, Dict, List]:
  76. if isinstance(json_element, dict):
  77. result = {}
  78. if parent and reserve_parent and reserve_parent is True:
  79. for key, value in json_element.items():
  80. result[f'{parent}.{key}'] = value
  81. else:
  82. for key, value in json_element.items():
  83. result.update(flatten_json_node(key, value))
  84. return result
  85. elif isinstance(json_element, list):
  86. result = []
  87. if parent and reserve_parent and reserve_parent is True:
  88. for index in range(len(json_element)):
  89. result.append(flatten_json_node(f'{parent}.[{index}]', json_element[index]))
  90. else:
  91. for index in range(len(json_element)):
  92. result.append(flatten_json_node(None, json_element[index]))
  93. return result
  94. else:
  95. return {parent: json_element}
  96. if not json_str:
  97. return json_str
  98. try:
  99. json_node = json.loads(json_str)
  100. flattened_json = flatten_json_node(None, json_node)
  101. return json.dumps(flattened_json, ensure_ascii=False)
  102. except (TypeError, ValueError):
  103. return json_str
  104. def remove_empty_key(info):
  105. """递归删除 JSON 中 value 为空的 key。"""
  106. json_info = json.loads(info)
  107. def internal_remove(json_info):
  108. try:
  109. if isinstance(json_info, dict):
  110. info_re = dict()
  111. for key, value in json_info.items():
  112. if isinstance(value, dict) or isinstance(value, list):
  113. re = internal_remove(value)
  114. if len(re):
  115. info_re[key] = re
  116. elif value not in ['', {}, [], 'null', None]:
  117. info_re[key] = str(value)
  118. return info_re
  119. elif isinstance(json_info, list):
  120. info_re = list()
  121. for value in json_info:
  122. if isinstance(value, dict) or isinstance(value, list):
  123. re = internal_remove(value)
  124. if len(re):
  125. info_re.append(re)
  126. elif value not in ['', {}, [], 'null', None]:
  127. info_re.append(str(value))
  128. return info_re
  129. else:
  130. return None
  131. except Exception as e:
  132. return None
  133. return json.dumps(internal_remove(json_info), ensure_ascii=False)
  134. def append_to_json_array(json_array_string: str, new_element, remove_duplicate: bool = False) -> str:
  135. """向 JSON array 末尾追加元素,可选去重。"""
  136. if not new_element:
  137. return json_array_string
  138. if not json_array_string:
  139. return json.dumps([new_element], ensure_ascii=False)
  140. json_array = _load_json_or_default(json_array_string, default=None) # type: list
  141. if not isinstance(json_array, list):
  142. return json_array_string
  143. json_array.append(new_element)
  144. if remove_duplicate is True:
  145. return json.dumps(_dedupe_keep_order(json_array), ensure_ascii=False)
  146. return json.dumps(json_array, ensure_ascii=False)
  147. def json_array_subset(json_array_string: str,
  148. subset_fields: Union[List, str],
  149. as_list: bool = False,
  150. skip_null: bool = False) -> str:
  151. """按字段提取 JSON object array 的子集。"""
  152. if not json_array_string:
  153. return None
  154. if not subset_fields:
  155. return None
  156. if isinstance(subset_fields, str):
  157. subset_field_list = subset_fields.split(',')
  158. else:
  159. subset_field_list = subset_fields
  160. if len(subset_field_list) == 0:
  161. return None
  162. json_array = _load_json_or_literal(json_array_string, default=None)
  163. if not isinstance(json_array, list):
  164. return None
  165. list_subset = []
  166. if len(subset_field_list) == 1 and as_list:
  167. only_subset_field = subset_field_list[0]
  168. for element in json_array: # type:Dict
  169. if isinstance(element, dict):
  170. field_value = element.get(only_subset_field)
  171. if field_value or not skip_null:
  172. list_subset.append(field_value)
  173. else:
  174. for element in json_array: # type:Dict
  175. subset_of_element = {}
  176. if isinstance(element, dict):
  177. for field in subset_field_list:
  178. field_value = element.get(field)
  179. if field_value or not skip_null:
  180. subset_of_element[field] = field_value
  181. list_subset.append(subset_of_element)
  182. return json.dumps(list_subset, ensure_ascii=False)
  183. # ==================== ARRAY ====================
  184. # UDF-21 数组交集:计算两个数组的交集。
  185. @udf(returnType=ArrayType(StringType()))
  186. def array_intersect(arr1, arr2):
  187. """计算两个数组的交集。"""
  188. return list(set(arr1) & set(arr2))
  189. def array_append(array: List, new_element,
  190. ignore_null: bool = False,
  191. remove_duplicate: bool = False,
  192. need_sort: bool = False) -> List:
  193. """向数组追加元素,可按现有规则控制空值、去重和排序。"""
  194. if not array or len(array) == 0:
  195. if new_element or ignore_null is not True:
  196. return [new_element]
  197. return []
  198. if not new_element:
  199. if ignore_null is True:
  200. return array
  201. else:
  202. if array.__contains__(new_element) and remove_duplicate is True:
  203. return array
  204. array.append(new_element)
  205. if need_sort:
  206. array.sort()
  207. return array
  208. # UDF-22 数组切片:按起止下标截取数组。
  209. @udf(ArrayType(StringType()))
  210. def array_slice(input_array, start, end):
  211. """截取数组切片,行为与 Python 切片一致。"""
  212. if input_array:
  213. return input_array[start:end]
  214. return []
  215. # UDF-23 数组合并:合并二维数组,并过滤 None 与空字符串。
  216. @udf(returnType=ArrayType(StringType()))
  217. def merge_list(arr_list: List):
  218. """合并二维数组,并过滤 None 与空字符串。"""
  219. return _merge_non_empty_values(*(arr_list or []))
  220. # ==================== STRING ====================
  221. # UDF-31 中文检测:判断字符串中是否包含中文字符。
  222. @udf(returnType=BooleanType())
  223. def has_chinese(datum: str) -> bool:
  224. """判断字符串中是否包含中文字符。"""
  225. if datum:
  226. pattern = re.compile(u'[\u4e00-\u9fa5]')
  227. if pattern.search(datum):
  228. return True
  229. return False
  230. # UDF-32 相似度计算:计算两个字符串的快速相似度。
  231. @udf(returnType=FloatType())
  232. def similarity(left: str, right: str) -> float:
  233. """计算两个字符串的快速相似度。"""
  234. return difflib.SequenceMatcher(None, left, right).quick_ratio()
  235. # UDF-33 正则全提取:提取正则表达式的全部匹配结果。
  236. @udf(returnType=ArrayType(StringType()))
  237. def regexp_extract_all(col: str, ptn: str, g: int = 0):
  238. """提取正则表达式的全部匹配结果。"""
  239. return [e.group(g) for e in re.compile(ptn).finditer(col if col else '')]
  240. def add_random_number_prefix(datum: str, separator: str, floor: int, ceiling: int) -> str:
  241. """给字符串追加随机数字前缀。"""
  242. return f'{random.randint(floor, ceiling)}{separator}{datum}'
  243. def field_merge(delimiter: str, *fields_values):
  244. """合并多个字段值,去重后用指定分隔符拼接。"""
  245. if not fields_values:
  246. return None
  247. result = []
  248. for value in fields_values:
  249. if value and value.strip() not in result:
  250. result.append(value.strip())
  251. return delimiter.join(result)
  252. def space2null(text):
  253. """把空白字符串规范化为 None。"""
  254. if text and not text.isspace():
  255. return text
  256. return None
  257. def merge_ws(text: str):
  258. """压缩多余空白符,只保留单个空格。"""
  259. if text:
  260. return ' '.join(text.split())
  261. return None
  262. def remove_special_char(text, char):
  263. """如果字符串以指定字符结尾,则移除最后一个字符。"""
  264. if text is not None and text.endswith(char):
  265. return text[:-1]
  266. return text
  267. def html_unescape(text):
  268. """反转义 HTML 实体。"""
  269. return html.unescape(text)
  270. # ==================== NUMERIC / DATE / HASH ====================
  271. def max_value(*args):
  272. """按现有真假值规则返回最大值。"""
  273. maxi_value = None
  274. for elem in args:
  275. if not elem:
  276. continue
  277. if not maxi_value or elem > maxi_value:
  278. maxi_value = elem
  279. return maxi_value
  280. def min_value(*args):
  281. """按现有真假值规则返回最小值。"""
  282. mini_value = None
  283. for elem in args:
  284. if not elem:
  285. continue
  286. if not mini_value or elem < mini_value:
  287. mini_value = elem
  288. return mini_value
  289. def millis_timestamp_to_str(ts: int, str_format: str = None) -> str:
  290. """把毫秒时间戳转换为时间字符串。"""
  291. date_time = datetime.fromtimestamp(ts / 1000.0)
  292. if str_format:
  293. return date_time.strftime(str_format)
  294. return date_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
  295. # UDF-41 时间解析:把日期字符串解析为时间戳。
  296. @udf(returnType=LongType())
  297. def parse_datetime_to_timestamp(date_time: str, in_milli_seconds: bool = False, original_format: str = None) -> int:
  298. """字符串日期 → 时间戳;支持 YY.MM.DD / YYYY年M月D日 启发式识别"""
  299. try:
  300. if date_time:
  301. d = date_time.split('.')
  302. if len(date_time) == 8 and len(d) == 3 and len(d[0]) == 2:
  303. date_time = '20' + date_time
  304. ret = re.match(r'(\d+)年(\d+)月(\d+)日', date_time)
  305. if ret:
  306. date_time = ret.group().replace('年', '-').replace('月', '-').replace('日', '')
  307. parsed_date = parse_datetime(date_time, original_format)
  308. if not parsed_date:
  309. return None
  310. if in_milli_seconds is True:
  311. return int(parsed_date.timestamp() * 1000)
  312. return int(parsed_date.timestamp())
  313. except:
  314. try:
  315. date_time = int(date_time)
  316. if datetime.now().timestamp() < date_time:
  317. return date_time if in_milli_seconds else int(date_time / 1000)
  318. return date_time * 1000 if in_milli_seconds else date_time
  319. except Exception:
  320. return None
  321. # UDF-42 MD5摘要:把多列值按长度前缀拼接后计算 MD5。
  322. @udf(returnType=StringType())
  323. def get_md5(*cols: str) -> str:
  324. """多列拼接(带长度前缀防碰撞)后取 md5"""
  325. col_and_len_list = []
  326. for col in cols:
  327. if col is not None:
  328. col_and_len_list.append(str(len(col)))
  329. col_and_len_list.append(col)
  330. key = ''.join(col_and_len_list)
  331. if not key:
  332. return ''
  333. md5 = hashlib.md5()
  334. md5.update(key.encode("utf-8"))
  335. return md5.hexdigest()
  336. # ==================== CROSS-TYPE CONVERTERS ====================
  337. def array_to_json(arr: List):
  338. """把数组序列化为 JSON 字符串。"""
  339. return json.dumps(arr, ensure_ascii=False)
  340. def map_to_json(map: dict):
  341. """把字典序列化为 JSON 字符串。"""
  342. return json.dumps(map, ensure_ascii=False)
  343. def struct_to_json(struct):
  344. """把结构体对象转换为 JSON 字符串。"""
  345. json_dict = {key: struct[key] for key in struct.__dict__["__fields__"]}
  346. return json.dumps(json_dict, ensure_ascii=False)
  347. def num_to_str(number):
  348. """把数值转换成字符串,整数型浮点数去掉小数位。"""
  349. if isinstance(number, float) and number.is_integer():
  350. return '{:.0f}'.format(number)
  351. return str(int(number)) if isinstance(number, int) else str(number)
  352. # UDF-51 字符串转数组:把 JSON array 字符串转换为 Python list。
  353. @udf(returnType=ArrayType(StringType()))
  354. def str_to_arr(json_str: str) -> list:
  355. """把 JSON array 字符串转换为 Python list。"""
  356. if json_str:
  357. parsed = _load_json_or_default(json_str, default=[])
  358. return parsed if isinstance(parsed, list) else []
  359. return []
  360. # UDF-52 字符串转JSON字符串数组:把 JSON array 转为 JSON 字符串数组。
  361. @udf(returnType=ArrayType(StringType()))
  362. def str_to_json_arr(json_str):
  363. """JSON array 字符串 → list of json strings(每个元素再 json.dumps)"""
  364. if json_str:
  365. try:
  366. str_arr = json.loads(json_str)
  367. if isinstance(str_arr, list):
  368. return [json.dumps(sm) for sm in str_arr]
  369. except json.JSONDecodeError:
  370. return []
  371. return []
  372. # UDF-53 字符串转Map数组:把 JSON array 字符串转换为 map 数组。
  373. @udf(returnType=ArrayType(MapType(StringType(), StringType())))
  374. def str_to_map_arr(json_str: str) -> list:
  375. """把 JSON array 字符串转换为 map 数组。"""
  376. if json_str:
  377. parsed = _load_json_or_default(json_str, default=[])
  378. return parsed if isinstance(parsed, list) else []
  379. return []