test_sync_template_gen.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # -*- coding:utf-8 -*-
  2. """
  3. datax-sync-template-gen 模板渲染 + JDBC URL 解析单测。
  4. 不连真 PG(query_columns_full 走 mock conn)。
  5. 脚本路径含连字符,用 importlib.util 动态加载为模块。
  6. """
  7. import importlib.util
  8. import os
  9. import sys
  10. from unittest.mock import MagicMock
  11. import pytest
  12. PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  13. SCRIPT_PATH = os.path.join(PROJECT_ROOT, 'bin', 'datax-sync-template-gen.py')
  14. def _load_script():
  15. spec = importlib.util.spec_from_file_location('datax_sync_template_gen', SCRIPT_PATH)
  16. mod = importlib.util.module_from_spec(spec)
  17. sys.modules['datax_sync_template_gen'] = mod
  18. spec.loader.exec_module(mod)
  19. return mod
  20. GEN = _load_script()
  21. def test_parse_jdbc_url_with_port():
  22. host, port, db = GEN.parse_jdbc_url('jdbc:postgresql://10.0.0.1:5433/hobby_stocks')
  23. assert host == '10.0.0.1'
  24. assert port == 5433
  25. assert db == 'hobby_stocks'
  26. def test_parse_jdbc_url_default_port():
  27. host, port, db = GEN.parse_jdbc_url('jdbc:postgresql://pg.example.com/mydb')
  28. assert host == 'pg.example.com'
  29. assert port == 5432
  30. assert db == 'mydb'
  31. def test_parse_jdbc_url_invalid():
  32. with pytest.raises(ValueError, match='无法解析'):
  33. GEN.parse_jdbc_url('mysql://10.0.0.1:3306/foo')
  34. def test_render_template_includes_required_fields():
  35. columns = [('id', 'id'), ('name', '姓名'), ('create_time', '创建时间')]
  36. out = GEN.render_template(
  37. ds_ref='postgresql/prod-hobby',
  38. database='hobby_stocks',
  39. schema='public',
  40. table='users',
  41. columns=columns,
  42. pk='id',
  43. )
  44. assert 'dataSource = postgresql/prod-hobby' in out
  45. assert 'database = hobby_stocks' in out
  46. assert 'table = public.users' in out
  47. assert 'column = id,name,create_time' in out
  48. assert 'splitPk = id' in out
  49. assert "where = update_time >= '${start_date}' AND update_time < '${stop_date}'" in out
  50. assert 'path = /user/hive/warehouse/raw.db/users_TODO_d/dt=${dt}/' in out
  51. assert 'fileName = users_TODO_d' in out
  52. # 不传 mask_methods 时不渲染 [mask] section header
  53. assert '\n[mask]\n' not in out
  54. def test_render_template_with_mask_methods():
  55. columns = [('id', 'id'), ('user_name', '用户名'), ('phone', '手机号')]
  56. out = GEN.render_template(
  57. ds_ref='postgresql/prod-hobby', database='db', schema='public',
  58. table='users', columns=columns, pk='id',
  59. mask_methods={'user_name': 'mask_middle', 'phone': 'md5'},
  60. )
  61. # [mask] section header 在 [reader] 后 [writer] 前
  62. assert '\n[mask]\n' in out
  63. assert 'user_name = mask_middle' in out
  64. assert 'phone = md5' in out
  65. reader_idx = out.index('\n[reader]\n')
  66. mask_idx = out.index('\n[mask]\n')
  67. writer_idx = out.index('\n[writer]\n')
  68. assert reader_idx < mask_idx < writer_idx
  69. def test_query_columns_full_returns_full_metadata():
  70. conn = MagicMock()
  71. cur = conn.cursor.return_value
  72. cur.fetchall.return_value = [
  73. (1, 'id', 'id', 'bigint', 'PK'),
  74. (2, 'name', '名称', 'character varying', ''),
  75. ]
  76. rows = GEN.query_columns_full(conn, 'public', 'orders')
  77. assert rows == [
  78. (1, 'id', 'id', 'bigint', 'PK'),
  79. (2, 'name', '名称', 'character varying', ''),
  80. ]
  81. def test_render_schema_md_no_mask_dict_blank_column():
  82. rows = [
  83. (1, 'id', 'id', 'bigint', 'PK'),
  84. (2, 'user_name', '用户名', 'character varying', ''),
  85. (3, 'create_time', None, 'timestamp without time zone', ''),
  86. ]
  87. out = GEN.render_schema_md(rows)
  88. assert '| 序号 | 字段名 | 中文名 | 数据类型 | 主键标识 | 脱敏类型 |' in out
  89. assert '| 1 | `id` | id | bigint | PK | |' in out
  90. assert '| 2 | `user_name` | 用户名 | character varying | | |' in out
  91. assert '| 3 | `create_time` | | timestamp without time zone | | |' in out
  92. def test_render_schema_md_with_mask_dict():
  93. rows = [
  94. (1, 'id', 'id', 'bigint', 'PK'),
  95. (2, 'user_name', '用户名', 'character varying', ''),
  96. (3, 'phone', '手机号', 'character varying', ''),
  97. (4, 'merchant_open', '商家代开', 'smallint', ''),
  98. ]
  99. mask_dict = {'phone': 'md5', 'merchant_open': 'trim', 'user_name': 'mask_middle'}
  100. out = GEN.render_schema_md(rows, mask_dict)
  101. assert '| 1 | `id` | id | bigint | PK | |' in out
  102. assert '| 2 | `user_name` | 用户名 | character varying | | mask_middle |' in out
  103. assert '| 3 | `phone` | 手机号 | character varying | | md5 |' in out
  104. assert '| 4 | `merchant_open` | 商家代开 | smallint | | trim |' in out
  105. def test_load_mask_conf_basic(tmp_path):
  106. p = tmp_path / 't.mask.ini'
  107. p.write_text(
  108. '[mask]\n'
  109. 'payment_num = trim\n'
  110. 'phone = md5\n'
  111. 'name = mask_middle\n',
  112. encoding='utf-8',
  113. )
  114. assert GEN.load_mask_conf(str(p)) == {
  115. 'payment_num': 'trim',
  116. 'phone': 'md5',
  117. 'name': 'mask_middle',
  118. }
  119. def test_load_mask_conf_no_section_returns_empty(tmp_path):
  120. p = tmp_path / 't.mask.ini'
  121. p.write_text('[other]\nfoo = bar\n', encoding='utf-8')
  122. assert GEN.load_mask_conf(str(p)) == {}
  123. def test_render_template_empty_pk():
  124. out = GEN.render_template(
  125. ds_ref='postgresql/prod-hobby', database='db', schema='public',
  126. table='t', columns=[('a', '')], pk='',
  127. )
  128. assert 'splitPk = \n' in out