trafficshaper_test.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. #!/usr/bin/env python
  2. # Copyright 2011 Google Inc. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """System integration test for traffic shaping.
  16. Usage:
  17. $ sudo ./trafficshaper_test.py
  18. """
  19. import daemonserver
  20. import logging
  21. import platformsettings
  22. import socket
  23. import SocketServer
  24. import trafficshaper
  25. import unittest
  26. RESPONSE_SIZE_KEY = 'response-size:'
  27. TEST_DNS_PORT = 5555
  28. TEST_HTTP_PORT = 8888
  29. TIMER = platformsettings.timer
  30. def GetElapsedMs(start_time, end_time):
  31. """Return milliseconds elapsed between |start_time| and |end_time|.
  32. Args:
  33. start_time: seconds as a float (or string representation of float).
  34. end_time: seconds as a float (or string representation of float).
  35. Return:
  36. milliseconds elapsed as integer.
  37. """
  38. return int((float(end_time) - float(start_time)) * 1000)
  39. class TrafficShaperTest(unittest.TestCase):
  40. def testBadBandwidthRaises(self):
  41. self.assertRaises(trafficshaper.BandwidthValueError,
  42. trafficshaper.TrafficShaper,
  43. down_bandwidth='1KBit/s')
  44. class TimedUdpHandler(SocketServer.DatagramRequestHandler):
  45. """UDP handler that returns the time when the request was handled."""
  46. def handle(self):
  47. data = self.rfile.read()
  48. read_time = self.server.timer()
  49. self.wfile.write(str(read_time))
  50. class TimedTcpHandler(SocketServer.StreamRequestHandler):
  51. """Tcp handler that returns the time when the request was read.
  52. It can respond with the number of bytes specified in the request.
  53. The request looks like:
  54. request_data -> RESPONSE_SIZE_KEY num_response_bytes '\n' ANY_DATA
  55. """
  56. def handle(self):
  57. data = self.rfile.read()
  58. read_time = self.server.timer()
  59. contents = str(read_time)
  60. if data.startswith(RESPONSE_SIZE_KEY):
  61. num_response_bytes = int(data[len(RESPONSE_SIZE_KEY):data.index('\n')])
  62. contents = '%s\n%s' % (contents,
  63. '\x00' * (num_response_bytes - len(contents) - 1))
  64. self.wfile.write(contents)
  65. class TimedUdpServer(SocketServer.ThreadingUDPServer,
  66. daemonserver.DaemonServer):
  67. """A simple UDP server similar to dnsproxy."""
  68. # Override SocketServer.TcpServer setting to avoid intermittent errors.
  69. allow_reuse_address = True
  70. def __init__(self, host, port, timer=TIMER):
  71. SocketServer.ThreadingUDPServer.__init__(
  72. self, (host, port), TimedUdpHandler)
  73. self.timer = timer
  74. def cleanup(self):
  75. pass
  76. class TimedTcpServer(SocketServer.ThreadingTCPServer,
  77. daemonserver.DaemonServer):
  78. """A simple TCP server similar to httpproxy."""
  79. # Override SocketServer.TcpServer setting to avoid intermittent errors.
  80. allow_reuse_address = True
  81. def __init__(self, host, port, timer=TIMER):
  82. SocketServer.ThreadingTCPServer.__init__(
  83. self, (host, port), TimedTcpHandler)
  84. self.timer = timer
  85. def cleanup(self):
  86. try:
  87. self.shutdown()
  88. except KeyboardInterrupt, e:
  89. pass
  90. class TcpTestSocketCreator(object):
  91. """A TCP socket creator suitable for with-statement."""
  92. def __init__(self, host, port, timeout=1.0):
  93. self.address = (host, port)
  94. self.timeout = timeout
  95. def __enter__(self):
  96. self.socket = socket.create_connection(self.address, timeout=self.timeout)
  97. return self.socket
  98. def __exit__(self, *args):
  99. self.socket.close()
  100. class TimedTestCase(unittest.TestCase):
  101. def assertValuesAlmostEqual(self, expected, actual, tolerance=0.05):
  102. """Like the following with nicer default message:
  103. assertTrue(expected <= actual + tolerance &&
  104. expected >= actual - tolerance)
  105. """
  106. delta = tolerance * expected
  107. if actual > expected + delta or actual < expected - delta:
  108. self.fail('%s is not equal to expected %s +/- %s%%' % (
  109. actual, expected, 100 * tolerance))
  110. class TcpTrafficShaperTest(TimedTestCase):
  111. def setUp(self):
  112. self.host = platformsettings.get_server_ip_address()
  113. self.port = TEST_HTTP_PORT
  114. self.tcp_socket_creator = TcpTestSocketCreator(self.host, self.port)
  115. self.timer = TIMER
  116. def TrafficShaper(self, **kwargs):
  117. return trafficshaper.TrafficShaper(
  118. host=self.host, ports=(self.port,), **kwargs)
  119. def GetTcpSendTimeMs(self, num_bytes):
  120. """Return time in milliseconds to send |num_bytes|."""
  121. with self.tcp_socket_creator as s:
  122. start_time = self.timer()
  123. request_data = '\x00' * num_bytes
  124. s.sendall(request_data)
  125. # TODO(slamm): Figure out why partial is shutdown needed to make it work.
  126. s.shutdown(socket.SHUT_WR)
  127. read_time = s.recv(1024)
  128. return GetElapsedMs(start_time, read_time)
  129. def GetTcpReceiveTimeMs(self, num_bytes):
  130. """Return time in milliseconds to receive |num_bytes|."""
  131. with self.tcp_socket_creator as s:
  132. s.sendall('%s%s\n' % (RESPONSE_SIZE_KEY, num_bytes))
  133. # TODO(slamm): Figure out why partial is shutdown needed to make it work.
  134. s.shutdown(socket.SHUT_WR)
  135. num_remaining_bytes = num_bytes
  136. read_time = None
  137. while num_remaining_bytes > 0:
  138. response_data = s.recv(4096)
  139. num_remaining_bytes -= len(response_data)
  140. if not read_time:
  141. read_time, padding = response_data.split('\n')
  142. return GetElapsedMs(read_time, self.timer())
  143. def testTcpConnectToIp(self):
  144. """Verify that it takes |delay_ms| to establish a TCP connection."""
  145. if not platformsettings.has_ipfw():
  146. logging.warning('ipfw is not available in path. Skip the test')
  147. return
  148. with TimedTcpServer(self.host, self.port):
  149. for delay_ms in (100, 175):
  150. with self.TrafficShaper(delay_ms=delay_ms):
  151. start_time = self.timer()
  152. with self.tcp_socket_creator:
  153. connect_time = GetElapsedMs(start_time, self.timer())
  154. self.assertValuesAlmostEqual(delay_ms, connect_time, tolerance=0.12)
  155. def testTcpUploadShaping(self):
  156. """Verify that 'up' bandwidth is shaped on TCP connections."""
  157. if not platformsettings.has_ipfw():
  158. logging.warning('ipfw is not available in path. Skip the test')
  159. return
  160. num_bytes = 1024 * 100
  161. bandwidth_kbits = 2000
  162. expected_ms = 8.0 * num_bytes / bandwidth_kbits
  163. with TimedTcpServer(self.host, self.port):
  164. with self.TrafficShaper(up_bandwidth='%sKbit/s' % bandwidth_kbits):
  165. self.assertValuesAlmostEqual(expected_ms, self.GetTcpSendTimeMs(num_bytes))
  166. def testTcpDownloadShaping(self):
  167. """Verify that 'down' bandwidth is shaped on TCP connections."""
  168. if not platformsettings.has_ipfw():
  169. logging.warning('ipfw is not available in path. Skip the test')
  170. return
  171. num_bytes = 1024 * 100
  172. bandwidth_kbits = 2000
  173. expected_ms = 8.0 * num_bytes / bandwidth_kbits
  174. with TimedTcpServer(self.host, self.port):
  175. with self.TrafficShaper(down_bandwidth='%sKbit/s' % bandwidth_kbits):
  176. self.assertValuesAlmostEqual(expected_ms, self.GetTcpReceiveTimeMs(num_bytes))
  177. def testTcpInterleavedDownloads(self):
  178. # TODO(slamm): write tcp interleaved downloads test
  179. pass
  180. class UdpTrafficShaperTest(TimedTestCase):
  181. def setUp(self):
  182. self.host = platformsettings.get_server_ip_address()
  183. self.dns_port = TEST_DNS_PORT
  184. self.timer = TIMER
  185. def TrafficShaper(self, **kwargs):
  186. return trafficshaper.TrafficShaper(
  187. host=self.host, ports=(self.dns_port,), **kwargs)
  188. def GetUdpSendReceiveTimesMs(self):
  189. """Return time in milliseconds to send |num_bytes|."""
  190. start_time = self.timer()
  191. udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  192. udp_socket.sendto('test data\n', (self.host, self.dns_port))
  193. read_time = udp_socket.recv(1024)
  194. return (GetElapsedMs(start_time, read_time),
  195. GetElapsedMs(read_time, self.timer()))
  196. def testUdpDelay(self):
  197. if not platformsettings.has_ipfw():
  198. logging.warning('ipfw is not available in path. Skip the test')
  199. return
  200. for delay_ms in (100, 170):
  201. expected_ms = delay_ms / 2
  202. with TimedUdpServer(self.host, self.dns_port):
  203. with self.TrafficShaper(delay_ms=delay_ms):
  204. send_ms, receive_ms = self.GetUdpSendReceiveTimesMs()
  205. self.assertValuesAlmostEqual(expected_ms, send_ms, tolerance=0.10)
  206. self.assertValuesAlmostEqual(expected_ms, receive_ms, tolerance=0.10)
  207. def testUdpInterleavedDelay(self):
  208. # TODO(slamm): write udp interleaved udp delay test
  209. pass
  210. class TcpAndUdpTrafficShaperTest(TimedTestCase):
  211. # TODO(slamm): Test concurrent TCP and UDP traffic
  212. pass
  213. # TODO(slamm): Packet loss rate (try different ports)
  214. if __name__ == '__main__':
  215. #logging.getLogger().setLevel(logging.DEBUG)
  216. unittest.main()