dnsproxy.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. #!/usr/bin/env python
  2. # Copyright 2010 Google Inc. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import daemonserver
  16. import errno
  17. import logging
  18. import socket
  19. import SocketServer
  20. import threading
  21. import time
  22. from third_party.dns import flags
  23. from third_party.dns import message
  24. from third_party.dns import rcode
  25. from third_party.dns import resolver
  26. from third_party.dns import rdatatype
  27. from third_party import ipaddr
  28. class DnsProxyException(Exception):
  29. pass
  30. class RealDnsLookup(object):
  31. def __init__(self, name_servers):
  32. if '127.0.0.1' in name_servers:
  33. raise DnsProxyException(
  34. 'Invalid nameserver: 127.0.0.1 (causes an infinte loop)')
  35. self.resolver = resolver.get_default_resolver()
  36. self.resolver.nameservers = name_servers
  37. self.dns_cache_lock = threading.Lock()
  38. self.dns_cache = {}
  39. @staticmethod
  40. def _IsIPAddress(hostname):
  41. try:
  42. socket.inet_aton(hostname)
  43. return True
  44. except socket.error:
  45. return False
  46. def __call__(self, hostname, rdtype=rdatatype.A):
  47. """Return real IP for a host.
  48. Args:
  49. host: a hostname ending with a period (e.g. "www.google.com.")
  50. rdtype: the query type (1 for 'A', 28 for 'AAAA')
  51. Returns:
  52. the IP address as a string (e.g. "192.168.25.2")
  53. """
  54. if self._IsIPAddress(hostname):
  55. return hostname
  56. self.dns_cache_lock.acquire()
  57. ip = self.dns_cache.get(hostname)
  58. self.dns_cache_lock.release()
  59. if ip:
  60. return ip
  61. try:
  62. answers = self.resolver.query(hostname, rdtype)
  63. except resolver.NXDOMAIN:
  64. return None
  65. except resolver.NoNameservers:
  66. logging.debug('_real_dns_lookup(%s) -> No nameserver.',
  67. hostname)
  68. return None
  69. except (resolver.NoAnswer, resolver.Timeout) as ex:
  70. logging.debug('_real_dns_lookup(%s) -> None (%s)',
  71. hostname, ex.__class__.__name__)
  72. return None
  73. if answers:
  74. ip = str(answers[0])
  75. self.dns_cache_lock.acquire()
  76. self.dns_cache[hostname] = ip
  77. self.dns_cache_lock.release()
  78. return ip
  79. def ClearCache(self):
  80. """Clear the dns cache."""
  81. self.dns_cache_lock.acquire()
  82. self.dns_cache.clear()
  83. self.dns_cache_lock.release()
  84. class ReplayDnsLookup(object):
  85. """Resolve DNS requests to replay host."""
  86. def __init__(self, replay_ip, filters=None):
  87. self.replay_ip = replay_ip
  88. self.filters = filters or []
  89. def __call__(self, hostname):
  90. ip = self.replay_ip
  91. for f in self.filters:
  92. ip = f(hostname, default_ip=ip)
  93. return ip
  94. class PrivateIpFilter(object):
  95. """Resolve private hosts to their real IPs and others to the Web proxy IP.
  96. Hosts in the given http_archive will resolve to the Web proxy IP without
  97. checking the real IP.
  98. This only supports IPv4 lookups.
  99. """
  100. def __init__(self, real_dns_lookup, http_archive):
  101. """Initialize PrivateIpDnsLookup.
  102. Args:
  103. real_dns_lookup: a function that resolves a host to an IP.
  104. http_archive: an instance of a HttpArchive
  105. Hosts is in the archive will always resolve to the web_proxy_ip
  106. """
  107. self.real_dns_lookup = real_dns_lookup
  108. self.http_archive = http_archive
  109. self.InitializeArchiveHosts()
  110. def __call__(self, host, default_ip):
  111. """Return real IPv4 for private hosts and Web proxy IP otherwise.
  112. Args:
  113. host: a hostname ending with a period (e.g. "www.google.com.")
  114. Returns:
  115. IP address as a string or None (if lookup fails)
  116. """
  117. ip = default_ip
  118. if host not in self.archive_hosts:
  119. real_ip = self.real_dns_lookup(host)
  120. if real_ip:
  121. if ipaddr.IPAddress(real_ip).is_private:
  122. ip = real_ip
  123. else:
  124. ip = None
  125. return ip
  126. def InitializeArchiveHosts(self):
  127. """Recompute the archive_hosts from the http_archive."""
  128. self.archive_hosts = set('%s.' % req.host.split(':')[0]
  129. for req in self.http_archive)
  130. class DelayFilter(object):
  131. """Add a delay to replayed lookups."""
  132. def __init__(self, is_record_mode, delay_ms):
  133. self.is_record_mode = is_record_mode
  134. self.delay_ms = int(delay_ms)
  135. def __call__(self, host, default_ip):
  136. if not self.is_record_mode:
  137. time.sleep(self.delay_ms * 1000.0)
  138. return default_ip
  139. def SetRecordMode(self):
  140. self.is_record_mode = True
  141. def SetReplayMode(self):
  142. self.is_record_mode = False
  143. class UdpDnsHandler(SocketServer.DatagramRequestHandler):
  144. """Resolve DNS queries to localhost.
  145. Possible alternative implementation:
  146. http://howl.play-bow.org/pipermail/dnspython-users/2010-February/000119.html
  147. """
  148. STANDARD_QUERY_OPERATION_CODE = 0
  149. def handle(self):
  150. """Handle a DNS query.
  151. IPv6 requests (with rdtype AAAA) receive mismatched IPv4 responses
  152. (with rdtype A). To properly support IPv6, the http proxy would
  153. need both types of addresses. By default, Windows XP does not
  154. support IPv6.
  155. """
  156. self.data = self.rfile.read()
  157. self.transaction_id = self.data[0]
  158. self.flags = self.data[1]
  159. self.qa_counts = self.data[4:6]
  160. self.domain = ''
  161. operation_code = (ord(self.data[2]) >> 3) & 15
  162. if operation_code == self.STANDARD_QUERY_OPERATION_CODE:
  163. self.wire_domain = self.data[12:]
  164. self.domain = self._domain(self.wire_domain)
  165. else:
  166. logging.debug("DNS request with non-zero operation code: %s",
  167. operation_code)
  168. ip = self.server.dns_lookup(self.domain)
  169. if ip is None:
  170. logging.debug('dnsproxy: %s -> NXDOMAIN', self.domain)
  171. response = self.get_dns_no_such_name_response()
  172. else:
  173. if ip == self.server.server_address[0]:
  174. logging.debug('dnsproxy: %s -> %s (replay web proxy)', self.domain, ip)
  175. else:
  176. logging.debug('dnsproxy: %s -> %s', self.domain, ip)
  177. response = self.get_dns_response(ip)
  178. self.wfile.write(response)
  179. @classmethod
  180. def _domain(cls, wire_domain):
  181. domain = ''
  182. index = 0
  183. length = ord(wire_domain[index])
  184. while length:
  185. domain += wire_domain[index + 1:index + length + 1] + '.'
  186. index += length + 1
  187. length = ord(wire_domain[index])
  188. return domain
  189. def get_dns_response(self, ip):
  190. packet = ''
  191. if self.domain:
  192. packet = (
  193. self.transaction_id +
  194. self.flags +
  195. '\x81\x80' + # standard query response, no error
  196. self.qa_counts * 2 + '\x00\x00\x00\x00' + # Q&A counts
  197. self.wire_domain +
  198. '\xc0\x0c' # pointer to domain name
  199. '\x00\x01' # resource record type ("A" host address)
  200. '\x00\x01' # class of the data
  201. '\x00\x00\x00\x3c' # ttl (seconds)
  202. '\x00\x04' + # resource data length (4 bytes for ip)
  203. socket.inet_aton(ip)
  204. )
  205. return packet
  206. def get_dns_no_such_name_response(self):
  207. query_message = message.from_wire(self.data)
  208. response_message = message.make_response(query_message)
  209. response_message.flags |= flags.AA | flags.RA
  210. response_message.set_rcode(rcode.NXDOMAIN)
  211. return response_message.to_wire()
  212. class DnsProxyServer(SocketServer.ThreadingUDPServer,
  213. daemonserver.DaemonServer):
  214. # Increase the request queue size. The default value, 5, is set in
  215. # SocketServer.TCPServer (the parent of BaseHTTPServer.HTTPServer).
  216. # Since we're intercepting many domains through this single server,
  217. # it is quite possible to get more than 5 concurrent requests.
  218. request_queue_size = 256
  219. # Allow sockets to be reused. See
  220. # http://svn.python.org/projects/python/trunk/Lib/SocketServer.py for more
  221. # details.
  222. allow_reuse_address = True
  223. # Don't prevent python from exiting when there is thread activity.
  224. daemon_threads = True
  225. def __init__(self, host='', port=53, dns_lookup=None):
  226. """Initialize DnsProxyServer.
  227. Args:
  228. host: a host string (name or IP) to bind the dns proxy and to which
  229. DNS requests will be resolved.
  230. port: an integer port on which to bind the proxy.
  231. dns_lookup: a list of filters to apply to lookup.
  232. """
  233. try:
  234. SocketServer.ThreadingUDPServer.__init__(
  235. self, (host, port), UdpDnsHandler)
  236. except socket.error, (error_number, msg):
  237. if error_number == errno.EACCES:
  238. raise DnsProxyException(
  239. 'Unable to bind DNS server on (%s:%s)' % (host, port))
  240. raise
  241. self.dns_lookup = dns_lookup or (lambda host: self.server_address[0])
  242. self.server_port = self.server_address[1]
  243. logging.warning('DNS server started on %s:%d', self.server_address[0],
  244. self.server_address[1])
  245. def cleanup(self):
  246. try:
  247. self.shutdown()
  248. self.server_close()
  249. except KeyboardInterrupt, e:
  250. pass
  251. logging.info('Stopped DNS server')