spark_common_udf.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. #!/usr/bin/env /usr/bin/python3
  2. # -*- coding:utf-8 -*-
  3. """
  4. 通用 UDF —— 与业务无关的数据类型 / 格式操作(JSON / Array / String / Numeric / Date / Hash)
  5. SparkSQL 入口自动 ADD FILE 注册;业务专用 UDF 请放到 dw_base/udf/business/ 下按需加载
  6. """
  7. import difflib
  8. import hashlib
  9. import html
  10. import json
  11. import random
  12. import re
  13. import traceback
  14. from collections import Counter
  15. from datetime import datetime
  16. from typing import Dict, List, Union
  17. from pyspark.sql.functions import udf
  18. from pyspark.sql.types import (
  19. ArrayType, BooleanType, FloatType, IntegerType, LongType, MapType,
  20. StringType, StructField, StructType,
  21. )
  22. from dw_base.utils.datetime_utils import parse_datetime
  23. # ==================== JSON ====================
  24. @udf(returnType=BooleanType())
  25. def is_json(data) -> bool:
  26. try:
  27. json.loads(data)
  28. except:
  29. return False
  30. return True
  31. @udf(returnType=ArrayType(StringType()))
  32. def json_object_keys(json_str: str) -> List[str]:
  33. if not json_str:
  34. return None
  35. try:
  36. json_dict = json.loads(json_str) # type:dict
  37. return [k for k in json_dict.keys()]
  38. except:
  39. return None
  40. def flatten_json(json_str: str, reserve_parent: bool = True) -> str:
  41. """展平 json,reserve_parent 控制是否保留父 key"""
  42. def flatten_json_node(parent, json_element) -> Union[float, int, str, Dict, List]:
  43. if isinstance(json_element, dict):
  44. result = {}
  45. if parent and reserve_parent and reserve_parent is True:
  46. for key, value in json_element.items():
  47. result[f'{parent}.{key}'] = value
  48. else:
  49. for key, value in json_element.items():
  50. result.update(flatten_json_node(key, value))
  51. return result
  52. elif isinstance(json_element, list):
  53. result = []
  54. if parent and reserve_parent and reserve_parent is True:
  55. for index in range(len(json_element)):
  56. result.append(flatten_json_node(f'{parent}.[{index}]', json_element[index]))
  57. else:
  58. for index in range(len(json_element)):
  59. result.append(flatten_json_node(None, json_element[index]))
  60. return result
  61. else:
  62. return {parent: json_element}
  63. if not json_str:
  64. return json_str
  65. try:
  66. json_node = json.loads(json_str)
  67. flattened_json = flatten_json_node(None, json_node)
  68. return json.dumps(flattened_json, ensure_ascii=False)
  69. except Exception as e:
  70. traceback.format_exc(e)
  71. return json_str
  72. def remove_empty_key(info):
  73. """递归删除 json 中 value 为空的 key"""
  74. json_info = json.loads(info)
  75. def internal_remove(json_info):
  76. try:
  77. if isinstance(json_info, dict):
  78. info_re = dict()
  79. for key, value in json_info.items():
  80. if isinstance(value, dict) or isinstance(value, list):
  81. re = internal_remove(value)
  82. if len(re):
  83. info_re[key] = re
  84. elif value not in ['', {}, [], 'null', None]:
  85. info_re[key] = str(value)
  86. return info_re
  87. elif isinstance(json_info, list):
  88. info_re = list()
  89. for value in json_info:
  90. if isinstance(value, dict) or isinstance(value, list):
  91. re = internal_remove(value)
  92. if len(re):
  93. info_re.append(re)
  94. elif value not in ['', {}, [], 'null', None]:
  95. info_re.append(str(value))
  96. return info_re
  97. else:
  98. return None
  99. except Exception as e:
  100. return None
  101. return json.dumps(internal_remove(json_info), ensure_ascii=False)
  102. def append_to_json_array(json_array_string: str, new_element, remove_duplicate: bool = False) -> str:
  103. """向 JSON array 追加元素,可选去重"""
  104. if not new_element:
  105. return json_array_string
  106. if not json_array_string:
  107. return json.dumps([new_element], ensure_ascii=False)
  108. json_array = json.loads(json_array_string) # type: list
  109. json_array.append(new_element)
  110. if remove_duplicate is True:
  111. result = []
  112. for elem in json_array:
  113. if result.__contains__(elem):
  114. continue
  115. result.append(elem)
  116. return json.dumps(result, ensure_ascii=False)
  117. return json.dumps(json_array, ensure_ascii=False)
  118. def json_array_subset(json_array_string: str,
  119. subset_fields: Union[List, str],
  120. as_list: bool = False,
  121. skip_null: bool = False) -> str:
  122. """按字段提取 json object array 的子集"""
  123. if not json_array_string:
  124. return None
  125. if not subset_fields:
  126. return None
  127. if isinstance(subset_fields, str):
  128. subset_field_list = subset_fields.split(',')
  129. else:
  130. subset_field_list = subset_fields
  131. if len(subset_field_list) == 0:
  132. return None
  133. try:
  134. json_array = json.loads(json_array_string)
  135. except:
  136. json_array = eval(json_array_string)
  137. list_subset = []
  138. if len(subset_field_list) == 1 and as_list:
  139. only_subset_field = subset_field_list[0]
  140. for element in json_array: # type:Dict
  141. if isinstance(element, dict):
  142. field_value = element.get(only_subset_field)
  143. if field_value or not skip_null:
  144. list_subset.append(field_value)
  145. else:
  146. for element in json_array: # type:Dict
  147. subset_of_element = {}
  148. if isinstance(element, dict):
  149. for field in subset_field_list:
  150. field_value = element.get(field)
  151. if field_value or not skip_null:
  152. subset_of_element[field] = field_value
  153. list_subset.append(subset_of_element)
  154. return json.dumps(list_subset, ensure_ascii=False)
  155. @udf(returnType=ArrayType(StructType([
  156. StructField("idx", IntegerType(), False),
  157. StructField("obj", StringType(), False),
  158. ])))
  159. def parse_jsonarr_to_arr(s: str):
  160. return [(i + 1, json.dumps(obj)) for i, obj in enumerate(json.loads(s))]
  161. @udf(returnType=ArrayType(StructType([
  162. StructField("idx", IntegerType(), False),
  163. StructField("obj", StringType(), False),
  164. ])))
  165. def parse_jsonarr_to_strarr(s: str):
  166. return [(i + 1, obj) for i, obj in enumerate(json.loads(s))]
  167. # ==================== ARRAY ====================
  168. @udf(returnType=ArrayType(StringType()))
  169. def array_intersect(arr1, arr2):
  170. return list(set(arr1) & set(arr2))
  171. def array_append(array: List, new_element,
  172. ignore_null: bool = False,
  173. remove_duplicate: bool = False,
  174. need_sort: bool = False) -> List:
  175. if not array or len(array) == 0:
  176. if new_element or ignore_null is not True:
  177. return [new_element]
  178. return []
  179. if not new_element:
  180. if ignore_null is True:
  181. return array
  182. else:
  183. if array.__contains__(new_element) and remove_duplicate is True:
  184. return array
  185. array.append(new_element)
  186. if need_sort:
  187. array.sort()
  188. return array
  189. @udf(ArrayType(StringType()))
  190. def array_slice(input_array, start, end):
  191. if input_array:
  192. return input_array[start:end]
  193. return []
  194. @udf(returnType=ArrayType(StringType()))
  195. def merge_list(arr_list: List):
  196. res = set()
  197. for e in arr_list:
  198. if e is not None:
  199. for i in e:
  200. if i is not None and i != "":
  201. res.add(i)
  202. return list(res)
  203. @udf(returnType=ArrayType(StringType()))
  204. def merge_source(incr_source: List, old_source: List):
  205. res = set()
  206. if incr_source is not None:
  207. for i in incr_source:
  208. if i is not None and i != "":
  209. res.add(i)
  210. if old_source is not None:
  211. for i in old_source:
  212. if i is not None and i != "":
  213. res.add(i)
  214. return list(res)
  215. @udf(returnType=StructType([
  216. StructField("k", ArrayType(StringType()), False),
  217. StructField("kv", StringType()),
  218. ]))
  219. def parse_arr_and_count(arr, tag: str, return_count: int = -1):
  220. ele_cnt_dict = Counter(arr)
  221. json_list = sorted([{"code": key, "num": value} for key, value in ele_cnt_dict.items()], key=lambda x: x["num"], reverse=True)
  222. if return_count < 0:
  223. return [obj['code'] for obj in json_list], ",".join(['{' + f'{i["code"]},{tag}:{i["num"]}' + '}' for i in json_list])
  224. list_len = len(json_list)
  225. index = list_len if return_count >= list_len else return_count
  226. return [obj['code'] for obj in json_list][:index], ",".join(['{' + f'{i["code"]},{tag}:{i["num"]}' + '}' for i in json_list[:index]])
  227. @udf(returnType=StructType([
  228. StructField("sum", FloatType(), False),
  229. StructField("list", StringType()),
  230. ]))
  231. def parse_arr_and_sum(struct_arr, tag: str):
  232. sum_dict = {}
  233. for s in struct_arr:
  234. key = s[0]
  235. value: float = s[1]
  236. if key not in sum_dict:
  237. sum_dict[key] = 0.0
  238. if value is not None:
  239. sum_dict[key] += value
  240. json_list = sorted([{"code": key, "num": value} for key, value in sum_dict.items()], key=lambda x: x["num"], reverse=True)
  241. total = 0.0
  242. for obj in json_list:
  243. total += obj["num"]
  244. return round(total, 2), ",".join(['{' + f'{i["code"]},{tag}:{round(i["num"], 2)}' + '}' for i in json_list])
  245. # ==================== STRING ====================
  246. @udf(returnType=BooleanType())
  247. def has_chinese(datum: str) -> bool:
  248. if datum:
  249. pattern = re.compile(u'[\u4e00-\u9fa5]')
  250. if pattern.search(datum):
  251. return True
  252. return False
  253. @udf(returnType=FloatType())
  254. def similarity(left: str, right: str) -> float:
  255. return difflib.SequenceMatcher(None, left, right).quick_ratio()
  256. @udf(returnType=ArrayType(StringType()))
  257. def regexp_extract_all(col: str, ptn: str, g: int = 0):
  258. return [e.group(g) for e in re.compile(ptn).finditer(col if col else '')]
  259. def add_random_number_prefix(datum: str, separator: str, floor: int, ceiling: int) -> str:
  260. return f'{random.randint(floor, ceiling)}{separator}{datum}'
  261. def field_merge(delimiter: str, *fields_values):
  262. """多字段合并,相同仅保留一个,不同用 delimiter 分隔"""
  263. if not fields_values:
  264. return None
  265. result = []
  266. [result.append(value.strip()) for value in fields_values if value and value.strip() not in result]
  267. return delimiter.join(result)
  268. def space2null(text):
  269. if text and not text.isspace():
  270. return text
  271. return None
  272. def merge_ws(text: str):
  273. if text:
  274. return ' '.join(text.split())
  275. return None
  276. def remove_special_char(text, char):
  277. if text is not None and text.endswith(char):
  278. return text[:-1]
  279. return text
  280. @udf(returnType=ArrayType(StringType()))
  281. def explode_str_to_arr(text: str) -> list:
  282. """大于 8 位时,从后往前每次少一位截取子串入数组(用于前缀匹配场景)"""
  283. if text is None:
  284. return []
  285. if len(text) <= 8:
  286. return [text]
  287. return [text[:i] for i in range(len(text), 7, -1)]
  288. def html_unescape(text):
  289. return html.unescape(text)
  290. # ==================== NUMERIC / DATE / HASH ====================
  291. def max_value(*args):
  292. maxi_value = None
  293. for elem in args:
  294. if not elem:
  295. continue
  296. if not maxi_value or elem > maxi_value:
  297. maxi_value = elem
  298. return maxi_value
  299. def min_value(*args):
  300. mini_value = None
  301. for elem in args:
  302. if not elem:
  303. continue
  304. if not mini_value or elem < mini_value:
  305. mini_value = elem
  306. return mini_value
  307. def millis_timestamp_to_str(ts: int, str_format: str = None) -> str:
  308. date_time = datetime.fromtimestamp(ts / 1000.0)
  309. if str_format:
  310. return date_time.strftime(str_format)
  311. return date_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
  312. @udf(returnType=LongType())
  313. def parse_datetime_to_timestamp(date_time: str, in_milli_seconds: bool = False, original_format: str = None) -> int:
  314. """字符串日期 → 时间戳;支持 YY.MM.DD / YYYY年M月D日 启发式识别"""
  315. try:
  316. if date_time:
  317. d = date_time.split('.')
  318. if len(date_time) == 8 and len(d) == 3 and len(d[0]) == 2:
  319. date_time = '20' + date_time
  320. ret = re.match(r'(\d+)年(\d+)月(\d+)日', date_time)
  321. if ret:
  322. date_time = ret.group().replace('年', '-').replace('月', '-').replace('日', '')
  323. parsed_date = parse_datetime(date_time, original_format)
  324. if not parsed_date:
  325. return None
  326. if in_milli_seconds is True:
  327. return int(parsed_date.timestamp() * 1000)
  328. return int(parsed_date.timestamp())
  329. except:
  330. try:
  331. date_time = int(date_time)
  332. if datetime.now().timestamp() < date_time:
  333. return date_time if in_milli_seconds else int(date_time / 1000)
  334. return date_time * 1000 if in_milli_seconds else date_time
  335. except Exception:
  336. return None
  337. @udf(returnType=StringType())
  338. def get_md5(*cols: str) -> str:
  339. """多列拼接(带长度前缀防碰撞)后取 md5"""
  340. col_and_len_list = []
  341. for col in cols:
  342. if col is not None:
  343. col_and_len_list.append(str(len(col)))
  344. col_and_len_list.append(col)
  345. key = ''.join(col_and_len_list)
  346. if not key:
  347. return ''
  348. md5 = hashlib.md5()
  349. md5.update(key.encode("utf-8"))
  350. return md5.hexdigest()
  351. # ==================== CROSS-TYPE CONVERTERS ====================
  352. def array_to_json(arr: List):
  353. return json.dumps(arr, ensure_ascii=False)
  354. def map_to_json(map: dict):
  355. return json.dumps(map, ensure_ascii=False)
  356. def struct_to_json(struct):
  357. json_dict = {key: struct[key] for key in struct.__dict__["__fields__"]}
  358. return json.dumps(json_dict, ensure_ascii=False)
  359. def num_to_str(number):
  360. if isinstance(number, float) and number.is_integer():
  361. return '{:.0f}'.format(number)
  362. return str(int(number)) if isinstance(number, int) else str(number)
  363. @udf(returnType=ArrayType(StringType()))
  364. def str_to_arr(json_str: str) -> list:
  365. if json_str:
  366. return json.loads(json_str)
  367. return []
  368. @udf(returnType=ArrayType(StringType()))
  369. def str_to_json_arr(json_str):
  370. """JSON array 字符串 → list of json strings(每个元素再 json.dumps)"""
  371. if json_str:
  372. try:
  373. str_arr = json.loads(json_str)
  374. if isinstance(str_arr, list):
  375. return [json.dumps(sm) for sm in str_arr]
  376. except json.JSONDecodeError:
  377. return []
  378. return []
  379. @udf(returnType=ArrayType(MapType(StringType(), StringType())))
  380. def str_to_map_arr(json_str: str) -> list:
  381. if json_str:
  382. return json.loads(json_str)
  383. return []
  384. @udf(returnType=StringType())
  385. def split_str_to_jsonstr(str_list: List):
  386. """每个元素按 ':' 切成 k:v,聚合成 JSON 字符串"""
  387. res = []
  388. for kv_str in str_list:
  389. arr = kv_str.split(':')
  390. if len(arr) == 2:
  391. res.append({arr[0]: arr[1]})
  392. return json.dumps(res, ensure_ascii=False)
  393. @udf(returnType=MapType(StringType(), ArrayType(StringType())))
  394. def split_str_to_maparr(str_list: List):
  395. """每个元素按 ':' 切成 k:v,同 key 追加到 list"""
  396. res = {}
  397. for kv_str in str_list:
  398. arr = kv_str.split(':')
  399. if len(arr) == 2:
  400. if arr[0] not in res:
  401. res[arr[0]] = [arr[1]]
  402. else:
  403. res[arr[0]].append(arr[1])
  404. return res