From 9371b3cfbc03aac0757020b2f82c462cac5b67ed Mon Sep 17 00:00:00 2001
From: Andrew Thornton <art27@cantab.net>
Date: Mon, 18 Jul 2022 09:20:26 +0100
Subject: [PATCH 1/5] begin adding websocket support

Signed-off-by: Andrew Thornton <art27@cantab.net>
---
 routers/web/web.go | 1 +
 1 file changed, 1 insertion(+)

diff --git a/routers/web/web.go b/routers/web/web.go
index b4e8666c44fd2..897748da78be9 100644
--- a/routers/web/web.go
+++ b/routers/web/web.go
@@ -365,6 +365,7 @@ func RegisterRoutes(m *web.Route) {
 	}, reqSignOut)
 
 	m.Any("/user/events", routing.MarkLongPolling, events.Events)
+	m.Any("/user/websocket", routing.MarkLongPolling, events.Websocket)
 
 	m.Group("/login/oauth", func() {
 		m.Get("/authorize", bindIgnErr(forms.AuthorizationForm{}), auth.AuthorizeOAuth)

From b954177c06a9c64f62462e31e82655e4093595f2 Mon Sep 17 00:00:00 2001
From: Andrew Thornton <art27@cantab.net>
Date: Mon, 18 Jul 2022 15:17:01 +0100
Subject: [PATCH 2/5] partila

---
 go.mod                          |   2 +-
 routers/web/events/websocket.go | 126 ++++++++++++++++++++++++++++++++
 2 files changed, 127 insertions(+), 1 deletion(-)
 create mode 100644 routers/web/events/websocket.go

diff --git a/go.mod b/go.mod
index 6d41af507d351..93149c40d1bd9 100644
--- a/go.mod
+++ b/go.mod
@@ -51,6 +51,7 @@ require (
 	github.com/google/uuid v1.3.0
 	github.com/gorilla/feeds v1.1.1
 	github.com/gorilla/sessions v1.2.1
+	github.com/gorilla/websocket v1.4.2
 	github.com/hashicorp/go-version v1.4.0
 	github.com/hashicorp/golang-lru v0.5.4
 	github.com/huandu/xstrings v1.3.2
@@ -194,7 +195,6 @@ require (
 	github.com/gorilla/handlers v1.5.1 // indirect
 	github.com/gorilla/mux v1.8.0 // indirect
 	github.com/gorilla/securecookie v1.1.1 // indirect
-	github.com/gorilla/websocket v1.4.2 // indirect
 	github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
 	github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
 	github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
diff --git a/routers/web/events/websocket.go b/routers/web/events/websocket.go
new file mode 100644
index 0000000000000..3d01a7ee143ce
--- /dev/null
+++ b/routers/web/events/websocket.go
@@ -0,0 +1,126 @@
+// Copyright 2022 The Gitea Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package events
+
+import (
+	"code.gitea.io/gitea/modules/context"
+	"code.gitea.io/gitea/modules/eventsource"
+	"code.gitea.io/gitea/modules/graceful"
+	"code.gitea.io/gitea/modules/json"
+	"code.gitea.io/gitea/modules/log"
+	"code.gitea.io/gitea/routers/web/auth"
+	"github.com/gorilla/websocket"
+)
+
+type readMessage struct {
+	messageType int
+	message     []byte
+	err         error
+}
+
+// Events listens for events
+func Websocket(ctx *context.Context) {
+	upgrader := websocket.Upgrader{
+		ReadBufferSize:  1024,
+		WriteBufferSize: 1024,
+	}
+
+	conn, err := upgrader.Upgrade(ctx.Resp, ctx.Req, nil)
+	if err != nil {
+		log.Error("Unable to upgrade due to error: %v", err)
+		return
+	}
+	defer conn.Close()
+
+	notify := ctx.Done()
+	shutdownCtx := graceful.GetManager().ShutdownContext()
+
+	eventChan := make(<-chan *eventsource.Event)
+	uid := int64(0)
+	unregister := func() {}
+	if ctx.IsSigned {
+		uid = ctx.Doer.ID
+		eventChan = eventsource.GetManager().Register(uid)
+		unregister = func() {
+			go func() {
+				eventsource.GetManager().Unregister(uid, eventChan)
+				// ensure the messageChan is closed
+				for {
+					_, ok := <-eventChan
+					if !ok {
+						break
+					}
+				}
+			}()
+		}
+	}
+	defer unregister()
+
+	readChan := make(chan readMessage, 20)
+	go func() {
+		for {
+			messageType, message, err := conn.ReadMessage()
+			readChan <- readMessage{
+				messageType: messageType,
+				message:     message,
+				err:         err,
+			}
+			if err != nil {
+				close(readChan)
+				return
+			}
+		}
+	}()
+
+	for {
+		select {
+		case <-notify:
+			return
+		case <-shutdownCtx.Done():
+			return
+		case _, ok := <-readChan:
+			if !ok {
+				break
+			}
+		case event, ok := <-eventChan:
+			if !ok {
+				break
+			}
+			if event.Name == "logout" {
+				if ctx.Session.ID() == event.Data {
+					_, _ = (&eventsource.Event{
+						Name: "logout",
+						Data: "here",
+					}).WriteTo(ctx.Resp)
+					ctx.Resp.Flush()
+					go unregister()
+					auth.HandleSignOut(ctx)
+					break
+				}
+				// Replace the event - we don't want to expose the session ID to the user
+				event = &eventsource.Event{
+					Name: "logout",
+					Data: "elsewhere",
+				}
+			}
+
+			w, err := conn.NextWriter(websocket.TextMessage)
+			if err != nil {
+				log.Warn("Unable to get writer for websocket %v", err)
+				return
+			}
+
+			if err := json.NewEncoder(w).Encode(event); err != nil {
+				log.Error("Unable to create encoder for %v %v", event, err)
+				return
+			}
+			if err := w.Close(); err != nil {
+				log.Warn("Unable to close writer for websocket %v", err)
+				return
+			}
+
+		}
+	}
+}

From 7b5b616693b25ea639804cf827f1d184c49da05f Mon Sep 17 00:00:00 2001
From: Andrew Thornton <art27@cantab.net>
Date: Tue, 19 Jul 2022 10:43:27 +0100
Subject: [PATCH 3/5] still partial

Signed-off-by: Andrew Thornton <art27@cantab.net>
---
 modules/context/response.go                   |  19 +-
 routers/web/events/websocket.go               |  24 +++
 .../js/features/eventsource.sharedworker.js   |   6 +-
 web_src/js/features/notification.js           |  28 ++-
 web_src/js/features/stopwatch.js              |  18 +-
 web_src/js/features/websocket.sharedworker.js | 163 ++++++++++++++++++
 webpack.config.js                             |   3 +
 7 files changed, 246 insertions(+), 15 deletions(-)
 create mode 100644 web_src/js/features/websocket.sharedworker.js

diff --git a/modules/context/response.go b/modules/context/response.go
index 112964dbe14cd..cd5ff4a01a59d 100644
--- a/modules/context/response.go
+++ b/modules/context/response.go
@@ -84,14 +84,29 @@ func (r *Response) Before(f func(ResponseWriter)) {
 	r.befores = append(r.befores, f)
 }
 
+type hijackerResponse struct {
+	*Response
+	http.Hijacker
+}
+
 // NewResponse creates a response
-func NewResponse(resp http.ResponseWriter) *Response {
+func NewResponse(resp http.ResponseWriter) ResponseWriter {
 	if v, ok := resp.(*Response); ok {
 		return v
 	}
-	return &Response{
+	hijacker, ok := resp.(http.Hijacker)
+
+	response := &Response{
 		ResponseWriter: resp,
 		status:         0,
 		befores:        make([]func(ResponseWriter), 0),
 	}
+	if ok {
+		return hijackerResponse{
+			Response: response,
+			Hijacker: hijacker,
+		}
+	}
+
+	return response
 }
diff --git a/routers/web/events/websocket.go b/routers/web/events/websocket.go
index 3d01a7ee143ce..b5ccb7da43b5e 100644
--- a/routers/web/events/websocket.go
+++ b/routers/web/events/websocket.go
@@ -5,11 +5,15 @@
 package events
 
 import (
+	"net/http"
+	"net/url"
+
 	"code.gitea.io/gitea/modules/context"
 	"code.gitea.io/gitea/modules/eventsource"
 	"code.gitea.io/gitea/modules/graceful"
 	"code.gitea.io/gitea/modules/json"
 	"code.gitea.io/gitea/modules/log"
+	"code.gitea.io/gitea/modules/setting"
 	"code.gitea.io/gitea/routers/web/auth"
 	"github.com/gorilla/websocket"
 )
@@ -25,8 +29,28 @@ func Websocket(ctx *context.Context) {
 	upgrader := websocket.Upgrader{
 		ReadBufferSize:  1024,
 		WriteBufferSize: 1024,
+		CheckOrigin: func(r *http.Request) bool {
+			origin := r.Header["Origin"]
+			if len(origin) == 0 {
+				return true
+			}
+			u, err := url.Parse(origin[0])
+			if err != nil {
+				return false
+			}
+			appURLURL, err := url.Parse(setting.AppURL)
+			if err != nil {
+				return true
+			}
+
+			return u.Host == appURLURL.Host
+		},
 	}
 
+	// Because http proxies will tend not to pass these headers
+	ctx.Req.Header.Add("Upgrade", "websocket")
+	ctx.Req.Header.Add("Connection", "Upgrade")
+
 	conn, err := upgrader.Upgrade(ctx.Resp, ctx.Req, nil)
 	if err != nil {
 		log.Error("Unable to upgrade due to error: %v", err)
diff --git a/web_src/js/features/eventsource.sharedworker.js b/web_src/js/features/eventsource.sharedworker.js
index 824ccfea79f84..95b2e1bbd3431 100644
--- a/web_src/js/features/eventsource.sharedworker.js
+++ b/web_src/js/features/eventsource.sharedworker.js
@@ -46,9 +46,13 @@ class Source {
     if (this.listening[eventType]) return;
     this.listening[eventType] = true;
     this.eventSource.addEventListener(eventType, (event) => {
+      let data;
+      if (event.data) {
+        data = JSON.parse(event.data);
+      }
       this.notifyClients({
         type: eventType,
-        data: event.data
+        data
       });
     });
   }
diff --git a/web_src/js/features/notification.js b/web_src/js/features/notification.js
index 36df196cac2d8..af379ecec8b12 100644
--- a/web_src/js/features/notification.js
+++ b/web_src/js/features/notification.js
@@ -24,10 +24,9 @@ export function initNotificationsTable() {
   });
 }
 
-async function receiveUpdateCount(event) {
+async function receiveUpdateCount(data) {
+  console.log(data);
   try {
-    const data = JSON.parse(event.data);
-
     const notificationCount = document.querySelector('.notification_count');
     if (data.Count > 0) {
       notificationCount.classList.remove('hidden');
@@ -36,9 +35,10 @@ async function receiveUpdateCount(event) {
     }
 
     notificationCount.textContent = `${data.Count}`;
+    console.log(notificationCount);
     await updateNotificationTable();
   } catch (error) {
-    console.error(error, event);
+    console.error(error, data);
   }
 }
 
@@ -49,9 +49,20 @@ export function initNotificationCount() {
     return;
   }
 
-  if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) {
+  let worker;
+  let workerUrl;
+
+  if (notificationSettings.EventSourceUpdateTime > 0 && !!window.WebSocket && window.SharedWorker) {
+    // Try to connect to the event source via the shared worker first
+    worker = new SharedWorker(`${__webpack_public_path__}js/websocket.sharedworker.js`, 'notification-worker');
+    workerUrl = `${window.location.origin}${appSubUrl}/user/websocket`;
+  } else if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) {
     // Try to connect to the event source via the shared worker first
-    const worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker');
+    worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker');
+    workerUrl = `${window.location.origin}${appSubUrl}/user/events`;
+  }
+
+  if (worker) {
     worker.addEventListener('error', (event) => {
       console.error(event);
     });
@@ -60,15 +71,16 @@ export function initNotificationCount() {
     });
     worker.port.postMessage({
       type: 'start',
-      url: `${window.location.origin}${appSubUrl}/user/events`,
+      url: workerUrl,
     });
     worker.port.addEventListener('message', (event) => {
       if (!event.data || !event.data.type) {
         console.error(event);
         return;
       }
+      console.log(event);
       if (event.data.type === 'notification-count') {
-        const _promise = receiveUpdateCount(event.data);
+        const _promise = receiveUpdateCount(event.data.data).then(console.log('done'));
       } else if (event.data.type === 'error') {
         console.error(event.data);
       } else if (event.data.type === 'logout') {
diff --git a/web_src/js/features/stopwatch.js b/web_src/js/features/stopwatch.js
index d63da4155af27..a506dfde6c65a 100644
--- a/web_src/js/features/stopwatch.js
+++ b/web_src/js/features/stopwatch.js
@@ -26,9 +26,19 @@ export function initStopwatch() {
     $(this).parent().trigger('submit');
   });
 
-  if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) {
+  let worker;
+  let workerUrl;
+  if (notificationSettings.EventSourceUpdateTime > 0 && !!window.WebSocket && window.SharedWorker) {
     // Try to connect to the event source via the shared worker first
-    const worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker');
+    worker = new SharedWorker(`${__webpack_public_path__}js/websocket.sharedworker.js`, 'notification-worker');
+    workerUrl = `${window.location.origin}${appSubUrl}/user/websocket`;
+  } else if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) {
+    // Try to connect to the event source via the shared worker first
+    worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker');
+    workerUrl = `${window.location.origin}${appSubUrl}/user/events`;
+  }
+
+  if (worker) {
     worker.addEventListener('error', (event) => {
       console.error(event);
     });
@@ -37,7 +47,7 @@ export function initStopwatch() {
     });
     worker.port.postMessage({
       type: 'start',
-      url: `${window.location.origin}${appSubUrl}/user/events`,
+      url: workerUrl,
     });
     worker.port.addEventListener('message', (event) => {
       if (!event.data || !event.data.type) {
@@ -45,7 +55,7 @@ export function initStopwatch() {
         return;
       }
       if (event.data.type === 'stopwatches') {
-        updateStopwatchData(JSON.parse(event.data.data));
+        updateStopwatchData(event.data.data);
       } else if (event.data.type === 'error') {
         console.error(event.data);
       } else if (event.data.type === 'logout') {
diff --git a/web_src/js/features/websocket.sharedworker.js b/web_src/js/features/websocket.sharedworker.js
new file mode 100644
index 0000000000000..1dc9dc9dec191
--- /dev/null
+++ b/web_src/js/features/websocket.sharedworker.js
@@ -0,0 +1,163 @@
+const sourcesByUrl = {};
+const sourcesByPort = {};
+
+class Source {
+  constructor(url) {
+    this.url = url.replace(/^http/, 'ws');
+    this.webSocket = new WebSocket(this.url);
+    this.listening = {};
+    this.clients = [];
+    this.listen('open');
+    this.listen('close');
+    this.listen('logout');
+    this.listen('notification-count');
+    this.listen('stopwatches');
+    this.listen('error');
+    this.webSocket.addEventListener('error', (error) => {
+      this.lastError = error;
+    });
+    this.webSocket.addEventListener('message', (event) => {
+      const message = JSON.parse(event.data);
+      if (!message) {
+        return;
+      }
+      if (this.listening[message.Name]) {
+        this.notifyClients({
+          type: message.Name,
+          data: message.Data
+        });
+      }
+    });
+  }
+
+  register(port) {
+    if (this.clients.includes(port)) return;
+
+    this.clients.push(port);
+
+    port.postMessage({
+      type: 'status',
+      message: `registered to ${this.url}`,
+    });
+
+    if (!this.webSocket) {
+      if (this.lastError) {
+        port.postMessage({
+          type: 'error',
+          message: `websocket disconnected: ${this.lastError}`
+        });
+      } else {
+        port.postMessage({
+          type: 'error',
+          message: 'websocket disconnected'
+        });
+      }
+    }
+  }
+
+  deregister(port) {
+    const portIdx = this.clients.indexOf(port);
+    if (portIdx < 0) {
+      return this.clients.length;
+    }
+    this.clients.splice(portIdx, 1);
+    return this.clients.length;
+  }
+
+  close() {
+    if (!this.webSocket) return;
+
+    this.webSocket.close();
+    this.webSocket = null;
+  }
+
+  listen(eventType) {
+    if (this.listening[eventType]) return;
+    this.listening[eventType] = true;
+    this.webSocket.addEventListener(eventType, (event) => {
+      this.notifyClients({
+        type: eventType,
+        data: event.data
+      });
+    });
+  }
+
+  notifyClients(event) {
+    for (const client of this.clients) {
+      client.postMessage(event);
+    }
+  }
+
+  status(port) {
+    port.postMessage({
+      type: 'status',
+      message: `url: ${this.url} readyState: ${this.webSocket.readyState}`,
+    });
+  }
+}
+
+self.addEventListener('connect', (e) => {
+  for (const port of e.ports) {
+    port.addEventListener('message', (event) => {
+      if (event.data.type === 'start') {
+        const url = event.data.url;
+        if (sourcesByUrl[url]) {
+          // we have a Source registered to this url
+          const source = sourcesByUrl[url];
+          source.register(port);
+          sourcesByPort[port] = source;
+          return;
+        }
+        let source = sourcesByPort[port];
+        if (source) {
+          if (source.eventSource && source.url === url) return;
+
+          // How this has happened I don't understand...
+          // deregister from that source
+          const count = source.deregister(port);
+          // Clean-up
+          if (count === 0) {
+            source.close();
+            sourcesByUrl[source.url] = null;
+          }
+        }
+        // Create a new Source
+        source = new Source(url);
+        source.register(port);
+        sourcesByUrl[url] = source;
+        sourcesByPort[port] = source;
+      } else if (event.data.type === 'listen') {
+        const source = sourcesByPort[port];
+        source.listen(event.data.eventType);
+      } else if (event.data.type === 'close') {
+        const source = sourcesByPort[port];
+
+        if (!source) return;
+
+        const count = source.deregister(port);
+        if (count === 0) {
+          source.close();
+          sourcesByUrl[source.url] = null;
+          sourcesByPort[port] = null;
+        }
+      } else if (event.data.type === 'status') {
+        const source = sourcesByPort[port];
+        if (!source) {
+          port.postMessage({
+            type: 'status',
+            message: 'not connected',
+          });
+          return;
+        }
+        source.status(port);
+      } else {
+        // just send it back
+        port.postMessage({
+          type: 'error',
+          message: `received but don't know how to handle: ${event.data}`,
+        });
+      }
+    });
+    port.start();
+  }
+});
diff --git a/webpack.config.js b/webpack.config.js
index 5109103f7faf7..c865c472f017d 100644
--- a/webpack.config.js
+++ b/webpack.config.js
@@ -62,6 +62,9 @@ export default {
     'eventsource.sharedworker': [
       fileURLToPath(new URL('web_src/js/features/eventsource.sharedworker.js', import.meta.url)),
     ],
+    'websocket.sharedworker': [
+      fileURLToPath(new URL('web_src/js/features/websocket.sharedworker.js', import.meta.url)),
+    ],
     ...themes,
   },
   devtool: false,

From a8d0f9153f048b47c8bc80eb6ecb20a11f1c674c Mon Sep 17 00:00:00 2001
From: Andrew Thornton <art27@cantab.net>
Date: Thu, 28 Jul 2022 18:05:58 +0100
Subject: [PATCH 4/5] partial

Signed-off-by: Andrew Thornton <art27@cantab.net>
---
 routers/web/events/websocket.go               | 103 +++++++++++++++---
 web_src/js/features/notification.js           |  16 ++-
 web_src/js/features/websocket.sharedworker.js |  33 ++++--
 3 files changed, 119 insertions(+), 33 deletions(-)

diff --git a/routers/web/events/websocket.go b/routers/web/events/websocket.go
index b5ccb7da43b5e..6efc115f098f5 100644
--- a/routers/web/events/websocket.go
+++ b/routers/web/events/websocket.go
@@ -7,6 +7,7 @@ package events
 import (
 	"net/http"
 	"net/url"
+	"time"
 
 	"code.gitea.io/gitea/modules/context"
 	"code.gitea.io/gitea/modules/eventsource"
@@ -18,6 +19,12 @@ import (
 	"github.com/gorilla/websocket"
 )
 
+const (
+	writeWait  = 10 * time.Second
+	pongWait   = 60 * time.Second
+	pingPeriod = (pongWait * 9) / 10
+)
+
 type readMessage struct {
 	messageType int
 	message     []byte
@@ -84,6 +91,19 @@ func Websocket(ctx *context.Context) {
 
 	readChan := make(chan readMessage, 20)
 	go func() {
+		defer conn.Close()
+		conn.SetReadLimit(2048)
+		if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
+			log.Error("unable to SetReadDeadline: %v", err)
+			return
+		}
+		conn.SetPongHandler(func(string) error {
+			if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
+				log.Error("unable to SetReadDeadline: %v", err)
+			}
+			return nil
+		})
+
 		for {
 			messageType, message, err := conn.ReadMessage()
 			readChan <- readMessage{
@@ -95,30 +115,70 @@ func Websocket(ctx *context.Context) {
 				close(readChan)
 				return
 			}
+			if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
+				log.Error("unable to SetReadDeadline: %v", err)
+				return
+			}
 		}
 	}()
 
+	pingTicker := time.NewTicker(pingPeriod)
+
 	for {
 		select {
 		case <-notify:
 			return
 		case <-shutdownCtx.Done():
 			return
-		case _, ok := <-readChan:
+		case <-pingTicker.C:
+			// ensure that we're not already cancelled
+			select {
+			case <-notify:
+				return
+			case <-shutdownCtx.Done():
+				return
+			default:
+			}
+			if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
+				log.Error("unable to SetWriteDeadline: %v", err)
+				return
+			}
+			if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
+				log.Error("unable to send PingMessage: %v", err)
+				return
+			}
+		case message, ok := <-readChan:
 			if !ok {
 				break
 			}
+			// ensure that we're not already cancelled
+			select {
+			case <-notify:
+				return
+			case <-shutdownCtx.Done():
+				return
+			default:
+			}
+			log.Info("Got Message: %d:%s:%v", message.messageType, message.message, message.err)
 		case event, ok := <-eventChan:
 			if !ok {
 				break
 			}
+			// ensure that we're not already cancelled
+			select {
+			case <-notify:
+				return
+			case <-shutdownCtx.Done():
+				return
+			default:
+			}
 			if event.Name == "logout" {
 				if ctx.Session.ID() == event.Data {
-					_, _ = (&eventsource.Event{
+					event = &eventsource.Event{
 						Name: "logout",
 						Data: "here",
-					}).WriteTo(ctx.Resp)
-					ctx.Resp.Flush()
+					}
+					_ = writeEvent(conn, event)
 					go unregister()
 					auth.HandleSignOut(ctx)
 					break
@@ -129,22 +189,31 @@ func Websocket(ctx *context.Context) {
 					Data: "elsewhere",
 				}
 			}
-
-			w, err := conn.NextWriter(websocket.TextMessage)
-			if err != nil {
-				log.Warn("Unable to get writer for websocket %v", err)
+			if err := writeEvent(conn, event); err != nil {
 				return
 			}
+		}
+	}
+}
 
-			if err := json.NewEncoder(w).Encode(event); err != nil {
-				log.Error("Unable to create encoder for %v %v", event, err)
-				return
-			}
-			if err := w.Close(); err != nil {
-				log.Warn("Unable to close writer for websocket %v", err)
-				return
-			}
+func writeEvent(conn *websocket.Conn, event *eventsource.Event) error {
+	if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
+		log.Error("unable to SetWriteDeadline: %v", err)
+		return err
+	}
+	w, err := conn.NextWriter(websocket.TextMessage)
+	if err != nil {
+		log.Warn("Unable to get writer for websocket %v", err)
+		return err
+	}
 
-		}
+	if err := json.NewEncoder(w).Encode(event); err != nil {
+		log.Error("Unable to create encoder for %v %v", event, err)
+		return err
+	}
+	if err := w.Close(); err != nil {
+		log.Warn("Unable to close writer for websocket %v", err)
+		return err
 	}
+	return nil
 }
diff --git a/web_src/js/features/notification.js b/web_src/js/features/notification.js
index af379ecec8b12..291747ac11cc8 100644
--- a/web_src/js/features/notification.js
+++ b/web_src/js/features/notification.js
@@ -24,9 +24,10 @@ export function initNotificationsTable() {
   });
 }
 
-async function receiveUpdateCount(data) {
+async function receiveUpdateCount(data, document) {
   console.log(data);
   try {
+    console.log(window, document);
     const notificationCount = document.querySelector('.notification_count');
     if (data.Count > 0) {
       notificationCount.classList.remove('hidden');
@@ -36,6 +37,9 @@ async function receiveUpdateCount(data) {
 
     notificationCount.textContent = `${data.Count}`;
     console.log(notificationCount);
+    const oldDisplay = notificationCount.style.display;
+    notificationCount.style.display = 'none';
+    notificationCount.style.display = oldDisplay;
     await updateNotificationTable();
   } catch (error) {
     console.error(error, data);
@@ -62,9 +66,11 @@ export function initNotificationCount() {
     workerUrl = `${window.location.origin}${appSubUrl}/user/events`;
   }
 
+  const currentDocument = document;
+
   if (worker) {
     worker.addEventListener('error', (event) => {
-      console.error(event);
+      console.error('error from listener: ', event);
     });
     worker.port.addEventListener('messageerror', () => {
       console.error('Unable to deserialize message');
@@ -75,12 +81,13 @@ export function initNotificationCount() {
     });
     worker.port.addEventListener('message', (event) => {
       if (!event.data || !event.data.type) {
-        console.error(event);
+        console.error('Unexpected event:', event);
         return;
       }
       console.log(event);
+      console.log(currentDocument === document);
       if (event.data.type === 'notification-count') {
-        const _promise = receiveUpdateCount(event.data.data).then(console.log('done'));
+        const _promise = receiveUpdateCount(event.data.data, currentDocument).then(console.log('done'));
       } else if (event.data.type === 'error') {
         console.error(event.data);
       } else if (event.data.type === 'logout') {
@@ -98,6 +105,7 @@ export function initNotificationCount() {
         });
         worker.port.close();
       }
+      console.log('done eventlistenter');
     });
     worker.port.addEventListener('error', (e) => {
       console.error(e);
diff --git a/web_src/js/features/websocket.sharedworker.js b/web_src/js/features/websocket.sharedworker.js
index 1dc9dc9dec191..096f485895a34 100644
--- a/web_src/js/features/websocket.sharedworker.js
+++ b/web_src/js/features/websocket.sharedworker.js
@@ -28,6 +28,18 @@ class Source {
         });
       }
     });
+    this.webSocket.addEventListener('close', (event) => {
+      if (!this.webSocket) {
+        return;
+      }
+      const oldWebSocket = this.webSocket;
+      this.webSocket = null;
+      this.notifyClients({
+        type: 'close',
+        data: event
+      });
+      oldWebSocket.close();
+    });
   }
 
   register(port) {
@@ -66,20 +78,14 @@ class Source {
 
   close() {
     if (!this.webSocket) return;
-
-    this.webSocket.close();
+    const oldWebSocket = this.webSocket;
     this.webSocket = null;
+    oldWebSocket.close();
   }
 
   listen(eventType) {
     if (this.listening[eventType]) return;
     this.listening[eventType] = true;
-    this.webSocket.addEventListener(eventType, (event) => {
-      this.notifyClients({
-        type: eventType,
-        data: event.data
-      });
-    });
   }
 
   notifyClients(event) {
@@ -104,13 +110,16 @@ self.addEventListener('connect', (e) => {
         if (sourcesByUrl[url]) {
           // we have a Source registered to this url
           const source = sourcesByUrl[url];
-          source.register(port);
-          sourcesByPort[port] = source;
-          return;
+          if (source.webSocket) {
+            source.register(port);
+            sourcesByPort[port] = source;
+            return;
+          }
+          sourcesByUrl[url] = null;
         }
         let source = sourcesByPort[port];
         if (source) {
-          if (source.eventSource && source.url === url) return;
+          if (source.webSocket && source.url === url) return;
 
           // How this has happened I don't understand...
           // deregister from that source

From 4100a72d1f9fc68371c33ed56e1b2c864bbbdefe Mon Sep 17 00:00:00 2001
From: Andrew Thornton <art27@cantab.net>
Date: Sat, 30 Jul 2022 11:44:02 +0100
Subject: [PATCH 5/5] merge in #20544 and remove a lot of unnecessary logging

Signed-off-by: Andrew Thornton <art27@cantab.net>
---
 modules/context/response.go         |  2 +
 routers/web/events/websocket.go     | 84 +++++++++++++++++------------
 web_src/js/features/notification.js | 21 ++------
 3 files changed, 57 insertions(+), 50 deletions(-)

diff --git a/modules/context/response.go b/modules/context/response.go
index cd5ff4a01a59d..0a844f9d4ae57 100644
--- a/modules/context/response.go
+++ b/modules/context/response.go
@@ -84,6 +84,7 @@ func (r *Response) Before(f func(ResponseWriter)) {
 	r.befores = append(r.befores, f)
 }
 
+// hijackerResponse wraps the Response to allow casting as a Hijacker if the underlying response is a hijacker
 type hijackerResponse struct {
 	*Response
 	http.Hijacker
@@ -102,6 +103,7 @@ func NewResponse(resp http.ResponseWriter) ResponseWriter {
 		befores:        make([]func(ResponseWriter), 0),
 	}
 	if ok {
+		// ensure that the Response we return is also hijackable
 		return hijackerResponse{
 			Response: response,
 			Hijacker: hijacker,
diff --git a/routers/web/events/websocket.go b/routers/web/events/websocket.go
index 6efc115f098f5..7acdcf1df187c 100644
--- a/routers/web/events/websocket.go
+++ b/routers/web/events/websocket.go
@@ -23,6 +23,9 @@ const (
 	writeWait  = 10 * time.Second
 	pongWait   = 60 * time.Second
 	pingPeriod = (pongWait * 9) / 10
+
+	maximumMessageSize  = 2048
+	readMessageChanSize = 20 // <- I've put 20 here because it seems like a reasonable buffer but it may to increase
 )
 
 type readMessage struct {
@@ -89,38 +92,8 @@ func Websocket(ctx *context.Context) {
 	}
 	defer unregister()
 
-	readChan := make(chan readMessage, 20)
-	go func() {
-		defer conn.Close()
-		conn.SetReadLimit(2048)
-		if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
-			log.Error("unable to SetReadDeadline: %v", err)
-			return
-		}
-		conn.SetPongHandler(func(string) error {
-			if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
-				log.Error("unable to SetReadDeadline: %v", err)
-			}
-			return nil
-		})
-
-		for {
-			messageType, message, err := conn.ReadMessage()
-			readChan <- readMessage{
-				messageType: messageType,
-				message:     message,
-				err:         err,
-			}
-			if err != nil {
-				close(readChan)
-				return
-			}
-			if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
-				log.Error("unable to SetReadDeadline: %v", err)
-				return
-			}
-		}
-	}()
+	readMessageChan := make(chan readMessage, readMessageChanSize)
+	go readMessagesFromConnToChan(conn, readMessageChan)
 
 	pingTicker := time.NewTicker(pingPeriod)
 
@@ -139,6 +112,7 @@ func Websocket(ctx *context.Context) {
 				return
 			default:
 			}
+
 			if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
 				log.Error("unable to SetWriteDeadline: %v", err)
 				return
@@ -147,10 +121,11 @@ func Websocket(ctx *context.Context) {
 				log.Error("unable to send PingMessage: %v", err)
 				return
 			}
-		case message, ok := <-readChan:
+		case message, ok := <-readMessageChan:
 			if !ok {
 				break
 			}
+
 			// ensure that we're not already cancelled
 			select {
 			case <-notify:
@@ -159,11 +134,14 @@ func Websocket(ctx *context.Context) {
 				return
 			default:
 			}
+
+			// FIXME: HANDLE MESSAGES
 			log.Info("Got Message: %d:%s:%v", message.messageType, message.message, message.err)
 		case event, ok := <-eventChan:
 			if !ok {
 				break
 			}
+
 			// ensure that we're not already cancelled
 			select {
 			case <-notify:
@@ -172,6 +150,8 @@ func Websocket(ctx *context.Context) {
 				return
 			default:
 			}
+
+			// Handle events
 			if event.Name == "logout" {
 				if ctx.Session.ID() == event.Data {
 					event = &eventsource.Event{
@@ -196,14 +176,50 @@ func Websocket(ctx *context.Context) {
 	}
 }
 
+func readMessagesFromConnToChan(conn *websocket.Conn, messageChan chan readMessage) {
+	defer func() {
+		close(messageChan) // Please note: this has to be within a wrapping anonymous func otherwise it will be evaluated when creating the defer
+		_ = conn.Close()
+	}()
+	conn.SetReadLimit(maximumMessageSize)
+	if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
+		log.Error("unable to SetReadDeadline: %v", err)
+		return
+	}
+	conn.SetPongHandler(func(string) error {
+		if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
+			log.Error("unable to SetReadDeadline: %v", err)
+		}
+		return nil
+	})
+
+	for {
+		messageType, message, err := conn.ReadMessage()
+		messageChan <- readMessage{
+			messageType: messageType,
+			message:     message,
+			err:         err,
+		}
+		if err != nil {
+			// don't need to handle the error here as it is passed down the channel
+			return
+		}
+		if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
+			log.Error("unable to SetReadDeadline: %v", err)
+			return
+		}
+	}
+}
+
 func writeEvent(conn *websocket.Conn, event *eventsource.Event) error {
 	if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
 		log.Error("unable to SetWriteDeadline: %v", err)
 		return err
 	}
+
 	w, err := conn.NextWriter(websocket.TextMessage)
 	if err != nil {
-		log.Warn("Unable to get writer for websocket %v", err)
+		log.Error("Unable to get writer for websocket %v", err)
 		return err
 	}
 
diff --git a/web_src/js/features/notification.js b/web_src/js/features/notification.js
index 291747ac11cc8..24d83057c51a0 100644
--- a/web_src/js/features/notification.js
+++ b/web_src/js/features/notification.js
@@ -25,21 +25,13 @@ export function initNotificationsTable() {
 }
 
 async function receiveUpdateCount(data, document) {
-  console.log(data);
   try {
-    console.log(window, document);
-    const notificationCount = document.querySelector('.notification_count');
-    if (data.Count > 0) {
-      notificationCount.classList.remove('hidden');
-    } else {
-      notificationCount.classList.add('hidden');
+    const notificationCounts = document.querySelectorAll('.notification_count');
+    for (const count of notificationCounts) {
+      count.classList.toggle('hidden', data.Count === 0);
+      count.textContent = `${data.Count}`;
     }
 
-    notificationCount.textContent = `${data.Count}`;
-    console.log(notificationCount);
-    const oldDisplay = notificationCount.style.display;
-    notificationCount.style.display = 'none';
-    notificationCount.style.display = oldDisplay;
     await updateNotificationTable();
   } catch (error) {
     console.error(error, data);
@@ -84,10 +76,8 @@ export function initNotificationCount() {
         console.error('Unexpected event:', event);
         return;
       }
-      console.log(event);
-      console.log(currentDocument === document);
       if (event.data.type === 'notification-count') {
-        const _promise = receiveUpdateCount(event.data.data, currentDocument).then(console.log('done'));
+        const _promise = receiveUpdateCount(event.data.data, currentDocument);
       } else if (event.data.type === 'error') {
         console.error(event.data);
       } else if (event.data.type === 'logout') {
@@ -105,7 +95,6 @@ export function initNotificationCount() {
         });
         worker.port.close();
       }
-      console.log('done eventlistenter');
     });
     worker.port.addEventListener('error', (e) => {
       console.error(e);