spark_id_generate_udf.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """
  2. 批量匹配tid
  3. """
  4. import hashlib
  5. import json
  6. from functools import lru_cache
  7. from dw_base.spark.udf.enterprise.unique.ent_offline_udf_america import generate_tid_usa, clean_company_name_usa
  8. from dw_base.spark.udf.enterprise.unique.ent_offline_udf_india import clean_company_name_ind, generate_tid_ind
  9. from dw_base.spark.udf.enterprise.unique.ent_offline_udf_indonesia import clean_company_name_idn, generate_tid_idn
  10. from dw_base.spark.udf.enterprise.unique.ent_offline_udf_russia import clean_company_name_rus, generate_tid_rus
  11. from dw_base.spark.udf.enterprise.unique.ent_offline_udf_turkey import clean_company_name_tur, generate_tid_tur
  12. from dw_base.utils.tid_utils import TidGeneratorFactory
  13. mapping = {}
  14. tid_generator = TidGeneratorFactory().createTidGenerator('Enterprise')
  15. def generate_tid(website, name, country_code3):
  16. if not name:
  17. return None
  18. if country_code3 in ['IDN']:
  19. cleaned_name = clean_company_name_idn(name)
  20. return match_tid(name, cleaned_name, country_code3)
  21. elif country_code3 in ['USA']:
  22. cleaned_name = clean_company_name_usa(name)
  23. return match_tid(name, cleaned_name, country_code3)
  24. elif country_code3 in ['TUR']:
  25. cleaned_name = clean_company_name_tur(name)
  26. return match_tid(name, cleaned_name, country_code3)
  27. elif country_code3 in ['IND']:
  28. cleaned_name = clean_company_name_ind(name)
  29. return match_tid(name, cleaned_name, country_code3)
  30. elif country_code3 in ['RUS']:
  31. cleaned_name = clean_company_name_rus(name)
  32. return match_tid(name, cleaned_name, country_code3)
  33. else:
  34. return old_generate_tid(website, name, country_code3)
  35. def generate_md5_hash(input_str: str):
  36. md5_hash = hashlib.md5(input_str.encode('utf-8'))
  37. return md5_hash.hexdigest()
  38. def old_generate_tid(website, name, country_code3):
  39. if not name:
  40. return None
  41. input_str = website if website else f"{name}-{country_code3 if country_code3 else ''}"
  42. return generate_md5_hash(input_str)
  43. def match_tid(name: str, cleaned_name: str, country: str):
  44. tid = cache_tid(name, cleaned_name, country)
  45. if not tid:
  46. if country == 'IDN':
  47. return generate_tid_idn(cleaned_name, None, None)
  48. elif country == 'USA':
  49. return generate_tid_usa(cleaned_name, None, None)
  50. elif country == 'TUR':
  51. return generate_tid_tur(cleaned_name, None)
  52. elif country == 'IND':
  53. return generate_tid_ind(cleaned_name, None)
  54. elif country == 'RUS':
  55. return generate_tid_rus(cleaned_name, None, None)
  56. return tid
  57. @lru_cache(maxsize=1000000)
  58. def cache_tid(name: str, cleaned_name: str, country: str):
  59. key = '%s--%s' % (
  60. name if name else "",
  61. country if country else ""
  62. )
  63. cleaned_key = '%s--%s' % (cleaned_name if cleaned_name else "",
  64. country if country else "")
  65. tid = mapping.get(key) or mapping.get(cleaned_key)
  66. if tid is None:
  67. # 如果mapping里没有该tid,进行匹配
  68. tid = tid_generator.match_tid(name, country)
  69. if tid is None:
  70. # 如果匹配结果为null,则向mapping写入一个空字符串
  71. tid = tid_generator.match_tid(cleaned_name, country)
  72. if tid is None:
  73. mapping[key] = ''
  74. mapping[cleaned_key] = ''
  75. else:
  76. mapping[cleaned_key] = tid
  77. else:
  78. mapping[key] = tid
  79. elif tid == '':
  80. # 对于第一次没有匹配到tid的公司,第二次进入该方法会得到一个空字符串,此时应返回null
  81. return None
  82. return tid
  83. if __name__ == '__main__':
  84. print(generate_tid('', 'KENCANA LINTASINDO INTERNASIONAL', 'IDN'))