diff --git a/packages/libp2p/src/connection.ts b/packages/libp2p/src/connection.ts index 4f4ee18ade..6308e1e422 100644 --- a/packages/libp2p/src/connection.ts +++ b/packages/libp2p/src/connection.ts @@ -6,7 +6,7 @@ import { CONNECTION_CLOSE_TIMEOUT, PROTOCOL_NEGOTIATION_TIMEOUT } from './connec import { isDirect } from './connection-manager/utils.ts' import { MuxerUnavailableError } from './errors.ts' import { DEFAULT_MAX_INBOUND_STREAMS, DEFAULT_MAX_OUTBOUND_STREAMS } from './registrar.ts' -import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream } from '@libp2p/interface' +import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream, StreamMiddleware } from '@libp2p/interface' import type { Registrar } from '@libp2p/interface-internal' import type { Multiaddr } from '@multiformats/multiaddr' @@ -126,7 +126,7 @@ export class Connection extends TypedEventEmitter implement } this.log.trace('starting new stream for protocols %s', protocols) - let muxedStream = await this.muxer.createStream({ + const muxedStream = await this.muxer.createStream({ ...options, // most underlying transports only support negotiating a single protocol @@ -179,23 +179,7 @@ export class Connection extends TypedEventEmitter implement const middleware = this.components.registrar.getMiddleware(muxedStream.protocol) - middleware.push((stream, connection, next) => { - next(stream, connection) - }) - - let i = 0 - let connection: ConnectionInterface = this - - while (i < middleware.length) { - // eslint-disable-next-line no-loop-func - middleware[i](muxedStream, connection, (s, c) => { - muxedStream = s - connection = c - i++ - }) - } - - return muxedStream + return await this.runMiddlewareChain(muxedStream, this, middleware) } catch (err: any) { if (muxedStream.status === 'open') { muxedStream.abort(err) @@ -208,7 +192,7 @@ export class Connection extends TypedEventEmitter implement } private async onIncomingStream (evt: CustomEvent): Promise { - let muxedStream = evt.detail + const muxedStream = evt.detail const signal = AbortSignal.timeout(this.inboundStreamProtocolNegotiationTimeout) setMaxListeners(Infinity, signal) @@ -260,20 +244,40 @@ export class Connection extends TypedEventEmitter implement next(stream, connection) }) - let connection: ConnectionInterface = this - - for (const m of middleware) { - // eslint-disable-next-line no-loop-func - await m(muxedStream, connection, (s, c) => { - muxedStream = s - connection = c - }) - } + await this.runMiddlewareChain(muxedStream, this, middleware) } catch (err: any) { muxedStream.abort(err) } } + private async runMiddlewareChain (stream: Stream, connection: ConnectionInterface, middleware: StreamMiddleware[]): Promise { + for (let i = 0; i < middleware.length; i++) { + const mw = middleware[i] + stream.log.trace('running middleware', i, mw) + + // eslint-disable-next-line no-loop-func + await new Promise((resolve, reject) => { + try { + const result = mw(stream, connection, (s, c) => { + stream = s + connection = c + resolve() + }) + + if (result instanceof Promise) { + result.catch(reject) + } + } catch (err) { + reject(err) + } + }) + + stream.log.trace('ran middleware', i, mw) + } + + return stream + } + /** * Close the connection */ diff --git a/packages/libp2p/test/connection/index.spec.ts b/packages/libp2p/test/connection/index.spec.ts index a08401a4f1..df97511479 100644 --- a/packages/libp2p/test/connection/index.spec.ts +++ b/packages/libp2p/test/connection/index.spec.ts @@ -1,4 +1,5 @@ import { StreamCloseEvent } from '@libp2p/interface' +import { defaultLogger } from '@libp2p/logger' import { peerIdFromString } from '@libp2p/peer-id' import { echoStream, streamPair, echo, multiaddrConnectionPair, mockMuxer } from '@libp2p/utils' import { multiaddr } from '@multiformats/multiaddr' @@ -361,6 +362,7 @@ describe('connection', () => { } const incomingStream = stubInterface({ + log: defaultLogger().forComponent('stream'), protocol: streamProtocol }) @@ -371,24 +373,100 @@ describe('connection', () => { onIncomingStream(new CustomEvent('stream', { detail: incomingStream })) - /* + + // incoming stream is opened asynchronously + await delay(100) + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.true() + }) + + it('should not call outbound middleware if previous middleware errors', async () => { + const streamProtocol = '/test/protocol' + const err = new Error('boom') + + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + throw err + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + registrar.getMiddleware.withArgs(streamProtocol).returns(middleware) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler: () => {}, + options: {} + }) + + const connection = createConnection(components, init) + + await expect(connection.newStream(streamProtocol)) + .to.eventually.be.rejectedWith(err) + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.false() + }) + + it('should not call inbound middleware if previous middleware errors', async () => { + const streamProtocol = '/test/protocol' + + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + throw new Error('boom') + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + registrar.getMiddleware.withArgs(streamProtocol).returns(middleware) + registrar.getHandler.withArgs(streamProtocol).returns({ + handler: () => {}, + options: {} + }) + + const muxer = stubInterface({ + streams: [] + }) + + createConnection(components, { + ...init, + muxer + }) + + expect(muxer.addEventListener.getCall(0).args[0]).to.equal('stream') + const onIncomingStream = muxer.addEventListener.getCall(0).args[1] + + if (onIncomingStream == null) { + throw new Error('No incoming stream handler registered') + } + const incomingStream = stubInterface({ - id: 'stream-id', - log: logger('test-stream'), - direction: 'outbound', - sink: async (source) => drain(source), - source: map((async function * () { - yield '/multistream/1.0.0\n' - yield `${streamProtocol}\n` - })(), str => encode.single(uint8ArrayFromString(str))) + log: defaultLogger().forComponent('stream'), + protocol: streamProtocol }) -*/ - // onIncomingStream?.(incomingStream) + + if (typeof onIncomingStream !== 'function') { + throw new Error('Stream handler was not function') + } + + onIncomingStream(new CustomEvent('stream', { + detail: incomingStream + })) // incoming stream is opened asynchronously await delay(100) expect(middleware1.called).to.be.true() - expect(middleware2.called).to.be.true() + expect(middleware2.called).to.be.false() + expect(incomingStream).to.have.nested.property('abort.called', true) }) })