Эх сурвалжийг харах

refactor(udf): 合并通用 UDF 并补单测

tianyu.chu 2 долоо хоног өмнө
parent
commit
a437262176

+ 103 - 136
dw_base/udf/common/spark_common_udf.py

@@ -11,44 +11,85 @@ import html
 import json
 import random
 import re
-import traceback
-from collections import Counter
+from ast import literal_eval
 from datetime import datetime
 from typing import Dict, List, Union
 
 from pyspark.sql.functions import udf
 from pyspark.sql.types import (
-    ArrayType, BooleanType, FloatType, IntegerType, LongType, MapType,
-    StringType, StructField, StructType,
+    ArrayType, BooleanType, FloatType, LongType, MapType, StringType,
 )
 
 from dw_base.utils.datetime_utils import parse_datetime
 
 
+def _load_json_or_default(data, default=None):
+    """优先按 JSON 解析,失败时返回默认值。"""
+    try:
+        return json.loads(data)
+    except (TypeError, ValueError):
+        return default
+
+
+def _load_json_or_literal(data, default=None):
+    """先按 JSON 解析,失败后再按 Python 字面量兜底解析。"""
+    parsed = _load_json_or_default(data, default=None)
+    if parsed is not None:
+        return parsed
+    try:
+        return literal_eval(data)
+    except (ValueError, SyntaxError, TypeError):
+        return default
+
+
+def _dedupe_keep_order(values: List) -> List:
+    """按原始顺序去重。"""
+    result = []
+    for value in values:
+        if value not in result:
+            result.append(value)
+    return result
+
+
+def _merge_non_empty_values(*arrays: List) -> List[str]:
+    """合并多个数组,并过滤 None 与空字符串。"""
+    result = set()
+    for array in arrays:
+        if array is None:
+            continue
+        for item in array:
+            if item is not None and item != "":
+                result.add(item)
+    return list(result)
+
+
 # ==================== JSON ====================
 
+# UDF-01 JSON校验:判断输入是否为合法 JSON 字符串。
 @udf(returnType=BooleanType())
 def is_json(data) -> bool:
+    """判断输入是否为合法 JSON 字符串。"""
     try:
         json.loads(data)
-    except:
+    except (TypeError, ValueError):
         return False
     return True
 
 
+# UDF-02 JSON取键:提取 JSON object 的 key 列表。
 @udf(returnType=ArrayType(StringType()))
 def json_object_keys(json_str: str) -> List[str]:
+    """提取 JSON object 的 key 列表。"""
     if not json_str:
         return None
-    try:
-        json_dict = json.loads(json_str)  # type:dict
-        return [k for k in json_dict.keys()]
-    except:
+    json_dict = _load_json_or_default(json_str, default=None)  # type:dict
+    if not isinstance(json_dict, dict):
         return None
+    return [k for k in json_dict.keys()]
 
 
 def flatten_json(json_str: str, reserve_parent: bool = True) -> str:
-    """展平 json,reserve_parent 控制是否保留父 key"""
+    """展平 JSON 字符串,`reserve_parent` 控制是否保留父级 key。"""
 
     def flatten_json_node(parent, json_element) -> Union[float, int, str, Dict, List]:
         if isinstance(json_element, dict):
@@ -78,13 +119,12 @@ def flatten_json(json_str: str, reserve_parent: bool = True) -> str:
         json_node = json.loads(json_str)
         flattened_json = flatten_json_node(None, json_node)
         return json.dumps(flattened_json, ensure_ascii=False)
-    except Exception as e:
-        traceback.format_exc(e)
+    except (TypeError, ValueError):
         return json_str
 
 
 def remove_empty_key(info):
-    """递归删除 json 中 value 为空的 key"""
+    """递归删除 JSON 中 value 为空的 key。"""
     json_info = json.loads(info)
 
     def internal_remove(json_info):
