Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__pycache__
.vscode
smallchat.dSYM
smallchat-server
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ smallchat-server: smallchat-server.c chatlib.c
smallchat-client: smallchat-client.c chatlib.c
$(CC) smallchat-client.c chatlib.c -o smallchat-client $(CFLAGS)

test: smallchat-server smallchat-client
python3 -m unittest process.py -v
python3 -m unittest tests.py -v

clean:
rm -f smallchat-server
rm -f smallchat-client
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Smallchat

Inspired by Salvatore Sanfilippo's series on creating a chat
Implemented in python


TLDR: This is just a programming example for a few friends of mine. It somehow turned into a set of programming videos, continuing one project I started some time ago: Writing System Software videos series.

1. [First episode](https://www.youtube.com/watch?v=eT02gzeLmF0), how the basic server works.
Expand Down
116 changes: 116 additions & 0 deletions process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import subprocess
import sys
import unittest


class Process:
def __init__(self, args):
self.proc = subprocess.Popen(
args, stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def close(self):
self.stop()
self.proc.stdin.close()
self.proc.stdout.close()

def read(self):
return self.proc.stdout.readline().strip()

def stop(self):
self.terminate()
self.wait()

def terminate(self):
self.proc.terminate()

def wait(self):
self.proc.wait()

def write(self, msg):
# print(f"Process.write msg: {msg}", file=sys.stderr)
self.proc.stdin.write(msg + b"\n")
self.proc.stdin.flush()


class TestProcess(unittest.TestCase):
def setUp(self):
self.p = Process([sys.executable, __file__])

def tearDown(self):
self.p.close()

def test_stdout(self):
line = self.p.read()
self.assertEqual(line, b"started")
self.p.stop()
line = self.p.read()
self.assertFalse(line)

def test_stdin(self):
line = self.p.read()
self.assertEqual(line, b"started")
self.p.write(b"test-request")
line = self.p.read()
self.assertEqual(line, b"test-request")
self.p.stop()
line = self.p.read()
self.assertFalse(line)

def test_cycle(self):
line = self.p.read()
self.assertEqual(line, b"started")
self.p.write(b"test-request-1")
line = self.p.read()
self.assertEqual(line, b"test-request-1")
self.p.write(b"test-request-2")
line = self.p.read()
self.assertEqual(line, b"test-request-2")
self.p.write(b"/exit")
self.p.wait()
line = self.p.read()
self.assertFalse(line)


def test_long(self):
line = self.p.read()
self.assertEqual(line, b"started")
request = b"0123456789" * 10000
self.p.write(request)
line = self.p.read()
self.assertEqual(line, request)
self.p.stop()
line = self.p.read()
self.assertFalse(line)


class TestProcessContext(unittest.TestCase):
def test_stdout(self):
with Process([sys.executable, __file__]) as p:
line = p.read()
self.assertEqual(line, b"started")
p.stop()
line = p.read()
self.assertFalse(line)


def test_main():
sys.stdout.write("started\n")
sys.stdout.flush()
while True:
# print("read", file=sys.stderr)
request = sys.stdin.readline().strip()
# print("request", request, file=sys.stderr)
if request == '/exit':
break
sys.stdout.write(request + "\n")
sys.stdout.flush()


if __name__ == '__main__':
test_main(*sys.argv[1:])
187 changes: 187 additions & 0 deletions smallchat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import select
import socket
import sys

WELCOME = b"Welcome to Simple Chat! Use /nick <nick> to set your nick."
PREFIX = b"/nick "


class Client:
def __init__(self, conn, protocol_cls, notify_receive, notify_close):
self.conn = conn
self.notify_receive = notify_receive
self.notify_close = notify_close
self.fd = conn.fileno()
self.protocol = protocol_cls(self.notify_receive)
self.out_buffer = bytearray()

def raw_receive(self):
data = self.conn.recv(1024)
if not data:
self.notify_close(self)
self.conn.close()
else:
self.protocol.decode(data)

def raw_send(self):
if self.out_buffer:
sent = self.conn.send(self.out_buffer)
self.out_buffer = self.out_buffer[sent:]

def send(self, msg):
# print(f"send msg: {msg}", file=sys.stderr)
self.out_buffer += self.protocol.encode(msg)


class Clients:
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs
self.clients = {}

def add(self, client):
self.clients[client.fd] = client
self.inputs.append(client.conn)
self.outputs.append(client.conn)

def delete(self, client):
self.clients.pop(client.fd)
self.inputs.remove(client.conn)
self.outputs.remove(client.conn)

def get(self, conn):
return self.clients[conn.fileno()]


class ChatClient(Client):
def __init__(self, conn, protocol_cls, publish, notify_close):
super().__init__(conn, protocol_cls, self._received, notify_close)
self.publish = publish
self.nick = f"user:{conn.fileno()}"

def _received(self, msg):
# print(f"received msg: {msg}", file=sys.stderr)
if msg.startswith(PREFIX):
self.nick = msg[len(PREFIX):].decode()
else:
self.publish(self, msg)


class ChatClients(Clients):
def add(self, client):
super().add(client)
# print(f"Connected client fd={client.fd}, nick={client.nick}")
client.send(WELCOME)

def delete(self, client):
super().delete(client)
# print(f"Disconnected client fd={client.fd}, nick={client.nick}")

def publish(self, sender, msg):
response = sender.nick.encode() + b"> " + msg
for client in self.clients.values():
if client != sender:
client.send(response)


class Protocol:
END = b"\n"

def __init__(self, notify, end=None):
self.notify = notify
self.end = end or self.END
self.buff = bytearray()

@classmethod
def encode(cls, msg):
assert not cls.END in msg
return msg + cls.END

def decode(self, data):
for car in data:
if car == ord(self.END):
self.notify(self.buff)
self.buff.clear()
else:
self.buff.append(car)


class Stream:
def __init__(self, stdin, stdout):
self.stdin = stdin
self.stdout = stdout
self.closed = False
self.send = None

def close(self):
self.closed = True

def raw_receive(self):
msg = self.stdin.readline().rstrip()
self.send(msg.encode())

def receive(self, msg):
self.stdout.write(msg.decode() + "\n")
self.stdout.flush()

def raw_send(self):
pass # noop


def _main_client(address):
stream = Stream(sys.stdin, sys.stdout)

with socket.socket() as conn:
conn.connect(address)
client = Client(conn, Protocol, notify_receive=stream.receive, notify_close=stream.close)
stream.protocol = Protocol(client.send, "\n")
stream.send = client.send
inputs = [conn, sys.stdin]
outputs = [conn, sys.stdout]
clients = {conn: client, sys.stdin: stream, sys.stdout: stream}
while not stream.closed:
inputready, outputready, exceptready = select.select(inputs, outputs, [])
for s in inputready:
clients.get(s).raw_receive()
for s in outputready:
if s.fileno() <= 0:
# sockets already closed during reception/recv
continue
clients.get(s).raw_send()


def _main_server(address):
with socket.socket() as sl:
sl.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sl.bind(address)
sl.listen()
sys.stdout.write(f"Server started address={address}\n")
sys.stdout.flush()
inputs = [sl]
outputs = []
clients = ChatClients(inputs, outputs)
while True:
inputready, outputready, exceptready = select.select(inputs, outputs, [])
for s in inputready:
if s == sl:
conn, addr = sl.accept()
client = ChatClient(conn, Protocol, clients.publish, clients.delete)
clients.add(client)
else:
clients.get(s).raw_receive()
for s in outputready:
if s.fileno() <= 0:
# sockets already closed during reception/recv
continue
clients.get(s).raw_send()


def main(role, host, port):
fun = globals()["_main_" + role]
address = (host, int(port))
fun(address)


if __name__ == '__main__':
main(*sys.argv[1:])

Loading