spark_common_udf.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. #!/usr/bin/env /usr/bin/python3
  2. # -*- coding:utf-8 -*-
  3. """
  4. 通用 UDF —— 与业务无关的数据类型 / 格式操作(JSON / Array / String / Numeric / Date / Hash)
  5. SparkSQL 入口自动 ADD FILE 注册;业务专用 UDF 请放到 dw_base/spark/udf/business/ 下按需加载
  6. """
  7. import difflib
  8. import hashlib
  9. import html
  10. import json
  11. import random
  12. import re
  13. import traceback
  14. from collections import Counter
  15. from datetime import datetime
  16. from typing import Dict, List, Union
  17. import pygeohash
  18. from pyspark.sql.functions import udf
  19. from pyspark.sql.types import (
  20. ArrayType, BooleanType, FloatType, IntegerType, LongType, MapType,
  21. StringType, StructField, StructType,
  22. )
  23. from dw_base.utils.datetime_utils import parse_datetime
  24. # ==================== JSON ====================
  25. @udf(returnType=BooleanType())
  26. def is_json(data) -> bool:
  27. try:
  28. json.loads(data)
  29. except:
  30. return False
  31. return True
  32. @udf(returnType=ArrayType(StringType()))
  33. def json_object_keys(json_str: str) -> List[str]:
  34. if not json_str:
  35. return None
  36. try:
  37. json_dict = json.loads(json_str) # type:dict
  38. return [k for k in json_dict.keys()]
  39. except:
  40. return None
  41. def flatten_json(json_str: str, reserve_parent: bool = True) -> str:
  42. """展平 json,reserve_parent 控制是否保留父 key"""
  43. def flatten_json_node(parent, json_element) -> Union[float, int, str, Dict, List]:
  44. if isinstance(json_element, dict):
  45. result = {}
  46. if parent and reserve_parent and reserve_parent is True:
  47. for key, value in json_element.items():
  48. result[f'{parent}.{key}'] = value
  49. else:
  50. for key, value in json_element.items():
  51. result.update(flatten_json_node(key, value))
  52. return result
  53. elif isinstance(json_element, list):
  54. result = []
  55. if parent and reserve_parent and reserve_parent is True:
  56. for index in range(len(json_element)):
  57. result.append(flatten_json_node(f'{parent}.[{index}]', json_element[index]))
  58. else:
  59. for index in range(len(json_element)):
  60. result.append(flatten_json_node(None, json_element[index]))
  61. return result
  62. else:
  63. return {parent: json_element}
  64. if not json_str:
  65. return json_str
  66. try:
  67. json_node = json.loads(json_str)
  68. flattened_json = flatten_json_node(None, json_node)
  69. return json.dumps(flattened_json, ensure_ascii=False)
  70. except Exception as e:
  71. traceback.format_exc(e)
  72. return json_str
  73. def remove_empty_key(info):
  74. """递归删除 json 中 value 为空的 key"""
  75. json_info = json.loads(info)
  76. def internal_remove(json_info):
  77. try:
  78. if isinstance(json_info, dict):
  79. info_re = dict()
  80. for key, value in json_info.items():
  81. if isinstance(value, dict) or isinstance(value, list):
  82. re = internal_remove(value)
  83. if len(re):
  84. info_re[key] = re
  85. elif value not in ['', {}, [], 'null', None]:
  86. info_re[key] = str(value)
  87. return info_re
  88. elif isinstance(json_info, list):
  89. info_re = list()
  90. for value in json_info:
  91. if isinstance(value, dict) or isinstance(value, list):
  92. re = internal_remove(value)
  93. if len(re):
  94. info_re.append(re)
  95. elif value not in ['', {}, [], 'null', None]:
  96. info_re.append(str(value))
  97. return info_re
  98. else:
  99. return None
  100. except Exception as e:
  101. return None
  102. return json.dumps(internal_remove(json_info), ensure_ascii=False)
  103. def append_to_json_array(json_array_string: str, new_element, remove_duplicate: bool = False) -> str:
  104. """向 JSON array 追加元素,可选去重"""
  105. if not new_element:
  106. return json_array_string
  107. if not json_array_string:
  108. return json.dumps([new_element], ensure_ascii=False)
  109. json_array = json.loads(json_array_string) # type: list
  110. json_array.append(new_element)
  111. if remove_duplicate is True:
  112. result = []
  113. for elem in json_array:
  114. if result.__contains__(elem):
  115. continue
  116. result.append(elem)
  117. return json.dumps(result, ensure_ascii=False)
  118. return json.dumps(json_array, ensure_ascii=False)
  119. def json_array_subset(json_array_string: str,
  120. subset_fields: Union[List, str],
  121. as_list: bool = False,
  122. skip_null: bool = False) -> str:
  123. """按字段提取 json object array 的子集"""
  124. if not json_array_string:
  125. return None
  126. if not subset_fields:
  127. return None
  128. if isinstance(subset_fields, str):
  129. subset_field_list = subset_fields.split(',')
  130. else:
  131. subset_field_list = subset_fields
  132. if len(subset_field_list) == 0:
  133. return None
  134. try:
  135. json_array = json.loads(json_array_string)
  136. except:
  137. json_array = eval(json_array_string)
  138. list_subset = []
  139. if len(subset_field_list) == 1 and as_list:
  140. only_subset_field = subset_field_list[0]
  141. for element in json_array: # type:Dict
  142. if isinstance(element, dict):
  143. field_value = element.get(only_subset_field)
  144. if field_value or not skip_null:
  145. list_subset.append(field_value)
  146. else:
  147. for element in json_array: # type:Dict
  148. subset_of_element = {}
  149. if isinstance(element, dict):
  150. for field in subset_field_list:
  151. field_value = element.get(field)
  152. if field_value or not skip_null:
  153. subset_of_element[field] = field_value
  154. list_subset.append(subset_of_element)
  155. return json.dumps(list_subset, ensure_ascii=False)
  156. @udf(returnType=ArrayType(StructType([
  157. StructField("idx", IntegerType(), False),
  158. StructField("obj", StringType(), False),
  159. ])))
  160. def parse_jsonarr_to_arr(s: str):
  161. return [(i + 1, json.dumps(obj)) for i, obj in enumerate(json.loads(s))]
  162. @udf(returnType=ArrayType(StructType([
  163. StructField("idx", IntegerType(), False),
  164. StructField("obj", StringType(), False),
  165. ])))
  166. def parse_jsonarr_to_strarr(s: str):
  167. return [(i + 1, obj) for i, obj in enumerate(json.loads(s))]
  168. # ==================== ARRAY ====================
  169. @udf(returnType=ArrayType(StringType()))
  170. def array_intersect(arr1, arr2):
  171. return list(set(arr1) & set(arr2))
  172. def array_append(array: List, new_element,
  173. ignore_null: bool = False,
  174. remove_duplicate: bool = False,
  175. need_sort: bool = False) -> List:
  176. if not array or len(array) == 0:
  177. if new_element or ignore_null is not True:
  178. return [new_element]
  179. return []
  180. if not new_element:
  181. if ignore_null is True:
  182. return array
  183. else:
  184. if array.__contains__(new_element) and remove_duplicate is True:
  185. return array
  186. array.append(new_element)
  187. if need_sort:
  188. array.sort()
  189. return array
  190. @udf(ArrayType(StringType()))
  191. def array_slice(input_array, start, end):
  192. if input_array:
  193. return input_array[start:end]
  194. return []
  195. @udf(returnType=ArrayType(StringType()))
  196. def merge_list(arr_list: List):
  197. res = set()
  198. for e in arr_list:
  199. if e is not None:
  200. for i in e:
  201. if i is not None and i != "":
  202. res.add(i)
  203. return list(res)
  204. @udf(returnType=ArrayType(StringType()))
  205. def merge_source(incr_source: List, old_source: List):
  206. res = set()
  207. if incr_source is not None:
  208. for i in incr_source:
  209. if i is not None and i != "":
  210. res.add(i)
  211. if old_source is not None:
  212. for i in old_source:
  213. if i is not None and i != "":
  214. res.add(i)
  215. return list(res)
  216. @udf(returnType=StructType([
  217. StructField("k", ArrayType(StringType()), False),
  218. StructField("kv", StringType()),
  219. ]))
  220. def parse_arr_and_count(arr, tag: str, return_count: int = -1):
  221. ele_cnt_dict = Counter(arr)
  222. json_list = sorted([{"code": key, "num": value} for key, value in ele_cnt_dict.items()], key=lambda x: x["num"], reverse=True)
  223. if return_count < 0:
  224. return [obj['code'] for obj in json_list], ",".join(['{' + f'{i["code"]},{tag}:{i["num"]}' + '}' for i in json_list])
  225. list_len = len(json_list)
  226. index = list_len if return_count >= list_len else return_count
  227. return [obj['code'] for obj in json_list][:index], ",".join(['{' + f'{i["code"]},{tag}:{i["num"]}' + '}' for i in json_list[:index]])
  228. @udf(returnType=StructType([
  229. StructField("sum", FloatType(), False),
  230. StructField("list", StringType()),
  231. ]))
  232. def parse_arr_and_sum(struct_arr, tag: str):
  233. sum_dict = {}
  234. for s in struct_arr:
  235. key = s[0]
  236. value: float = s[1]
  237. if key not in sum_dict:
  238. sum_dict[key] = 0.0
  239. if value is not None:
  240. sum_dict[key] += value
  241. json_list = sorted([{"code": key, "num": value} for key, value in sum_dict.items()], key=lambda x: x["num"], reverse=True)
  242. total = 0.0
  243. for obj in json_list:
  244. total += obj["num"]
  245. return round(total, 2), ",".join(['{' + f'{i["code"]},{tag}:{round(i["num"], 2)}' + '}' for i in json_list])
  246. # ==================== STRING ====================
  247. @udf(returnType=BooleanType())
  248. def has_chinese(datum: str) -> bool:
  249. if datum:
  250. pattern = re.compile(u'[\u4e00-\u9fa5]')
  251. if pattern.search(datum):
  252. return True
  253. return False
  254. @udf(returnType=FloatType())
  255. def similarity(left: str, right: str) -> float:
  256. return difflib.SequenceMatcher(None, left, right).quick_ratio()
  257. @udf(returnType=ArrayType(StringType()))
  258. def regexp_extract_all(col: str, ptn: str, g: int = 0):
  259. return [e.group(g) for e in re.compile(ptn).finditer(col if col else '')]
  260. def add_random_number_prefix(datum: str, separator: str, floor: int, ceiling: int) -> str:
  261. return f'{random.randint(floor, ceiling)}{separator}{datum}'
  262. def field_merge(delimiter: str, *fields_values):
  263. """多字段合并,相同仅保留一个,不同用 delimiter 分隔"""
  264. if not fields_values:
  265. return None
  266. result = []
  267. [result.append(value.strip()) for value in fields_values if value and value.strip() not in result]
  268. return delimiter.join(result)
  269. def space2null(text):
  270. if text and not text.isspace():
  271. return text
  272. return None
  273. def merge_ws(text: str):
  274. if text:
  275. return ' '.join(text.split())
  276. return None
  277. def remove_special_char(text, char):
  278. if text is not None and text.endswith(char):
  279. return text[:-1]
  280. return text
  281. @udf(returnType=ArrayType(StringType()))
  282. def explode_str_to_arr(text: str) -> list:
  283. """大于 8 位时,从后往前每次少一位截取子串入数组(用于前缀匹配场景)"""
  284. if text is None:
  285. return []
  286. if len(text) <= 8:
  287. return [text]
  288. return [text[:i] for i in range(len(text), 7, -1)]
  289. def html_unescape(text):
  290. return html.unescape(text)
  291. # ==================== NUMERIC / DATE / HASH ====================
  292. def max_value(*args):
  293. maxi_value = None
  294. for elem in args:
  295. if not elem:
  296. continue
  297. if not maxi_value or elem > maxi_value:
  298. maxi_value = elem
  299. return maxi_value
  300. def min_value(*args):
  301. mini_value = None
  302. for elem in args:
  303. if not elem:
  304. continue
  305. if not mini_value or elem < mini_value:
  306. mini_value = elem
  307. return mini_value
  308. def geo_hash(latitude: float, longitude: float, precision: int) -> str:
  309. return pygeohash.encode(latitude, longitude, precision)
  310. def millis_timestamp_to_str(ts: int, str_format: str = None) -> str:
  311. date_time = datetime.fromtimestamp(ts / 1000.0)
  312. if str_format:
  313. return date_time.strftime(str_format)
  314. return date_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
  315. @udf(returnType=LongType())
  316. def parse_datetime_to_timestamp(date_time: str, in_milli_seconds: bool = False, original_format: str = None) -> int:
  317. """字符串日期 → 时间戳;支持 YY.MM.DD / YYYY年M月D日 启发式识别"""
  318. try:
  319. if date_time:
  320. d = date_time.split('.')
  321. if len(date_time) == 8 and len(d) == 3 and len(d[0]) == 2:
  322. date_time = '20' + date_time
  323. ret = re.match(r'(\d+)年(\d+)月(\d+)日', date_time)
  324. if ret:
  325. date_time = ret.group().replace('年', '-').replace('月', '-').replace('日', '')
  326. parsed_date = parse_datetime(date_time, original_format)
  327. if not parsed_date:
  328. return None
  329. if in_milli_seconds is True:
  330. return int(parsed_date.timestamp() * 1000)
  331. return int(parsed_date.timestamp())
  332. except:
  333. try:
  334. date_time = int(date_time)
  335. if datetime.now().timestamp() < date_time:
  336. return date_time if in_milli_seconds else int(date_time / 1000)
  337. return date_time * 1000 if in_milli_seconds else date_time
  338. except Exception:
  339. return None
  340. @udf(returnType=StringType())
  341. def get_md5(*cols: str) -> str:
  342. """多列拼接(带长度前缀防碰撞)后取 md5"""
  343. col_and_len_list = []
  344. for col in cols:
  345. if col is not None:
  346. col_and_len_list.append(str(len(col)))
  347. col_and_len_list.append(col)
  348. key = ''.join(col_and_len_list)
  349. if not key:
  350. return ''
  351. md5 = hashlib.md5()
  352. md5.update(key.encode("utf-8"))
  353. return md5.hexdigest()
  354. # ==================== CROSS-TYPE CONVERTERS ====================
  355. def array_to_json(arr: List):
  356. return json.dumps(arr, ensure_ascii=False)
  357. def map_to_json(map: dict):
  358. return json.dumps(map, ensure_ascii=False)
  359. def struct_to_json(struct):
  360. json_dict = {key: struct[key] for key in struct.__dict__["__fields__"]}
  361. return json.dumps(json_dict, ensure_ascii=False)
  362. def num_to_str(number):
  363. if isinstance(number, float) and number.is_integer():
  364. return '{:.0f}'.format(number)
  365. return str(int(number)) if isinstance(number, int) else str(number)
  366. @udf(returnType=ArrayType(StringType()))
  367. def str_to_arr(json_str: str) -> list:
  368. if json_str:
  369. return json.loads(json_str)
  370. return []
  371. @udf(returnType=ArrayType(StringType()))
  372. def str_to_json_arr(json_str):
  373. """JSON array 字符串 → list of json strings(每个元素再 json.dumps)"""
  374. if json_str:
  375. try:
  376. str_arr = json.loads(json_str)
  377. if isinstance(str_arr, list):
  378. return [json.dumps(sm) for sm in str_arr]
  379. except json.JSONDecodeError:
  380. return []
  381. return []
  382. @udf(returnType=ArrayType(MapType(StringType(), StringType())))
  383. def str_to_map_arr(json_str: str) -> list:
  384. if json_str:
  385. return json.loads(json_str)
  386. return []
  387. @udf(returnType=StringType())
  388. def split_str_to_jsonstr(str_list: List):
  389. """每个元素按 ':' 切成 k:v,聚合成 JSON 字符串"""
  390. res = []
  391. for kv_str in str_list:
  392. arr = kv_str.split(':')
  393. if len(arr) == 2:
  394. res.append({arr[0]: arr[1]})
  395. return json.dumps(res, ensure_ascii=False)
  396. @udf(returnType=MapType(StringType(), ArrayType(StringType())))
  397. def split_str_to_maparr(str_list: List):
  398. """每个元素按 ':' 切成 k:v,同 key 追加到 list"""
  399. res = {}
  400. for kv_str in str_list:
  401. arr = kv_str.split(':')
  402. if len(arr) == 2:
  403. if arr[0] not in res:
  404. res[arr[0]] = [arr[1]]
  405. else:
  406. res[arr[0]].append(arr[1])
  407. return res