@@ -118,20 +158,17 @@ def remove_empty_key(info):
 
 
 def append_to_json_array(json_array_string: str, new_element, remove_duplicate: bool = False) -> str:
-    """向 JSON array 追加元素,可选去重"""
+    """向 JSON array 末尾追加元素,可选去重"""
     if not new_element:
         return json_array_string
     if not json_array_string:
         return json.dumps([new_element], ensure_ascii=False)
-    json_array = json.loads(json_array_string)  # type: list
+    json_array = _load_json_or_default(json_array_string, default=None)  # type: list
+    if not isinstance(json_array, list):
+        return json_array_string
     json_array.append(new_element)
     if remove_duplicate is True:
-        result = []
-        for elem in json_array:
-            if result.__contains__(elem):
-                continue
-            result.append(elem)
-        return json.dumps(result, ensure_ascii=False)
+        return json.dumps(_dedupe_keep_order(json_array), ensure_ascii=False)
     return json.dumps(json_array, ensure_ascii=False)
 
 
@@ -139,7 +176,7 @@ def json_array_subset(json_array_string: str,
                       subset_fields: Union[List, str],
                       as_list: bool = False,
                       skip_null: bool = False) -> str:
-    """按字段提取 json object array 的子集"""
+    """按字段提取 JSON object array 的子集。"""
     if not json_array_string:
         return None
     if not subset_fields:
