sslproxy_test.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright 2014 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Test routines to generate dummy certificates."""
  15. import BaseHTTPServer
  16. import shutil
  17. import signal
  18. import socket
  19. import tempfile
  20. import threading
  21. import time
  22. import unittest
  23. import certutils
  24. import sslproxy
  25. class Client(object):
  26. def __init__(self, ca_cert_path, verify_cb, port, host_name='foo.com',
  27. host='localhost'):
  28. self.host_name = host_name
  29. self.verify_cb = verify_cb
  30. self.ca_cert_path = ca_cert_path
  31. self.port = port
  32. self.host_name = host_name
  33. self.host = host
  34. self.connection = None
  35. def run_request(self):
  36. context = certutils.get_ssl_context()
  37. context.set_verify(certutils.VERIFY_PEER, self.verify_cb) # Demand a cert
  38. context.use_certificate_file(self.ca_cert_path)
  39. context.load_verify_locations(self.ca_cert_path)
  40. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  41. self.connection = certutils.get_ssl_connection(context, s)
  42. self.connection.connect((self.host, self.port))
  43. self.connection.set_tlsext_host_name(self.host_name)
  44. try:
  45. self.connection.send('\r\n\r\n')
  46. finally:
  47. self.connection.shutdown()
  48. self.connection.close()
  49. class Handler(BaseHTTPServer.BaseHTTPRequestHandler):
  50. protocol_version = 'HTTP/1.1' # override BaseHTTPServer setting
  51. def handle_one_request(self):
  52. """Handle a single HTTP request."""
  53. self.raw_requestline = self.rfile.readline(65537)
  54. class WrappedErrorHandler(Handler):
  55. """Wraps handler to verify expected sslproxy errors are being raised."""
  56. def setup(self):
  57. Handler.setup(self)
  58. try:
  59. sslproxy._SetUpUsingDummyCert(self)
  60. except certutils.Error:
  61. self.server.error_function = certutils.Error
  62. def finish(self):
  63. Handler.finish(self)
  64. self.connection.shutdown()
  65. self.connection.close()
  66. class DummyArchive(object):
  67. def __init__(self):
  68. pass
  69. class DummyFetch(object):
  70. def __init__(self):
  71. self.http_archive = DummyArchive()
  72. class Server(BaseHTTPServer.HTTPServer):
  73. """SSL server."""
  74. def __init__(self, ca_cert_path, use_error_handler=False, port=0,
  75. host='localhost'):
  76. self.ca_cert_path = ca_cert_path
  77. with open(ca_cert_path, 'r') as ca_file:
  78. self.ca_cert_str = ca_file.read()
  79. self.http_archive_fetch = DummyFetch()
  80. if use_error_handler:
  81. self.HANDLER = WrappedErrorHandler
  82. else:
  83. self.HANDLER = sslproxy.wrap_handler(Handler)
  84. try:
  85. BaseHTTPServer.HTTPServer.__init__(self, (host, port), self.HANDLER)
  86. except Exception, e:
  87. raise RuntimeError('Could not start HTTPSServer on port %d: %s'
  88. % (port, e))
  89. def __enter__(self):
  90. thread = threading.Thread(target=self.serve_forever)
  91. thread.daemon = True
  92. thread.start()
  93. return self
  94. def cleanup(self):
  95. try:
  96. self.shutdown()
  97. except KeyboardInterrupt:
  98. pass
  99. def __exit__(self, type_, value_, traceback_):
  100. self.cleanup()
  101. def get_certificate(self, host):
  102. return certutils.generate_cert(self.ca_cert_str, '', host)
  103. class TestClient(unittest.TestCase):
  104. _temp_dir = None
  105. def setUp(self):
  106. self._temp_dir = tempfile.mkdtemp(prefix='sslproxy_', dir='/tmp')
  107. self.ca_cert_path = self._temp_dir + 'testCA.pem'
  108. self.cert_path = self._temp_dir + 'testCA-cert.cer'
  109. self.wrong_ca_cert_path = self._temp_dir + 'wrong.pem'
  110. self.wrong_cert_path = self._temp_dir + 'wrong-cert.cer'
  111. # Write both pem and cer files for certificates
  112. certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
  113. cert_path=self.ca_cert_path)
  114. certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
  115. cert_path=self.ca_cert_path)
  116. def tearDown(self):
  117. if self._temp_dir:
  118. shutil.rmtree(self._temp_dir)
  119. def verify_cb(self, conn, cert, errnum, depth, ok):
  120. """A callback that verifies the certificate authentication worked.
  121. Args:
  122. conn: Connection object
  123. cert: x509 object
  124. errnum: possible error number
  125. depth: error depth
  126. ok: 1 if the authentication worked 0 if it didnt.
  127. Returns:
  128. 1 or 0 depending on if the verification worked
  129. """
  130. self.assertFalse(cert.has_expired())
  131. self.assertGreater(time.strftime('%Y%m%d%H%M%SZ', time.gmtime()),
  132. cert.get_notBefore())
  133. return ok
  134. def test_no_host(self):
  135. with Server(self.ca_cert_path) as server:
  136. c = Client(self.cert_path, self.verify_cb, server.server_port, '')
  137. self.assertRaises(certutils.Error, c.run_request)
  138. def test_client_connection(self):
  139. with Server(self.ca_cert_path) as server:
  140. c = Client(self.cert_path, self.verify_cb, server.server_port, 'foo.com')
  141. c.run_request()
  142. c = Client(self.cert_path, self.verify_cb, server.server_port,
  143. 'random.host')
  144. c.run_request()
  145. def test_wrong_cert(self):
  146. with Server(self.ca_cert_path, True) as server:
  147. c = Client(self.wrong_cert_path, self.verify_cb, server.server_port,
  148. 'foo.com')
  149. self.assertRaises(certutils.Error, c.run_request)
  150. if __name__ == '__main__':
  151. signal.signal(signal.SIGINT, signal.SIG_DFL) # Exit on Ctrl-C
  152. unittest.main()