From a4fc06c8b1cea42733db8ca4edcff69151b8c5cd Mon Sep 17 00:00:00 2001
From: Artucuno <28388670+Artucuno@users.noreply.github.com>
Date: Mon, 6 Nov 2023 10:30:23 +1100
Subject: [PATCH 1/5] Update socket_manager.py

---
 fastapi_socketio/socket_manager.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/fastapi_socketio/socket_manager.py b/fastapi_socketio/socket_manager.py
index ff23726..65a2f56 100644
--- a/fastapi_socketio/socket_manager.py
+++ b/fastapi_socketio/socket_manager.py
@@ -33,6 +33,8 @@ def __init__(
         )
 
         app.mount(mount_location, self._app)
+        app.add_route(f"/{socketio_path}/", route=self._app, methods=["GET", "POST"])
+        app.add_websocket_route(f"/{socketio_path}/", self._app)
         app.sio = self._sio
 
     def is_asyncio_based(self) -> bool:

From 85df9bc148eb818da008c81d091fe8b88a4510b2 Mon Sep 17 00:00:00 2001
From: Artucuno <artucunov@gmail.com>
Date: Sat, 25 Nov 2023 13:24:54 +1100
Subject: [PATCH 2/5] Rewrite SocketManager

---
 fastapi_socketio/__init__.py       |  2 +-
 fastapi_socketio/socket_manager.py | 92 ++++++------------------------
 2 files changed, 17 insertions(+), 77 deletions(-)

diff --git a/fastapi_socketio/__init__.py b/fastapi_socketio/__init__.py
index 4bf4510..34802f2 100644
--- a/fastapi_socketio/__init__.py
+++ b/fastapi_socketio/__init__.py
@@ -1 +1 @@
-from .socket_manager import SocketManager
\ No newline at end of file
+from .socket_manager import SocketManager
diff --git a/fastapi_socketio/socket_manager.py b/fastapi_socketio/socket_manager.py
index 65a2f56..ece4fb2 100644
--- a/fastapi_socketio/socket_manager.py
+++ b/fastapi_socketio/socket_manager.py
@@ -2,8 +2,10 @@
 
 import socketio
 from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
 
-class SocketManager:
+
+class SocketManager(socketio.AsyncServer):
     """
     Integrates SocketIO with FastAPI app.
     Adds `sio` property to FastAPI object (app).
@@ -18,88 +20,26 @@ class SocketManager:
     """
 
     def __init__(
-        self,
-        app: FastAPI,
-        mount_location: str = "/ws",
-        socketio_path: str = "socket.io",
-        cors_allowed_origins: Union[str, list] = '*',
-        async_mode: str = "asgi",
-        **kwargs
+            self,
+            app: FastAPI,
+            mount_location: str = "/ws",
+            socketio_path: str = "socket.io",
+            cors_allowed_origins: Union[str, list] = '*',
+            async_mode: str = "asgi",
+            **kwargs
     ) -> None:
-        # TODO: Change Cors policy based on fastapi cors Middleware
-        self._sio = socketio.AsyncServer(async_mode=async_mode, cors_allowed_origins=cors_allowed_origins, **kwargs)
+        middleware = next((x for x in app.user_middleware if issubclass(x.cls, CORSMiddleware)), None)
+        if middleware:
+            cors_allowed_origins = middleware.options.get("allow_origins", "*")
+        super().__init__(cors_allowed_origins=cors_allowed_origins, async_mode=async_mode, **kwargs)
         self._app = socketio.ASGIApp(
-            socketio_server=self._sio, socketio_path=socketio_path
+            socketio_server=self, socketio_path=socketio_path
         )
 
         app.mount(mount_location, self._app)
         app.add_route(f"/{socketio_path}/", route=self._app, methods=["GET", "POST"])
         app.add_websocket_route(f"/{socketio_path}/", self._app)
-        app.sio = self._sio
+        app.sio = self
 
     def is_asyncio_based(self) -> bool:
         return True
-
-    @property
-    def on(self):
-        return self._sio.on
-
-    @property
-    def attach(self):
-        return self._sio.attach
-
-    @property
-    def emit(self):
-        return self._sio.emit
-
-    @property
-    def send(self):
-        return self._sio.send
-
-    @property
-    def call(self):
-        return self._sio.call
-
-    @property
-    def close_room(self):
-        return self._sio.close_room
-
-    @property
-    def get_session(self):
-        return self._sio.get_session
-
-    @property
-    def save_session(self):
-        return self._sio.save_session
-
-    @property
-    def session(self):
-        return self._sio.session
-
-    @property
-    def disconnect(self):
-        return self._sio.disconnect
-
-    @property
-    def handle_request(self):
-        return self._sio.handle_request
-
-    @property
-    def start_background_task(self):
-        return self._sio.start_background_task
-
-    @property
-    def sleep(self):
-        return self._sio.sleep
-
-    @property
-    def enter_room(self):
-        return self._sio.enter_room
-
-    @property
-    def leave_room(self):
-        return self._sio.leave_room
-    
-    @property
-    def register_namespace(self):
-        return self._sio.register_namespace

