Skoči na vsebino

RK - 2021/22 - LDN10 resitev Starc Aljaz

Adaptirana iz 2021-22-naloge/10-LDN10. Posodbljeni certifikati ter dodana logika za ukaz /list.

Client
client.py
#!/usr/bin/python3

import socket as Socket
import os as Os
import threading as Threading
import datetime as Datetime
import json as Json
import time as Time
import sys as Sys
import re as Re
import ssl as Ssl
import OpenSSL as OpenSSL
from typing import Dict



SOCKET_PORT = int(Os.getenv('PORT', 1234))
SOCKET_HOST = Os.getenv('HOST', "127.0.0.1")
TERM_COLS, TERM_ROWS = Os.get_terminal_size()

socket: Socket.socket = None
certificate = None


def log (level: str, line: str, prefix: str = ""):
    colors = {
        "info": u"\u001b[34m",
        "success": u"\u001b[32m",
        "warn": u"\u001b[33m",
        "error": u"\u001b[31m",
        "log": u"\u001b[37m"
    }
    print(u"%s%-12s [%s%s\u001b[0m] %s" % (
        prefix,
        Datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        colors[level],
        level.upper().center(7),
        line
    ), flush=True)



def receiver(socket: Socket.socket):
    rfile = socket.makefile()
    while True:
        try:
            raw = rfile.readline()
            if not raw:
                continue

            data: Dict = Json.loads(raw)
            if data.get("action") == "message" or data.get("action") == "dm":
                print("\033[%d;%dH[%12s%s %s" % (TERM_ROWS - 1, 0, data.get("from"), ">" if data.get("action") == "dm" else "]", data.get("data")), flush=True)
                print("[%12s] " % certificate.get_subject().CN, flush=True)
                print("\033[%d;%dH" % (TERM_ROWS - 1, 16), end='', flush=True)

            elif data.get('action') == "list":
                print("\033[%d;%dH[%12s] %s" % (TERM_ROWS - 1, 0, "list", data.get("data")), flush=True)
                print("[%12s] " % certificate.get_subject().CN, flush=True)
                print("\033[%d;%dH" % (TERM_ROWS - 1, 16), end='', flush=True)

        except Exception as e:
            print("Error")
            print(e)
            pass



def send(data: Dict):
    bdata = bytes(Json.dumps(data) + "\n", "UTF-8")
    socket.send(bdata)



def command(msg: str):
    cmd, *args = msg[1:].split(" ")
    if cmd == "help":
        print("""
\u001b[33mCommands\u001b[0m:
    \u001b[34m/help\u001b[0m
        Display this help menu

    \u001b[34m/list\u001b[0m
        Display a list of distinct usernames of connected clients

    \u001b[34m/msg\u001b[0m <\u001b[34msession\u001b[0m> <\u001b[34mcontent\u001b[0m>
        Send a direct message to a specific session
        <\u001b[34msession\u001b[0m> is the session identificator. You can get list of them with \u001b[34m/list\u001b[0m
        <\u001b[34mcontent\u001b[0m> is your message content
""")

    elif cmd == "msg":
        send({ "action": "dm", "to": args[0], "data": " ".join(args[1:]) })

    elif cmd == "list":
        send({ "action": "list" })
    
    else:
        print("\u001b[31mERROR!\u001b[0m Invalid command. Try \u001b[34m/help\u001b[0m for a list of commands")


socket_ssl = None
try:
    while not certificate or not socket_ssl:
        print(chr(27) + "[2J")
        print("\033[%d;%dH" % (TERM_ROWS / 2, TERM_COLS / 2 - 30), flush=True, end="")
        vpis = input("Certificate name (client1): ")

        if Os.path.exists(vpis + ".key") and Os.path.exists(vpis + ".pem"):
            socket_ssl = Ssl.SSLContext(Ssl.PROTOCOL_TLSv1_2)
            socket_ssl.verify_mode = Ssl.CERT_REQUIRED
            socket_ssl.load_cert_chain(certfile = vpis + ".pem", keyfile = vpis + ".key")
            socket_ssl.load_verify_locations('./rootCA.pem')
            socket_ssl.set_ciphers('ECDHE-RSA-AES128-GCM-SHA256')
            with open(vpis + ".pem", "r") as f:
                certificate = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, f.read())

        print(chr(27) + "[2J")
except KeyboardInterrupt:
    print()
    Sys.exit()

while True:
    try:
        print("\033[%d;%dH" % (TERM_ROWS - 1, 0), flush=True)
        print("[      \u001b[34msystem\u001b[0m] Use \u001b[34m/help\u001b[0m for a list of commands")
        print("[      \u001b[34msystem\u001b[0m] connecting to chat server ... ", end="", flush=True)

        socket = socket_ssl.wrap_socket(Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM))
        connected = False
        while not connected:
            try:
                socket.connect((SOCKET_HOST, SOCKET_PORT))
                connected = True
            except Exception as e:
                print(e)
                pass
            Time.sleep(1.0)
        print("\u001b[32mCONNECTED\u001b[0m!")

        thread = Threading.Thread(target=receiver, args=(socket,))
        thread.daemon = True
        thread.start()

        while True:
            try:
                print("[%12s] " % certificate.get_subject().CN)
                print("\033[%d;%dH" % (TERM_ROWS - 1, 16), end='', flush=True)
                vpis = input("")
                
                if not vpis:
                    print ("\033[A" + " " * TERM_COLS + "\033[A", flush=True)
                    continue
                
                if vpis.startswith("/"):
                    command(vpis)
                else:
                    send({ "action": "message", "data": vpis })
            except Exception as e:
                log("error", e)
                print("Something went wrong! Restarting...")
                break
    except KeyboardInterrupt:
        print()
        Sys.exit()
    except Exception as e:
        print(e)
        Sys.exit()
