| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- # Copyright 2014 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Test routines to generate dummy certificates."""
- import BaseHTTPServer
- import shutil
- import signal
- import socket
- import tempfile
- import threading
- import time
- import unittest
- import certutils
- import sslproxy
- class Client(object):
- def __init__(self, ca_cert_path, verify_cb, port, host_name='foo.com',
- host='localhost'):
- self.host_name = host_name
- self.verify_cb = verify_cb
- self.ca_cert_path = ca_cert_path
- self.port = port
- self.host_name = host_name
- self.host = host
- self.connection = None
- def run_request(self):
- context = certutils.get_ssl_context()
- context.set_verify(certutils.VERIFY_PEER, self.verify_cb) # Demand a cert
- context.use_certificate_file(self.ca_cert_path)
- context.load_verify_locations(self.ca_cert_path)
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.connection = certutils.get_ssl_connection(context, s)
- self.connection.connect((self.host, self.port))
- self.connection.set_tlsext_host_name(self.host_name)
- try:
- self.connection.send('\r\n\r\n')
- finally:
- self.connection.shutdown()
- self.connection.close()
- class Handler(BaseHTTPServer.BaseHTTPRequestHandler):
- protocol_version = 'HTTP/1.1' # override BaseHTTPServer setting
- def handle_one_request(self):
- """Handle a single HTTP request."""
- self.raw_requestline = self.rfile.readline(65537)
- class WrappedErrorHandler(Handler):
- """Wraps handler to verify expected sslproxy errors are being raised."""
- def setup(self):
- Handler.setup(self)
- try:
- sslproxy._SetUpUsingDummyCert(self)
- except certutils.Error:
- self.server.error_function = certutils.Error
- def finish(self):
- Handler.finish(self)
- self.connection.shutdown()
- self.connection.close()
- class DummyArchive(object):
- def __init__(self):
- pass
- class DummyFetch(object):
- def __init__(self):
- self.http_archive = DummyArchive()
- class Server(BaseHTTPServer.HTTPServer):
- """SSL server."""
- def __init__(self, ca_cert_path, use_error_handler=False, port=0,
- host='localhost'):
- self.ca_cert_path = ca_cert_path
- with open(ca_cert_path, 'r') as ca_file:
- self.ca_cert_str = ca_file.read()
- self.http_archive_fetch = DummyFetch()
- if use_error_handler:
- self.HANDLER = WrappedErrorHandler
- else:
- self.HANDLER = sslproxy.wrap_handler(Handler)
- try:
- BaseHTTPServer.HTTPServer.__init__(self, (host, port), self.HANDLER)
- except Exception, e:
- raise RuntimeError('Could not start HTTPSServer on port %d: %s'
- % (port, e))
- def __enter__(self):
- thread = threading.Thread(target=self.serve_forever)
- thread.daemon = True
- thread.start()
- return self
- def cleanup(self):
- try:
- self.shutdown()
- except KeyboardInterrupt:
- pass
- def __exit__(self, type_, value_, traceback_):
- self.cleanup()
- def get_certificate(self, host):
- return certutils.generate_cert(self.ca_cert_str, '', host)
- class TestClient(unittest.TestCase):
- _temp_dir = None
- def setUp(self):
- self._temp_dir = tempfile.mkdtemp(prefix='sslproxy_', dir='/tmp')
- self.ca_cert_path = self._temp_dir + 'testCA.pem'
- self.cert_path = self._temp_dir + 'testCA-cert.cer'
- self.wrong_ca_cert_path = self._temp_dir + 'wrong.pem'
- self.wrong_cert_path = self._temp_dir + 'wrong-cert.cer'
- # Write both pem and cer files for certificates
- certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
- cert_path=self.ca_cert_path)
- certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(),
- cert_path=self.ca_cert_path)
- def tearDown(self):
- if self._temp_dir:
- shutil.rmtree(self._temp_dir)
- def verify_cb(self, conn, cert, errnum, depth, ok):
- """A callback that verifies the certificate authentication worked.
- Args:
- conn: Connection object
- cert: x509 object
- errnum: possible error number
- depth: error depth
- ok: 1 if the authentication worked 0 if it didnt.
- Returns:
- 1 or 0 depending on if the verification worked
- """
- self.assertFalse(cert.has_expired())
- self.assertGreater(time.strftime('%Y%m%d%H%M%SZ', time.gmtime()),
- cert.get_notBefore())
- return ok
- def test_no_host(self):
- with Server(self.ca_cert_path) as server:
- c = Client(self.cert_path, self.verify_cb, server.server_port, '')
- self.assertRaises(certutils.Error, c.run_request)
- def test_client_connection(self):
- with Server(self.ca_cert_path) as server:
- c = Client(self.cert_path, self.verify_cb, server.server_port, 'foo.com')
- c.run_request()
- c = Client(self.cert_path, self.verify_cb, server.server_port,
- 'random.host')
- c.run_request()
- def test_wrong_cert(self):
- with Server(self.ca_cert_path, True) as server:
- c = Client(self.wrong_cert_path, self.verify_cb, server.server_port,
- 'foo.com')
- self.assertRaises(certutils.Error, c.run_request)
- if __name__ == '__main__':
- signal.signal(signal.SIGINT, signal.SIG_DFL) # Exit on Ctrl-C
- unittest.main()
|