From da5974a7fb0abc1043c786859732b3268907fae6 Mon Sep 17 00:00:00 2001
From: Artucuno <artucunov@gmail.com>
Date: Sat, 25 Nov 2023 13:25:09 +1100
Subject: [PATCH 3/5] Update examples

---
 examples/app.py  |  3 +--
 examples/cors.py | 22 ++++++++++++++++++++++
 2 files changed, 23 insertions(+), 2 deletions(-)
 create mode 100644 examples/cors.py

diff --git a/examples/app.py b/examples/app.py
index fe62c18..78cd6f9 100644
--- a/examples/app.py
+++ b/examples/app.py
@@ -15,7 +15,6 @@ async def test(sid, *args, **kwargs):
     await sio.emit('hey', 'joe')
 
 
-
 if __name__ == '__main__':
     import logging
     import sys
@@ -25,4 +24,4 @@ async def test(sid, *args, **kwargs):
 
     import uvicorn
 
-    uvicorn.run("examples.app:app", host='0.0.0.0', port=8000, reload=True, debug=False)
+    uvicorn.run("examples.app:app", host='0.0.0.0', port=8000)
diff --git a/examples/cors.py b/examples/cors.py
new file mode 100644
index 0000000..b0e45af
--- /dev/null
+++ b/examples/cors.py
@@ -0,0 +1,22 @@
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+
+from fastapi_socketio import SocketManager
+
+app = FastAPI()
+# Adding the CORS middleware will overwrite SocketManager's CORS settings
+# Make sure to add the CORS middleware before SocketManager
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+sio = SocketManager(app=app, cors_allowed_origins="*")
+
+
+if __name__ == '__main__':
+    import uvicorn
+
+    uvicorn.run("examples.cors:app", host='0.0.0.0', port=8000)

From 87d8a4874171f6bf88f52910ef4a2add4f503976 Mon Sep 17 00:00:00 2001
From: Artucuno <artucunov@gmail.com>
Date: Sun, 26 Nov 2023 14:01:39 +1100
Subject: [PATCH 4/5] Updated examples/app.py

---
 examples/app.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/examples/app.py b/examples/app.py
index 78cd6f9..05ebd81 100644
--- a/examples/app.py
+++ b/examples/app.py
@@ -5,14 +5,15 @@
 sio = SocketManager(app=app)
 
 
-@app.sio.on('join')
-async def handle_join(sid, *args, **kwargs):
-    await sio.emit('lobby', 'User joined')
+@sio.event
+async def connect(sid, *args, **kwargs):
+    print(f"[{sid}] Connected!")
+    await sio.emit('test', 'Hello world!')
 
 
 @sio.on('test')
-async def test(sid, *args, **kwargs):
-    await sio.emit('hey', 'joe')
+async def test(sid, data, **kwargs):
+    print(f'[{sid}] Message Received! >> ', data)
 
 
 if __name__ == '__main__':
@@ -24,4 +25,4 @@ async def test(sid, *args, **kwargs):
 
     import uvicorn
 
-    uvicorn.run("examples.app:app", host='0.0.0.0', port=8000)
+    uvicorn.run("app:app", host='0.0.0.0', port=8000)

From 13955c9612f2d190a40987dd0b754e68b6ff3f57 Mon Sep 17 00:00:00 2001
From: Artucuno <artucunov@gmail.com>
Date: Sun, 26 Nov 2023 14:02:17 +1100
Subject: [PATCH 5/5] Update examples/client.py from #32

---
 examples/client.py | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)
 create mode 100644 examples/client.py

diff --git a/examples/client.py b/examples/client.py
new file mode 100644
index 0000000..ceece6f
--- /dev/null
+++ b/examples/client.py
@@ -0,0 +1,18 @@
+import socketio
+
+sio = socketio.Client()
+
+
+@sio.event
+def connect():
+    print("Connected!")
+
+
+@sio.on('test')
+def on_message(data):
+    print('Message Received! >> ', data)
+    sio.emit('test', 'Hello world!')
+
+
+sio.connect('http://127.0.0.1:8000')
+sio.wait()