@@ -150,10 +187,9 @@ def json_array_subset(json_array_string: str,
         subset_field_list = subset_fields
     if len(subset_field_list) == 0:
         return None
-    try:
-        json_array = json.loads(json_array_string)
-    except:
-        json_array = eval(json_array_string)
+    json_array = _load_json_or_literal(json_array_string, default=None)
+    if not isinstance(json_array, list):
+        return None
     list_subset = []
     if len(subset_field_list) == 1 and as_list:
         only_subset_field = subset_field_list[0]
@@ -174,26 +210,12 @@ def json_array_subset(json_array_string: str,
     return json.dumps(list_subset, ensure_ascii=False)
 
 
-@udf(returnType=ArrayType(StructType([
-    StructField("idx", IntegerType(), False),
-    StructField("obj", StringType(), False),
-])))
-def parse_jsonarr_to_arr(s: str):
-    return [(i + 1, json.dumps(obj)) for i, obj in enumerate(json.loads(s))]
-
-
-@udf(returnType=ArrayType(StructType([
-    StructField("idx", IntegerType(), False),
-    StructField("obj", StringType(), False),
-])))
-def parse_jsonarr_to_strarr(s: str):
-    return [(i + 1, obj) for i, obj in enumerate(json.loads(s))]
-
-
 # ==================== ARRAY ====================
 
+# UDF-21 数组交集:计算两个数组的交集。
 @udf(returnType=ArrayType(StringType()))
 def array_intersect(arr1, arr2):
+    """计算两个数组的交集。"""
     return list(set(arr1) & set(arr2))
 
 
@@ -201,6 +223,7 @@ def array_append(array: List, new_element,
                  ignore_null: bool = False,
                  remove_duplicate: bool = False,
                  need_sort: bool = False) -> List:
+    """向数组追加元素,可按现有规则控制空值、去重和排序。"""
     if not array or len(array) == 0:
         if new_element or ignore_null is not True:
             return [new_element]
@@ -217,76 +240,28 @@ def array_append(array: List, new_element,
     return array
 
 
+# UDF-22 数组切片:按起止下标截取数组。
 @udf(ArrayType(StringType()))
 def array_slice(input_array, start, end):
+    """截取数组切片,行为与 Python 切片一致。"""
     if input_array:
         return input_array[start:end]
     return []
 
 
+# UDF-23 数组合并:合并二维数组,并过滤 None 与空字符串。
 @udf(returnType=ArrayType(StringType()))
 def merge_list(arr_list: List):
-    res = set()
-    for e in arr_list:
-        if e is not None:
-            for i in e:
-                if i is not None and i != "":
-                    res.add(i)
-    return list(res)
-
-
-@udf(returnType=ArrayType(StringType()))
-def merge_source(incr_source: List, old_source: List):
-    res = set()
-    if incr_source is not None:
-        for i in incr_source:
-            if i is not None and i != "":
-                res.add(i)
-    if old_source is not None:
-        for i in old_source:
-            if i is not None and i != "":
-                res.add(i)
-    return list(res)
-
-
-@udf(returnType=StructType([
-    StructField("k", ArrayType(StringType()), False),
-    StructField("kv", StringType()),
-]))
-def parse_arr_and_count(arr, tag: str, return_count: int = -1):
-    ele_cnt_dict = Counter(arr)
-    json_list = sorted([{"code": key, "num": value} for key, value in ele_cnt_dict.items()], key=lambda x: x["num"], reverse=True)
-    if return_count < 0:
-        return [obj['code'] for obj in json_list], ",".join(['{' + f'{i["code"]},{tag}:{i["num"]}' + '}' for i in json_list])
-    list_len = len(json_list)
-    index = list_len if return_count >= list_len else return_count
-    return [obj['code'] for obj in json_list][:index], ",".join(['{' + f'{i["code"]},{tag}:{i["num"]}' + '}' for i in json_list[:index]])
-
-
-@udf(returnType=StructType([
-    StructField("sum", FloatType(), False),
-    StructField("list", StringType()),
-]))
-def parse_arr_and_sum(struct_arr, tag: str):
-    sum_dict = {}
-    for s in struct_arr:
-        key = s[0]
-        value: float = s[1]
-        if key not in sum_dict:
-            sum_dict[key] = 0.0
-        if value is not None:
-            sum_dict[key] += value
-    json_list = sorted([{"code": key, "num": value} for key, value in sum_dict.items()], key=lambda x: x["num"], reverse=True)
-    total = 0.0
-    for obj in json_list:
-        total += obj["num"]
-    return round(total, 2), ",".join(['{' + f'{i["code"]},{tag}:{round(i["num"], 2)}' + '}' for i in json_list])
+    """合并二维数组,并过滤 None 与空字符串。"""
+    return _merge_non_empty_values(*(arr_list or []))
 
 
 # ==================== STRING ====================
 
+# UDF-31 中文检测:判断字符串中是否包含中文字符。
 @udf(returnType=BooleanType())
 def has_chinese(datum: str) -> bool:
+    """判断字符串中是否包含中文字符。"""
     if datum:
         pattern = re.compile(u'[\u4e00-\u9fa5]')
         if pattern.search(datum):
@@ -294,64 +269,66 @@ def has_chinese(datum: str) -> bool:
     return False
 
 
+# UDF-32 相似度计算:计算两个字符串的快速相似度。
 @udf(returnType=FloatType())
 def similarity(left: str, right: str) -> float:
+    """计算两个字符串的快速相似度。"""
     return difflib.SequenceMatcher(None, left, right).quick_ratio()
 
 
+# UDF-33 正则全提取:提取正则表达式的全部匹配结果。
 @udf(returnType=ArrayType(StringType()))
 def regexp_extract_all(col: str, ptn: str, g: int = 0):
+    """提取正则表达式的全部匹配结果。"""
     return [e.group(g) for e in re.compile(ptn).finditer(col if col else '')]
 
 
 def add_random_number_prefix(datum: str, separator: str, floor: int, ceiling: int) -> str:
+    """给字符串追加随机数字前缀。"""
     return f'{random.randint(floor, ceiling)}{separator}{datum}'
 
 
 def field_merge(delimiter: str, *fields_values):
-    """多字段合并,相同仅保留一个,不同用 delimiter 分隔"""
+    """合并多个字段值,去重后用指定分隔符拼接。"""
     if not fields_values:
         return None
     result = []
-    [result.append(value.strip()) for value in fields_values if value and value.strip() not in result]
+    for value in fields_values:
+        if value and value.strip() not in result:
+            result.append(value.strip())
     return delimiter.join(result)
 
 
 def space2null(text):
+    """把空白字符串规范化为 None。"""
     if text and not text.isspace():
         return text
     return None
 
 
 def merge_ws(text: str):
+    """压缩多余空白符,只保留单个空格。"""
     if text:
         return ' '.join(text.split())
     return None
 
 
 def remove_special_char(text, char):
+    """如果字符串以指定字符结尾,则移除最后一个字符。"""
     if text is not None and text.endswith(char):
         return text[:-1]
     return text
 
 
-@udf(returnType=ArrayType(StringType()))
-def explode_str_to_arr(text: str) -> list:
-    """大于 8 位时,从后往前每次少一位截取子串入数组(用于前缀匹配场景)"""
-    if text is None:
-        return []
-    if len(text) <= 8:
-        return [text]
-    return [text[:i] for i in range(len(text), 7, -1)]
-
-
 def html_unescape(text):
+    """反转义 HTML 实体。"""
     return html.unescape(text)
 
 
 # ==================== NUMERIC / DATE / HASH ====================
 
 def max_value(*args):
+    """按现有真假值规则返回最大值。"""
     maxi_value = None
     for elem in args:
         if not elem:
@@ -362,6 +339,7 @@ def max_value(*args):
 
 
 def min_value(*args):
+    """按现有真假值规则返回最小值。"""
     mini_value = None
     for elem in args:
         if not elem:
@@ -372,12 +350,14 @@ def min_value(*args):
 
 
 def millis_timestamp_to_str(ts: int, str_format: str = None) -> str:
+    """把毫秒时间戳转换为时间字符串。"""
     date_time = datetime.fromtimestamp(ts / 1000.0)
     if str_format:
         return date_time.strftime(str_format)
     return date_time.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
 
 
+# UDF-41 时间解析:把日期字符串解析为时间戳。
 @udf(returnType=LongType())
 def parse_datetime_to_timestamp(date_time: str, in_milli_seconds: bool = False, original_format: str = None) -> int:
     """字符串日期 → 时间戳;支持 YY.MM.DD / YYYY年M月D日 启发式识别"""
@@ -406,6 +386,7 @@ def parse_datetime_to_timestamp(date_time: str, in_milli_seconds: bool = False,
             return None
 
 
+# UDF-42 MD5摘要:把多列值按长度前缀拼接后计算 MD5。
 @udf(returnType=StringType())
 def get_md5(*cols: str) -> str:
     """多列拼接(带长度前缀防碰撞)后取 md5"""
@@ -425,31 +406,39 @@ def get_md5(*cols: str) -> str:
 # ==================== CROSS-TYPE CONVERTERS ====================
 
 def array_to_json(arr: List):
+    """把数组序列化为 JSON 字符串。"""
     return json.dumps(arr, ensure_ascii=False)
 
 
 def map_to_json(map: dict):
+    """把字典序列化为 JSON 字符串。"""
     return json.dumps(map, ensure_ascii=False)
 
 
 def struct_to_json(struct):
+    """把结构体对象转换为 JSON 字符串。"""
     json_dict = {key: struct[key] for key in struct.__dict__["__fields__"]}
     return json.dumps(json_dict, ensure_ascii=False)
 
 
 def num_to_str(number):
+    """把数值转换成字符串,整数型浮点数去掉小数位。"""
     if isinstance(number, float) and number.is_integer():
         return '{:.0f}'.format(number)
     return str(int(number)) if isinstance(number, int) else str(number)
 
 
+# UDF-51 字符串转数组:把 JSON array 字符串转换为 Python list。
 @udf(returnType=ArrayType(StringType()))
 def str_to_arr(json_str: str) -> list:
+    """把 JSON array 字符串转换为 Python list。"""
     if json_str:
-        return json.loads(json_str)
+        parsed = _load_json_or_default(json_str, default=[])
+        return parsed if isinstance(parsed, list) else []
     return []
 
 
+# UDF-52 字符串转JSON字符串数组:把 JSON array 转为 JSON 字符串数组。
 @udf(returnType=ArrayType(StringType()))
 def str_to_json_arr(json_str):
     """JSON array 字符串 → list of json strings(每个元素再 json.dumps)"""
@@ -463,33 +452,11 @@ def str_to_json_arr(json_str):
     return []
 
 
+# UDF-53 字符串转Map数组:把 JSON array 字符串转换为 map 数组。
 @udf(returnType=ArrayType(MapType(StringType(), StringType())))
 def str_to_map_arr(json_str: str) -> list:
+    """把 JSON array 字符串转换为 map 数组。"""
     if json_str:
-        return json.loads(json_str)
+        parsed = _load_json_or_default(json_str, default=[])
+        return parsed if isinstance(parsed, list) else []
     return []
-
-
-@udf(returnType=StringType())
-def split_str_to_jsonstr(str_list: List):
-    """每个元素按 ':' 切成 k:v,聚合成 JSON 字符串"""
-    res = []
-    for kv_str in str_list:
-        arr = kv_str.split(':')
-        if len(arr) == 2:
-            res.append({arr[0]: arr[1]})
-    return json.dumps(res, ensure_ascii=False)
-
-
-@udf(returnType=MapType(StringType(), ArrayType(StringType())))
-def split_str_to_maparr(str_list: List):
-    """每个元素按 ':' 切成 k:v,同 key 追加到 list"""
-    res = {}
-    for kv_str in str_list:
-        arr = kv_str.split(':')
-        if len(arr) == 2:
-            if arr[0] not in res:
-                res[arr[0]] = [arr[1]]
-            else:
-                res[arr[0]].append(arr[1])
-    return res

+ 139 - 0
tests/unit/udf/test_spark_common_udf.py

@@ -0,0 +1,139 @@
+import json
+from datetime import datetime
+
+from dw_base.udf.common import spark_common_udf as udf_module
+
+
+def test_json_object_keys_returns_keys_for_json_object():
+    assert udf_module.json_object_keys.func('{"a": 1, "b": 2}') == ["a", "b"]
+
+
+def test_json_array_subset_supports_python_literal_without_eval():
+    data = "[{'name': 'alice', 'age': 18}, {'name': 'bob', 'age': 20}]"
+
+    result = udf_module.json_array_subset(data, "name", as_list=True)
+
+    assert json.loads(result) == ["alice", "bob"]
+
+
+def test_json_array_subset_returns_none_for_invalid_input():
+    assert udf_module.json_array_subset("not-json", "name") is None
+
+
+def test_append_to_json_array_returns_original_when_source_is_invalid_json():
+    assert udf_module.append_to_json_array("not-json", "x") == "not-json"
+
+
+def test_append_to_json_array_can_remove_duplicates():
+    result = udf_module.append_to_json_array('["a", "b"]', "a", remove_duplicate=True)
+
+    assert json.loads(result) == ["a", "b"]
+
+
+def test_flatten_json_returns_original_text_for_invalid_json():
+    assert udf_module.flatten_json("not-json") == "not-json"
+
+
+def test_remove_empty_key_removes_empty_values_recursively():
+    source = json.dumps({
+        "a": "",
+        "b": None,
+        "c": {"d": "", "e": 1},
+        "f": ["", {"g": "x"}],
+    })
+
+    assert json.loads(udf_module.remove_empty_key(source)) == {"c": {"e": "1"}, "f": [{"g": "x"}]}
+
+
+def test_merge_list_keeps_existing_semantics():
+    merged_list = sorted(udf_module.merge_list.func([["a", "", None], ["b", "a"], None]))
+
+    assert merged_list == ["a", "b"]
+
+
+def test_array_intersect_returns_common_items():
+    assert sorted(udf_module.array_intersect.func(["a", "b"], ["b", "c"])) == ["b"]
+
+
+def test_array_append_respects_existing_semantics():
+    assert udf_module.array_append(["a"], "a", remove_duplicate=True) == ["a"]
+    assert udf_module.array_append(["b"], "a", need_sort=True) == ["a", "b"]
+
+
+def test_array_slice_returns_sub_list():
+    assert udf_module.array_slice.func(["a", "b", "c"], 1, 3) == ["b", "c"]
+
+
+def test_has_chinese_detects_chinese_characters():
+    assert udf_module.has_chinese.func("abc中文") is True
+    assert udf_module.has_chinese.func("abc") is False
+
+
+def test_similarity_returns_high_score_for_identical_strings():
+    assert udf_module.similarity.func("abc", "abc") == 1.0
+
+
+def test_regexp_extract_all_extracts_all_matches():
+    assert udf_module.regexp_extract_all.func("a1b22c333", r"\d+") == ["1", "22", "333"]
+
+
+def test_field_merge_deduplicates_values():
+    assert udf_module.field_merge(",", " a ", "b", "a", None) == "a,b"
+
+
+def test_space2null_and_merge_ws_and_remove_special_char():
+    assert udf_module.space2null("   ") is None
+    assert udf_module.space2null(" a ") == " a "
+    assert udf_module.merge_ws("a   b\tc") == "a b c"
+    assert udf_module.remove_special_char("abc,", ",") == "abc"
+
+
+def test_html_unescape_restores_html_entities():
+    assert udf_module.html_unescape("&lt;div&gt;Tom &amp; Jerry&lt;/div&gt;") == "<div>Tom & Jerry</div>"
+
+
+def test_max_value_and_min_value_keep_existing_truthy_semantics():
+    assert udf_module.max_value(None, 2, 1) == 2
+    assert udf_module.min_value(None, 2, 1) == 1
+
+
+def test_millis_timestamp_to_str_formats_milliseconds():
+    expected = datetime.fromtimestamp(0).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
+    assert udf_module.millis_timestamp_to_str(0) == expected
+
+
+def test_parse_datetime_to_timestamp_supports_seconds_and_milliseconds():
+    expected_seconds = int(datetime(2024, 1, 2, 3, 4, 5).timestamp())
+    expected_milliseconds = expected_seconds * 1000
+
+    assert udf_module.parse_datetime_to_timestamp.func("2024-01-02 03:04:05") == expected_seconds
+    assert udf_module.parse_datetime_to_timestamp.func(str(expected_milliseconds)) == expected_seconds
+    assert udf_module.parse_datetime_to_timestamp.func(str(expected_seconds), in_milli_seconds=True) == expected_milliseconds
+
+
+def test_get_md5_is_stable_for_same_inputs():
+    assert udf_module.get_md5.func("ab", "cd") == udf_module.get_md5.func("ab", "cd")
+    assert udf_module.get_md5.func(None) == ""
+
+
+def test_array_to_json_and_map_to_json_and_num_to_str():
+    assert json.loads(udf_module.array_to_json(["a", 1])) == ["a", 1]
+    assert json.loads(udf_module.map_to_json({"a": 1})) == {"a": 1}
+    assert udf_module.num_to_str(1.0) == "1"
+    assert udf_module.num_to_str(2) == "2"
+
+
+def test_str_to_arr_returns_empty_when_json_is_invalid():
+    assert udf_module.str_to_arr.func("not-json") == []
+
+
+def test_str_to_json_arr_returns_json_strings():
+    assert udf_module.str_to_json_arr.func('[{"a": 1}, {"b": 2}]') == ['{"a": 1}', '{"b": 2}']
+
+
+def test_str_to_map_arr_returns_empty_when_json_is_not_list():
+    assert udf_module.str_to_map_arr.func('{"a": 1}') == []
+
+
+def test_is_json_handles_none():
+    assert udf_module.is_json.func(None) is False