solr_similar_match_udf.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334
  1. import json
  2. import requests
  3. from pyspark.sql.functions import udf
  4. from pyspark.sql.types import StringType, ArrayType, StructType, StructField, BooleanType
  5. def edismax_call(collection: str, q_alt: str, q: str, qf: str, mm: str = '70%', rows: int = 1, stopwords: str = 'true',
  6. tie: float = 0.2, wt: str = 'json'):
  7. def_type: str = 'edismax'
  8. params = {"defType": def_type, "mm": mm, "q.alt": q_alt, "q": q, 'qf': qf, 'rows': rows, 'stopwords': stopwords,
  9. 'tie': tie, 'wt': wt}
  10. resp = requests.get(f'http://m2.node.dev:8886/solr/{collection}/select', params=params)
  11. return resp
  12. @udf(returnType=StructType([
  13. StructField("is_finded",BooleanType(),False),
  14. StructField("basic_arr",ArrayType(StringType()),True)
  15. ]))
  16. def get_china_company_name_match(raw_name:str,mm:str = '70%', rows: int = 1):
  17. solr_resp = edismax_call('ent_china_biz_basic', raw_name,
  18. raw_name, 'ent_name_en_abb^1.0',mm,rows)
  19. if solr_resp.status_code != 200:
  20. return False, None
  21. else:
  22. resp = json.loads(solr_resp.text)['response']
  23. if resp['numFound'] == 0:
  24. return False, None
  25. else:
  26. most_match_one = resp['docs'][0]
  27. return True, [most_match_one['ent_name_chn'],most_match_one['ent_name_en'],most_match_one['ent_name_en_abb'],most_match_one['unc_id']]
  28. if __name__ == '__main__':
  29. print(get_china_company_name_match('SAMSUNG ELECTRONICS CO. LTD.,. '))