Server
server.py
#!/usr/bin/python3

import socket as Socket
import os as Os
import threading as Threading
import datetime as Datetime
import json as Json
import time as Time
import ssl as Ssl
import OpenSSL as OpenSSL

from typing import Dict, Set, TextIO



SOCKET_PORT = int(Os.getenv('PORT', 1234))
SOCKET_BIND = Os.getenv('BIND', "127.0.0.1")



def log (level: str, line: str, prefix: str = ""):
    colors = {
        "info": u"\u001b[34m",
        "success": u"\u001b[32m",
        "warn": u"\u001b[33m",
        "error": u"\u001b[31m",
        "log": u"\u001b[37m"
    }
    print(u"%s%-12s [%s%s\u001b[0m] %s" % (
        prefix,
        Datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        colors[level],
        level.upper().center(7),
        line
    ), flush=True)



class Client():

    socket: Socket.socket = None
    rfile: TextIO = None
    address = None
    thread: Threading.Thread = None
    cert = None


    def __init__(self, socket: Socket.socket, address) -> None:
        self.socket = socket
        self.address = address
        self.rfile = socket.makefile()
        self.cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, socket.getpeercert(True))
        self.thread = Threading.Thread(target=self.entrypoint)
        self.thread.daemon = True
        self.thread.start()


    def entrypoint (self):
        """
        The client thread entrypoint
        """
        log("log", "[client ] %s connected " % self.cert.get_subject().CN)
        emptyCount = 0
        while True:
            try:
                content = self.rfile.readline()
                if not content:
                    emptyCount += 1
                    if emptyCount > 10: break
                    continue
                emptyCount = 0
                data: Dict = Json.loads(content)

                if (data.get('action') == 'message'):
                    log("log", "[public ] %s: %s " % (self.cert.get_subject().CN, data.get('data')))
                    for client in clients:
                        if client.cert.get_subject().CN != self.cert.get_subject().CN:
                            client.send({ "action": "message", "from": self.cert.get_subject().CN, "data": data.get('data') })

                if (data.get('action') == 'dm'):
                    log("log", "[private] %s -> %s: %s " % (self.cert.get_subject().CN, data.get('to'), data.get('data')))
                    for client in clients:
                        if client.cert.get_subject().CN == data.get('to'):
                            client.send({ "action": "dm", "from": self.cert.get_subject().CN, "data": data.get('data') })

                if (data.get('action') == 'list'):
                    log("log", "[list   ] %s" % (self.cert.get_subject().CN))

                    self.send({ "action": "list", "data": ", ".join(set(_client.cert.get_subject().CN for _client in clients if _client.cert.get_subject().CN)) })

            except Exception as e:
                print(e)
                pass
        self.close()


    def send(self, data: Dict):
        """
        Send data to client
        """
        bdata = bytes(Json.dumps(data) + "\n", "UTF-8")
        self.socket.send(bdata)


    def close(self):
        """
        Gracefully close the client socket
        """
        with server_lock:
            log("log", "[client] %s closing ..." % self.cert.get_subject().CN)
            self.socket.close()
            clients.remove(self)
            log("warn", "[client] %s closed" % self.cert.get_subject().CN)

def setup_SSL_context():
    context = Ssl.SSLContext(Ssl.PROTOCOL_TLSv1_2)
    context.verify_mode = Ssl.CERT_REQUIRED
    context.load_cert_chain(certfile="./rootCA.pem", keyfile="./rootCA.key")
    context.load_verify_locations('./rootCA.pem')
    context.set_ciphers('ECDHE-RSA-AES128-GCM-SHA256')
    return context

log("info", "[system ] Starting server on port %d ..." % SOCKET_PORT)
clients: Set[Client] = set()
server_lock = Threading.Lock()

server_ssl = setup_SSL_context()
server_socket = server_ssl.wrap_socket(Socket.socket(Socket.AF_INET, Socket.SOCK_STREAM))
server_socket.bind((SOCKET_BIND, SOCKET_PORT))
server_socket.listen(10)
log("success", "[system ] Started server on port %d ..." % SOCKET_PORT)



log("info", "[system ] Starting healthchecker")
def healthchecker():
    while True:
        for client in list(clients):
            try: client.send({"action": "healthcheck"})
            except: pass
        Time.sleep(5)

healthchecker_thread = Threading.Thread(target=healthchecker)
healthchecker_thread.daemon = True
healthchecker_thread.start()
log("success", "[system ] Started healthchecker")



while True:
    try:
        socket, address = server_socket.accept()

        with server_lock:
            client = Client(socket, address)
            clients.add(client)

    except KeyboardInterrupt:
        break
    except:
        pass



log("warn", "[system ] Closing server ...", "\n")
for client in list(clients):
    client.close()

server_socket.close()
log("info", "[system ] Server closed")

Zadnja posodobitev: May 30, 2023