mqttsas.py
Go to the documentation of this file.
1 """
2 *******************************************************************
3  Copyright (c) 2013, 2020 IBM Corp.
4 
5  All rights reserved. This program and the accompanying materials
6  are made available under the terms of the Eclipse Public License v2.0
7  and Eclipse Distribution License v1.0 which accompany this distribution.
8 
9  The Eclipse Public License is available at
10  https://www.eclipse.org/legal/epl-2.0/
11  and the Eclipse Distribution License is available at
12  http://www.eclipse.org/org/documents/edl-v10.php.
13 
14  Contributors:
15  Ian Craggs - initial implementation and/or documentation
16  Ian Craggs - add MQTTV5 support
17 *******************************************************************
18 """
19 from __future__ import print_function
20 
21 import socket
22 import sys
23 import select
24 import traceback
25 import datetime
26 import os
27 import base64
28 import hashlib
29 import logging
30 try:
31  import socketserver
32  import MQTTV311 # Trace MQTT traffic - Python 3 version
33  import MQTTV5
34 except:
35  traceback.print_exc()
36  import SocketServer as socketserver
37  import MQTTV3112 as MQTTV311 # Trace MQTT traffic - Python 2 version
38  import MQTTV5
39 
40 MQTT = MQTTV311
41 logging = True
42 myWindow = None
43 
44 
46 
47  def __init__(self, socket):
48  self.socket = socket
49  self.buffer = bytearray()
50  self.websockets = False
51 
52  def close(self):
53  self.socket.shutdown(socket.SHUT_RDWR)
54  self.socket.close()
55 
56  def rebuffer(self, data):
57  self.buffer = data + self.buffer
58 
59  def wsrecv(self):
60  try:
61  header1 = ord(self.socket.recv(1))
62  header2 = ord(self.socket.recv(1))
63  except:
64  return
65 
66  opcode = (header1 & 0x0f)
67  maskbit = (header2 & 0x80) == 0x80
68  length = (header2 & 0x7f) # works for 0 to 125 inclusive
69  if length == 126: # for 126 to 65535 inclusive
70  lb1 = ord(self.socket.recv(1))
71  lb2 = ord(self.socket.recv(1))
72  length = lb1*256 + lb2
73  elif length == 127:
74  length = 0
75  for i in range(0, 8):
76  length += ord(self.socket.recv(1)) * 2**((7 - i)*8)
77  assert maskbit == True
78  if maskbit:
79  mask = self.socket.recv(4)
80  mpayload = bytearray()
81  while len(mpayload) < length:
82  mpayload += self.socket.recv(length - len(mpayload))
83  buffer = bytearray()
84  if maskbit:
85  mi = 0
86  for i in mpayload:
87  buffer.append(i ^ mask[mi])
88  mi = (mi+1) % 4
89  else:
90  buffer = mpayload
91  self.buffer += buffer
92 
93  def recv(self, bufsize):
94  if self.websockets:
95  while len(self.buffer) < bufsize:
96  self.wsrecv()
97  out = self.buffer[:bufsize]
98  self.buffer = self.buffer[bufsize:]
99  else:
100  if bufsize <= len(self.buffer):
101  out = self.buffer[:bufsize]
102  self.buffer = self.buffer[bufsize:]
103  else:
104  out = self.buffer + \
105  self.socket.recv(bufsize - len(self.buffer))
106  self.buffer = bytes()
107  return out
108 
109  def __getattr__(self, name):
110  return getattr(self.socket, name)
111 
112  def send(self, data):
113  header = bytearray()
114  if self.websockets:
115  header.append(0x82) # opcode
116  l = len(data)
117  if l < 126:
118  header.append(l)
119  elif l < 65536:
120  """ If 126, the following 2 bytes interpreted as a 16-bit unsigned integer are
121  the payload length.
122  """
123  header += bytearray([126, l // 256, l % 256])
124  elif l < 2**64:
125  """ If 127, the following 8 bytes interpreted as a 64-bit unsigned integer (the
126  most significant bit MUST be 0) are the payload length.
127  """
128  mybytes = [127]
129  for i in range(0, 7):
130  divisor = 2**((7 - i)*8)
131  mybytes.append(l // divisor)
132  l %= divisor
133  mybytes.append(l) # units
134  header += bytearray(mybytes)
135  totaldata = header + data
136  # Ensure the entire packet is sent by calling send again if necessary
137  sent = self.socket.send(totaldata)
138  while sent < len(totaldata):
139  sent += self.socket.send(totaldata[sent:])
140  return sent
141 
142 
143 def timestamp():
144  now = datetime.datetime.now()
145  return now.strftime('%Y%m%d %H%M%S')+str(float("."+str(now.microsecond)))[1:]
146 
147 
148 suspended = []
149 
150 
151 class MyHandler(socketserver.StreamRequestHandler):
152 
153  def getheaders(self, data):
154  "return headers: keys are converted to upper case so that checks are case insensitive"
155  headers = {}
156  lines = data.splitlines()
157  for curline in lines[1:]:
158  if curline.find(":") != -1:
159  key, value = curline.split(": ", 1)
160  headers[key.upper()] = value # headers are case insensitive
161  return headers
162 
163  def handshake(self, client):
164  GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
165  data = client.recv(1024).decode('utf-8')
166  headers = self.getheaders(data)
167  digest = base64.b64encode(hashlib.sha1(
168  (headers['SEC-WEBSOCKET-KEY'] + GUID).encode("utf-8")).digest())
169  resp = b"HTTP/1.1 101 Switching Protocols\r\n" +\
170  b"Upgrade: websocket\r\n" +\
171  b"Connection: Upgrade\r\n" +\
172  b"Sec-WebSocket-Protocol: mqtt\r\n" +\
173  b"Sec-WebSocket-Accept: " + digest + b"\r\n\r\n"
174  return client.send(resp)
175 
176  def handle(self):
177  global MQTT
178  if not hasattr(self, "ids"):
179  self.ids = {}
180  if not hasattr(self, "versions"):
181  self.versions = {}
182  inbuf = True
183  first = True
184  i = o = e = None
185  try:
186  clients = BufferedSockets(self.request)
187  sock_no = clients.fileno()
188  brokers = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
189  brokers.connect((brokerhost, brokerport))
190  terminated = False
191  while inbuf != None and not terminated:
192  (i, o, e) = select.select([clients, brokers], [], [])
193  for s in i:
194  if s in suspended:
195  print("suspended")
196  if s == clients and s not in suspended:
197  if first:
198  char = clients.recv(1)
199  clients.rebuffer(char)
200  if char == b"G": # should be websocket connection
201  self.handshake(clients)
202  clients.websockets = True
203  print("Switching to websockets for socket %d" % sock_no)
204  inbuf = MQTT.getPacket(clients) # get one packet
205  if inbuf == None:
206  break
207  try:
208  # if connect, this could be MQTTV3 or MQTTV5
209  if inbuf[0] >> 4 == 1: # connect packet
210  protocol_string = b'MQTT'
211  pos = inbuf.find(protocol_string)
212  if pos != -1:
213  version = inbuf[pos +
214  len(protocol_string)]
215  if version == 5:
216  MQTT = MQTTV5
217  else:
218  MQTT = MQTTV311
219  packet = MQTT.unpackPacket(inbuf)
220  if hasattr(packet.fh, "MessageType"):
221  packet_type = packet.fh.MessageType
222  publish_type = MQTT.PUBLISH
223  connect_type = MQTT.CONNECT
224  else:
225  packet_type = packet.fh.PacketType
226  publish_type = MQTT.PacketTypes.PUBLISH
227  connect_type = MQTT.PacketTypes.CONNECT
228  if packet_type == publish_type and \
229  packet.topicName == "MQTTSAS topic" and \
230  packet.data == b"TERMINATE":
231  print("Terminating client", self.ids[id(clients)])
232  brokers.close()
233  clients.close()
234  terminated = True
235  break
236  elif packet_type == publish_type and \
237  packet.topicName == "MQTTSAS topic" and \
238  packet.data == b"TERMINATE_SERVER":
239  print("Suspending client ", self.ids[id(clients)])
240  suspended.append(clients)
241  elif packet_type == connect_type:
242  self.ids[id(clients)
243  ] = packet.ClientIdentifier
244  self.versions[id(clients)] = 3
245  print(timestamp(), "C to S",
246  self.ids[id(clients)], str(packet))
247  #print([hex(b) for b in inbuf])
248  # print(inbuf)
249  except:
250  traceback.print_exc()
251  brokers.send(inbuf) # pass it on
252  elif s == brokers:
253  inbuf = MQTT.getPacket(brokers) # get one packet
254  if inbuf == None:
255  break
256  try:
257  print(timestamp(), "S to C", self.ids[id(clients)], str(MQTT.unpackPacket(inbuf)))
258  except:
259  traceback.print_exc()
260  clients.send(inbuf)
261  print(timestamp()+" client " + self.ids[id(clients)]+" connection closing")
262  first = False
263  except:
264  print(repr((i, o, e)), repr(inbuf))
265  traceback.print_exc()
266  if id(clients) in self.ids.keys():
267  del self.ids[id(clients)]
268  elif id(clients) in self.versions.keys():
269  del self.versions[id(clients)]
270 
271 
272 class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
273  pass
274 
275 
276 def run():
277  global brokerhost, brokerport
278  myhost = '127.0.0.1'
279  if len(sys.argv) > 1:
280  brokerhost = sys.argv[1]
281  else:
282  brokerhost = '127.0.0.1'
283 
284  if len(sys.argv) > 2:
285  brokerport = int(sys.argv[2])
286  else:
287  brokerport = 1883
288 
289  if len(sys.argv) > 3:
290  myport = int(sys.argv[3])
291  else:
292  if brokerhost == myhost:
293  myport = brokerport + 1
294  else:
295  myport = 1883
296 
297  print("Listening on port", str(myport)+", broker on port", brokerport)
298  s = ThreadingTCPServer(("127.0.0.1", myport), MyHandler)
299  s.serve_forever()
300 
301 
302 if __name__ == "__main__":
303  run()
def recv(self, bufsize)
Definition: mqttsas.py:93
def __init__(self, socket)
Definition: mqttsas.py:47
def rebuffer(self, data)
Definition: mqttsas.py:56
def send(self, data)
Definition: mqttsas.py:112
def encode(x)
Definition: MQTTV5.py:229
def handshake(self, client)
Definition: mqttsas.py:163
def getheaders(self, data)
Definition: mqttsas.py:153
def handle(self)
Definition: mqttsas.py:176
void print(std::FILE *f, const S &format_str, Args &&...args)
Definition: core.h:2101
def run()
Definition: mqttsas.py:276
def __getattr__(self, name)
Definition: mqttsas.py:109
def decode(buffer)
Definition: MQTTV5.py:247
def timestamp()
Definition: mqttsas.py:143
Definition: format.h:3618
int len
Definition: utf-8.c:46


plotjuggler
Author(s): Davide Faconti
autogenerated on Sun Dec 6 2020 03:48:09