spark_common_udf.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. #!/usr/bin/env /usr/bin/python3
  2. # -*- coding:utf-8 -*-
  3. import difflib
  4. import json
  5. import random
  6. import re
  7. import traceback
  8. from datetime import datetime
  9. from typing import Union, List, Dict
  10. import pygeohash
  11. from pyspark.sql.functions import udf
  12. from pyspark.sql.types import StringType, ArrayType, BooleanType, FloatType, LongType, MapType
  13. from dw_base.utils.datetime_utils import parse_datetime
  14. def add_random_number_prefix(datum: str, separator: str, floor: int, ceiling: int) -> str:
  15. """
  16. 为字段添加随机数字前缀
  17. Args:
  18. datum:
  19. separator: 原数据与随机前缀的分隔符
  20. floor: 随机数字前缀下限
  21. ceiling: 随机数字前缀上限
  22. Returns:
  23. """
  24. return f'{random.randint(floor, ceiling)}{separator}{datum}'
  25. def append_to_json_array(json_array_string: str, new_element, remove_duplicate: bool = False) -> str:
  26. """
  27. 向JSON array添加元素
  28. Args:
  29. json_array_string: JSON array字符串
  30. new_element: 要添加的元素
  31. remove_duplicate: 是否去重
  32. Returns:
  33. """
  34. if not new_element:
  35. return json_array_string
  36. if not json_array_string:
  37. return json.dumps([new_element], ensure_ascii=False)
  38. json_array = json.loads(json_array_string) # type: list
  39. json_array.append(new_element)
  40. if remove_duplicate is True:
  41. result = []
  42. for elem in json_array:
  43. if result.__contains__(elem):
  44. continue
  45. result.append(elem)
  46. return json.dumps(result, ensure_ascii=False)
  47. return json.dumps(json_array, ensure_ascii=False)
  48. def array_append(array: List, new_element,
  49. ignore_null: bool = False,
  50. remove_duplicate: bool = False,
  51. need_sort: bool = False) -> List:
  52. if not array or len(array) == 0:
  53. if new_element or ignore_null is not True:
  54. return [new_element]
  55. return []
  56. if not new_element:
  57. if ignore_null is True:
  58. return array
  59. else:
  60. if array.__contains__(new_element) and remove_duplicate is True:
  61. return array
  62. array.append(new_element)
  63. if need_sort:
  64. array.sort()
  65. return array
  66. def field_merge(delimiter: str, *fields_values):
  67. """
  68. 两个字段合并,如果相同只取一个,不同用delimiter分隔
  69. Args:
  70. delimiter:
  71. *fields_values:
  72. Returns:
  73. """
  74. if not fields_values:
  75. return None
  76. result = []
  77. [result.append(value.strip()) for value in fields_values if value and value.strip() not in result]
  78. return delimiter.join(result)
  79. def flatten_json(json_str: str, reserve_parent: bool = True) -> str:
  80. """
  81. 展平json
  82. Args:
  83. json_str: 待展平的json
  84. reserve_parent: 是否保留父key,默认保留
  85. Returns:
  86. """
  87. def flatten_json_node(parent, json_element) -> Union[float, int, str, Dict, List]:
  88. if isinstance(json_element, dict):
  89. result = {}
  90. if parent and reserve_parent and reserve_parent is True:
  91. for key, value in json_element.items():
  92. result[f'{parent}.{key}'] = value
  93. else:
  94. for key, value in json_element.items():
  95. result.update(flatten_json_node(key, value))
  96. return result
  97. elif isinstance(json_element, list):
  98. result = []
  99. if parent and reserve_parent and reserve_parent is True:
  100. for index in range(len(json_element)):
  101. result.append(flatten_json_node(f'{parent}.[{index}]', json_element[index]))
  102. else:
  103. for index in range(len(json_element)):
  104. result.append(flatten_json_node(None, json_element[index]))
  105. return result
  106. else:
  107. return {parent: json_element}
  108. if not json_str:
  109. return json_str
  110. try:
  111. json_node = json.loads(json_str)
  112. flattened_json = flatten_json_node(None, json_node)
  113. return json.dumps(flattened_json, ensure_ascii=False)
  114. except Exception as e:
  115. traceback.format_exc(e)
  116. return json_str
  117. def geo_hash(latitude: float, longitude: float, precision: int) -> str:
  118. return pygeohash.encode(latitude, longitude, precision)
  119. @udf(returnType=BooleanType())
  120. def has_chinese(datum: str) -> bool:
  121. if datum:
  122. pattern = re.compile(u'[\u4e00-\u9fa5]')
  123. match = pattern.search(datum)
  124. if match:
  125. return True
  126. return False
  127. @udf(returnType=BooleanType())
  128. def is_json(data) -> bool:
  129. try:
  130. json.loads(data)
  131. except:
  132. return False
  133. return True
  134. def json_array_subset(json_array_string: str,
  135. subset_fields: Union[List, str],
  136. as_list: bool = False,
  137. skip_null: bool = False) -> str:
  138. """
  139. 获取json object array string的子集
  140. Args:
  141. json_array_string:
  142. subset_fields: 子集字段
  143. as_list: 如果子集字段只有1个,是否以list返回
  144. skip_null: 字段的值是None,是否添加在返回的数据中
  145. Returns: 子集数组的字符串
  146. """
  147. if not json_array_string:
  148. return None
  149. if not subset_fields:
  150. return None
  151. if isinstance(subset_fields, str):
  152. subset_field_list = subset_fields.split(',')
  153. else:
  154. subset_field_list = subset_fields
  155. if len(subset_field_list) == 0:
  156. return None
  157. try:
  158. json_array = json.loads(json_array_string)
  159. except:
  160. json_array = eval(json_array_string)
  161. list_subset = []
  162. if len(subset_field_list) == 1 and as_list:
  163. only_subset_field = subset_field_list[0]
  164. for element in json_array: # type:Dict
  165. if isinstance(element, dict):
  166. field_value = element.get(only_subset_field)
  167. if field_value or not skip_null:
  168. list_subset.append(field_value)
  169. else:
  170. for element in json_array: # type:Dict
  171. subset_of_element = {}
  172. if isinstance(element, dict):
  173. for field in subset_field_list:
  174. field_value = element.get(field)
  175. if field_value or not skip_null:
  176. subset_of_element[field] = field_value
  177. list_subset.append(subset_of_element)
  178. return json.dumps(list_subset, ensure_ascii=False)
  179. @udf(returnType=ArrayType(StringType()))
  180. def json_object_keys(json_str: str) -> List[str]:
  181. if not json_str:
  182. return None
  183. try:
  184. json_dict = json.loads(json_str) # type:dict
  185. return [k for k in json_dict.keys()]
  186. except:
  187. return None
  188. def max_value(*args):
  189. maxi_value = None
  190. for elem in args:
  191. if not elem:
  192. continue
  193. if not maxi_value or elem > maxi_value:
  194. maxi_value = elem
  195. return maxi_value
  196. def min_value(*args):
  197. mini_value = None
  198. for elem in args:
  199. if not elem:
  200. continue
  201. if not mini_value or elem < mini_value:
  202. mini_value = elem
  203. return mini_value
  204. def millis_timestamp_to_str(ts: int, str_format: str = None) -> str:
  205. date_time = datetime.fromtimestamp(ts / 1000.0)
  206. if str_format:
  207. return date_time.strftime(str_format)
  208. return date_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
  209. @udf(returnType=LongType())
  210. def parse_datetime_to_timestamp(date_time: str, in_milli_seconds: bool = False, original_format: str = None) -> int:
  211. """
  212. 把字符串表达的日期转为时间戳
  213. Args:
  214. date_time: 日期
  215. original_format: 原日期格式,不传则智能识别
  216. in_milli_seconds: 是否返回毫秒
  217. Returns:
  218. 转换后的日期
  219. """
  220. try:
  221. if date_time:
  222. d = date_time.split('.')
  223. if len(date_time) == 8 and len(d) == 3 and len(d[0]) == 2:
  224. date_time = '20' + date_time
  225. ret = re.match('(\d+)年(\d+)月(\d+)日', date_time)
  226. if ret:
  227. date_time = ret.group().replace('年', '-').replace('月', '-').replace('日', '')
  228. parsed_date = parse_datetime(date_time, original_format)
  229. if not parsed_date:
  230. return None
  231. if in_milli_seconds is True:
  232. return int(parsed_date.timestamp() * 1000)
  233. return int(parsed_date.timestamp())
  234. except:
  235. try:
  236. date_time = int(date_time)
  237. # 当前时间小于传入的时间戳,认为是毫秒
  238. if datetime.now().timestamp() < date_time:
  239. if in_milli_seconds is True:
  240. return date_time
  241. else:
  242. return int(date_time / 1000)
  243. else:
  244. if in_milli_seconds is True:
  245. return date_time * 1000
  246. else:
  247. return date_time
  248. except Exception as e:
  249. return None
  250. @udf(returnType=FloatType())
  251. def similarity(left: str, right: str) -> float:
  252. """
  253. 计算两个字符串的相似度
  254. Args:
  255. left:
  256. right:
  257. Returns:
  258. """
  259. return difflib.SequenceMatcher(None, left, right).quick_ratio()
  260. def remove_empty_key(info):
  261. """
  262. 删除json中value为空的key
  263. Returns: json
  264. """
  265. json_info = json.loads(info)
  266. def internal_remove(json_info):
  267. try:
  268. if isinstance(json_info, dict):
  269. info_re = dict()
  270. for key, value in json_info.items():
  271. if isinstance(value, dict) or isinstance(value, list):
  272. re = internal_remove(value)
  273. if len(re):
  274. info_re[key] = re
  275. elif value not in ['', {}, [], 'null', None]:
  276. info_re[key] = str(value)
  277. return info_re
  278. elif isinstance(json_info, list):
  279. info_re = list()
  280. for value in json_info:
  281. if isinstance(value, dict) or isinstance(value, list):
  282. re = internal_remove(value)
  283. if len(re):
  284. info_re.append(re)
  285. elif value not in ['', {}, [], 'null', None]:
  286. info_re.append(str(value))
  287. return info_re
  288. else:
  289. return None
  290. except Exception as e:
  291. return None
  292. return json.dumps(internal_remove(json_info), ensure_ascii=False)
  293. @udf(returnType=ArrayType(StringType()))
  294. def regexp_extract_all(col: str, ptn: str, g: int = 0):
  295. return [e.group(g) for e in re.compile(ptn).finditer(col if col else '')]
  296. @udf(returnType=ArrayType(StringType()))
  297. def array_intersect(arr1, arr2):
  298. """
  299. 计算两个数组的交集
  300. :param arr1:
  301. :param arr2:
  302. :return:
  303. """
  304. return list(set(arr1) & set(arr2))
  305. def array_to_json(arr: List):
  306. """
  307. 数组转为jsonstring
  308. :param arr:
  309. :return:
  310. """
  311. return json.dumps(arr, ensure_ascii=False)
  312. def map_to_json(map: dict):
  313. """
  314. map转为jsonstring
  315. """
  316. return json.dumps(map, ensure_ascii=False)
  317. def struct_to_json(struct):
  318. json_dict = {key: struct[key] for key in struct.__dict__["__fields__"]}
  319. return json.dumps(json_dict, ensure_ascii=False)
  320. @udf(returnType=ArrayType(MapType(StringType(), StringType())))
  321. def str_to_map_arr(json_str: str) -> list:
  322. if json_str:
  323. return json.loads(json_str)
  324. return []
  325. def num_to_str(number):
  326. # 确保 number 是 float 类型
  327. if isinstance(number, float) and number.is_integer():
  328. return '{:.0f}'.format(number)
  329. else:
  330. return str(int(number)) if isinstance(number, int) else str(number)
  331. def space2null(text):
  332. if text and not text.isspace():
  333. return text
  334. return None
  335. if __name__ == '__main__':
  336. cases = [
  337. '',
  338. None,
  339. ' ',
  340. ' ',
  341. 'hello'
  342. ]
  343. for case in cases:
  344. print(space2null(case))