sql_utils.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env /usr/bin/python3
  2. # -*- coding:utf-8 -*-
  3. import re
  4. from typing import List
  5. from dw_base import NORM_GRN, NORM_YEL
  6. from dw_base.utils.file_utils import read_file_lines
  7. from dw_base.utils.log_utils import pretty_print
  8. def check_parameter_substituted(sql: str, ignore: bool = False):
  9. """
  10. 检查是否有未替换的参数
  11. Args:
  12. sql: 要检查的SQL
  13. ignore: 有未替换的参数时是否忽略
  14. Returns:
  15. """
  16. sql = re.sub(r'[\r\n]+', ' ', sql)
  17. match = re.findall('\\${(.*?)}', sql)
  18. if match and len(match) > 0:
  19. parameters = set(match)
  20. if ignore:
  21. pretty_print(f'{NORM_YEL}Parameter {NORM_GRN}{", ".join(parameters)}{NORM_YEL} is not provided')
  22. else:
  23. raise Exception(f'Parameter {", ".join(parameters)} is not provided')
  24. def get_sql_list_from_file(sql_file: str, trim_comment: bool = False) -> List[str]:
  25. """
  26. 从文件读取SQL语句列表
  27. Args:
  28. sql_file: SQL文件
  29. trim_comment: 是否去除注释
  30. Returns: SQL语句列表
  31. """
  32. sql_lines = read_file_lines(sql_file)
  33. sql_list = []
  34. sql_buffer = ''
  35. for line in sql_lines:
  36. if trim_comment:
  37. if line.strip() == '' or line.strip().startswith('--'):
  38. continue
  39. if line.__contains__('--'):
  40. cleaned_line = line[:line.rindex('--')].strip()
  41. if len(cleaned_line) > 0:
  42. sql_buffer += cleaned_line
  43. continue
  44. if line.strip() == ';':
  45. # 新行是分号
  46. if sql_buffer != '':
  47. sql_list.append(sql_buffer)
  48. sql_buffer = ''
  49. continue
  50. if line.strip().endswith(';'):
  51. # 新行以分号结尾
  52. sql_list.append(sql_buffer + line.strip().strip(';'))
  53. sql_buffer = ''
  54. continue
  55. # if line.strip().__contains__(';'):
  56. # # 新行含有分号(比较复杂的逻辑,如 like '%abc;def%'),如果分号左边的单引号个数是奇数个,应认为分号是作为参数的(没有实现,先一刀切认为是语句结尾吧)
  57. # parts = line.split(';')
  58. # sql_list.append(sql_buffer + parts[0])
  59. # for index in range(1, len(parts) - 1):
  60. # sql_list.append(parts[index])
  61. # sql_buffer = parts[-1]
  62. # continue
  63. sql_buffer += line
  64. if sql_buffer != '':
  65. sql_list.append(sql_buffer)
  66. return sql_list