___websocket.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """
  2. REF:: https://github.com/mtah/python-websocket (+ /.setup/websocket.py.patch)
  3. This program is free software: you can redistribute it and/or modify
  4. it under the terms of the GNU General Public License as published by
  5. the Free Software Foundation, either version 3 of the License, or
  6. (at your option) any later version.
  7. This program is distributed in the hope that it will be useful,
  8. but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. GNU General Public License for more details.
  11. You should have received a copy of the GNU General Public License
  12. along with this program. If not, see <http://www.gnu.org/licenses/>
  13. """
  14. import sys, re, socket, asyncore
  15. if sys.version_info[0] < 3:
  16. import urlparse as urlparse
  17. else:
  18. import urllib.parse as urlparse
  19. class WebSocket(object):
  20. def __init__(self, url, **kwargs):
  21. self.host, self.port, self.resource, self.secure = WebSocket._parse_url(url)
  22. self.protocol = kwargs.pop('protocol', None)
  23. self.cookie_jar = kwargs.pop('cookie_jar', None)
  24. self.onopen = kwargs.pop('onopen', None)
  25. self.onmessage = kwargs.pop('onmessage', None)
  26. self.onerror = kwargs.pop('onerror', None)
  27. self.onclose = kwargs.pop('onclose', None)
  28. if kwargs: raise ValueError('Unexpected argument(s): %s' % ', '.join(list(kwargs.values())))
  29. self._dispatcher = _Dispatcher(self)
  30. def send(self, data):
  31. self._dispatcher.write('\x00' + _utf8(data) + '\xff')
  32. def close(self):
  33. self._dispatcher.handle_close()
  34. @classmethod
  35. def _parse_url(cls, url):
  36. p = urlparse.urlparse(url)
  37. if p.hostname:
  38. host = p.hostname
  39. else:
  40. raise ValueError('URL must be absolute')
  41. if p.fragment:
  42. raise ValueError('URL must not contain a fragment component')
  43. if p.scheme == 'ws':
  44. secure = False
  45. port = p.port or 80
  46. elif p.scheme == 'wss':
  47. raise NotImplementedError('Secure WebSocket not yet supported')
  48. # secure = True
  49. # port = p.port or 443
  50. else:
  51. raise ValueError('Invalid URL scheme')
  52. resource = p.path or '/'
  53. if p.query: resource += '?' + p.query
  54. return (host, port, resource, secure)
  55. #@classmethod
  56. #def _generate_key(cls):
  57. # spaces = random.randint(1, 12)
  58. # number = random.randint(0, 0xffffffff/spaces)
  59. # key = list(str(number*spaces))
  60. # chars = map(unichr, range(0x21, 0x2f) + range(0x3a, 0x7e))
  61. # random_inserts = random.sample(xrange(len(key)), random.randint(1,12))
  62. # for (i, c) in [(r, random.choice(chars)) for r in random_inserts]:
  63. # key.insert(i, c)
  64. # print key
  65. # return ''.join(key)
  66. class WebSocketError(Exception):
  67. def _init_(self, value):
  68. self.value = value
  69. def _str_(self):
  70. return str(self.value)
  71. class _Dispatcher(asyncore.dispatcher):
  72. def __init__(self, ws):
  73. asyncore.dispatcher.__init__(self)
  74. self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
  75. self.connect((ws.host, ws.port))
  76. self.ws = ws
  77. self._read_buffer = ''
  78. self._write_buffer = ''
  79. self._handshake_complete = False
  80. if self.ws.port != 80:
  81. hostport = '%s:%d' % (self.ws.host, self.ws.port)
  82. else:
  83. hostport = self.ws.host
  84. fields = [
  85. 'Upgrade: WebSocket',
  86. 'Connection: Upgrade',
  87. 'Host: ' + hostport,
  88. 'Origin: http://' + hostport,
  89. #'Sec-WebSocket-Key1: %s' % WebSocket.generate_key(),
  90. #'Sec-WebSocket-Key2: %s' % WebSocket.generate_key()
  91. ]
  92. if self.ws.protocol: fields['Sec-WebSocket-Protocol'] = self.ws.protocol
  93. if self.ws.cookie_jar:
  94. cookies = [c for c in self.ws.cookie_jar if _cookie_for_domain(c, _eff_host(self.ws.host)) and \
  95. _cookie_for_path(c, self.ws.resource) and \
  96. not c.is_expired()]
  97. for cookie in cookies:
  98. fields.append('Cookie: %s=%s' % (cookie.name, cookie.value))
  99. # key3 = ''.join(map(unichr, (random.randrange(256) for i in xrange(8))))
  100. self.write(_utf8('GET %s HTTP/1.1\r\n' \
  101. '%s\r\n\r\n' % (self.ws.resource,
  102. '\r\n'.join(fields))))
  103. # key3)))
  104. def handl_expt(self):
  105. self.handle_error()
  106. def handle_error(self):
  107. self.close()
  108. t, e, trace = sys.exc_info()
  109. if self.ws.onerror:
  110. self.ws.onerror(e)
  111. else:
  112. asyncore.dispatcher.handle_error(self)
  113. def handle_close(self):
  114. self.close()
  115. if self.ws.onclose:
  116. self.ws.onclose()
  117. def handle_read(self):
  118. if self._handshake_complete:
  119. self._read_until('\xff', self._handle_frame)
  120. else:
  121. self._read_until('\r\n\r\n', self._handle_header)
  122. def handle_write(self):
  123. sent = self.send(self._write_buffer)
  124. self._write_buffer = self._write_buffer[sent:]
  125. def writable(self):
  126. return len(self._write_buffer) > 0
  127. def write(self, data):
  128. self._write_buffer += data # TODO: separate buffer for handshake from data to
  129. # prevent mix-up when send() is called before
  130. # handshake is complete?
  131. def _read_until(self, delimiter, callback):
  132. def lookForAndHandleCompletedFrame():
  133. pos = self._read_buffer.find(delimiter)
  134. if pos >= 0:
  135. pos += len(delimiter)
  136. data = self._read_buffer[:pos]
  137. self._read_buffer = self._read_buffer[pos:]
  138. if data:
  139. callback(data)
  140. lookForAndHandleCompletedFrame()
  141. self._read_buffer += self.recv(4096)
  142. lookForAndHandleCompletedFrame()
  143. def _handle_frame(self, frame):
  144. assert frame[-1] == '\xff'
  145. if frame[0] != '\x00':
  146. raise WebSocketError('WebSocket stream error')
  147. if self.ws.onmessage:
  148. self.ws.onmessage(frame[1:-1])
  149. # TODO: else raise WebSocketError('No message handler defined')
  150. def _handle_header(self, header):
  151. assert header[-4:] == '\r\n\r\n'
  152. start_line, fields = _parse_http_header(header)
  153. if start_line != 'HTTP/1.1 101 Web Socket Protocol Handshake' or \
  154. fields.get('Connection', None) != 'Upgrade' or \
  155. fields.get('Upgrade', None) != 'WebSocket':
  156. raise WebSocketError('Invalid server handshake')
  157. self._handshake_complete = True
  158. if self.ws.onopen:
  159. self.ws.onopen()
  160. _IPV4_RE = re.compile(r'\.\d+$')
  161. _PATH_SEP = re.compile(r'/+')
  162. def _parse_http_header(header):
  163. def split_field(field):
  164. k, v = field.split(':', 1)
  165. return (k, v.strip())
  166. lines = header.strip().split('\r\n')
  167. if len(lines) > 0:
  168. start_line = lines[0]
  169. else:
  170. start_line = None
  171. return (start_line, dict(map(split_field, lines[1:])))
  172. def _eff_host(host):
  173. if host.find('.') == -1 and not _IPV4_RE.search(host):
  174. return host + '.local'
  175. return host
  176. def _cookie_for_path(cookie, path):
  177. if not cookie.path or path == '' or path == '/':
  178. return True
  179. path = _PATH_SEP.split(path)[1:]
  180. cookie_path = _PATH_SEP.split(cookie.path)[1:]
  181. for p1, p2 in map(lambda *ps: ps, path, cookie_path):
  182. if p1 == None:
  183. return True
  184. elif p1 != p2:
  185. return False
  186. return True
  187. def _cookie_for_domain(cookie, domain):
  188. if not cookie.domain:
  189. return True
  190. elif cookie.domain[0] == '.':
  191. return domain.endswith(cookie.domain)
  192. else:
  193. return cookie.domain == domain
  194. def _utf8(s):
  195. return s.encode('utf-8')