tid_utils.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # coding=utf-8
  2. """
  3. udf
  4. """
  5. import json
  6. import logging
  7. from dw_base.database.mongodb_utils import MongoDBHandler
  8. class TidGenerator(object):
  9. def match_pid(self,
  10. company_name: str, country: str) -> str:
  11. raise Exception("not implemented yet")
  12. class MongoTidGenerator(TidGenerator):
  13. def __init__(self):
  14. self.tid_field = None
  15. self.alias_field = None
  16. self.country_field = None
  17. self.company_aliases = None
  18. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  19. def match_tid(self,
  20. company_name: str, country: str) -> str:
  21. pid = None
  22. if company_name:
  23. pid = self.__match_by_company_name(company_name, country)
  24. return pid
  25. def __match_by_company_name(self,
  26. company_name: str, country: str) -> str:
  27. if not company_name:
  28. return None
  29. documents = []
  30. find_result_by_name = self.company_aliases.find({self.alias_field: {"$eq": company_name}})
  31. for document in find_result_by_name:
  32. tid_value = document.get(self.tid_field)
  33. if tid_value and tid_value[:3] == country:
  34. documents.append(document)
  35. if len(documents) > 0:
  36. max_document = max(documents, key=lambda x: x.get(self.tid_field, 0))
  37. return max_document.get(self.tid_field)
  38. return None
  39. class EnterpriseTidGenerator(MongoTidGenerator):
  40. def __init__(self):
  41. super().__init__()
  42. self.uri = 'mongodb://tendata_corp:TD_corpqyk22@192.168.11.27:21868/?authSource=tendata_corp'
  43. self.database = "tendata_corp"
  44. self.collection_alias = "company_aliases"
  45. self.tid_field = 'tid'
  46. self.alias_field = 'alias'
  47. self.country_field = 'country_code3'
  48. self.mongo_client = MongoDBHandler(self.uri).mongo_client
  49. self.company_aliases = self.mongo_client.get_database(self.database).get_collection(self.collection_alias)
  50. class HBaseTidGenerator(TidGenerator):
  51. def __init__(self):
  52. raise Exception("not implemented yet")
  53. # 自定义异常类
  54. class UnsupportedDimensionError(Exception):
  55. def __init__(self, dimension):
  56. self.dimension = dimension
  57. super().__init__(f"Not supported generator dimension: {dimension}")
  58. class TidGeneratorFactory(object):
  59. @staticmethod
  60. def createTidGenerator(dimension: str):
  61. if dimension is None:
  62. raise ValueError("Dimension cannot be None")
  63. switch_generator_dict = {
  64. 'Enterprise': EnterpriseTidGenerator(),
  65. }
  66. if dimension not in switch_generator_dict:
  67. raise UnsupportedDimensionError(dimension)
  68. return switch_generator_dict[dimension]
  69. if __name__ == '__main__':
  70. pass