Skip to content
62 changes: 33 additions & 29 deletions packages/libp2p/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { CONNECTION_CLOSE_TIMEOUT, PROTOCOL_NEGOTIATION_TIMEOUT } from './connec
import { isDirect } from './connection-manager/utils.ts'
import { MuxerUnavailableError } from './errors.ts'
import { DEFAULT_MAX_INBOUND_STREAMS, DEFAULT_MAX_OUTBOUND_STREAMS } from './registrar.ts'
import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream } from '@libp2p/interface'
import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream, StreamMiddleware } from '@libp2p/interface'
import type { Registrar } from '@libp2p/interface-internal'
import type { Multiaddr } from '@multiformats/multiaddr'

Expand Down Expand Up @@ -126,7 +126,7 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
}

this.log.trace('starting new stream for protocols %s', protocols)
let muxedStream = await this.muxer.createStream({
const muxedStream = await this.muxer.createStream({
...options,

// most underlying transports only support negotiating a single protocol
Expand Down Expand Up @@ -179,23 +179,7 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement

const middleware = this.components.registrar.getMiddleware(muxedStream.protocol)

middleware.push((stream, connection, next) => {
next(stream, connection)
})

let i = 0
let connection: ConnectionInterface = this

while (i < middleware.length) {
// eslint-disable-next-line no-loop-func
middleware[i](muxedStream, connection, (s, c) => {
muxedStream = s
connection = c
i++
})
}

return muxedStream
return await this.runMiddlewareChain(muxedStream, this, middleware)
} catch (err: any) {
if (muxedStream.status === 'open') {
muxedStream.abort(err)
Expand All @@ -208,7 +192,7 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
}

private async onIncomingStream (evt: CustomEvent<Stream>): Promise<void> {
let muxedStream = evt.detail
const muxedStream = evt.detail

const signal = AbortSignal.timeout(this.inboundStreamProtocolNegotiationTimeout)
setMaxListeners(Infinity, signal)
Expand Down Expand Up @@ -260,20 +244,40 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
next(stream, connection)
})

let connection: ConnectionInterface = this

for (const m of middleware) {
// eslint-disable-next-line no-loop-func
await m(muxedStream, connection, (s, c) => {
muxedStream = s
connection = c
})
}
await this.runMiddlewareChain(muxedStream, this, middleware)
} catch (err: any) {
muxedStream.abort(err)
}
}

private async runMiddlewareChain (stream: Stream, connection: ConnectionInterface, middleware: StreamMiddleware[]): Promise<Stream> {
for (let i = 0; i < middleware.length; i++) {
const mw = middleware[i]
stream.log.trace('running middleware', i, mw)

// eslint-disable-next-line no-loop-func
await new Promise<void>((resolve, reject) => {
try {
const result = mw(stream, connection, (s, c) => {
stream = s
connection = c
resolve()
})

if (result instanceof Promise) {
result.catch(reject)
}
} catch (err) {
reject(err)
}
})

stream.log.trace('ran middleware', i, mw)
}

return stream
}

/**
* Close the connection
*/
Expand Down
102 changes: 90 additions & 12 deletions packages/libp2p/test/connection/index.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { StreamCloseEvent } from '@libp2p/interface'
import { defaultLogger } from '@libp2p/logger'
import { peerIdFromString } from '@libp2p/peer-id'
import { echoStream, streamPair, echo, multiaddrConnectionPair, mockMuxer } from '@libp2p/utils'
import { multiaddr } from '@multiformats/multiaddr'
Expand Down Expand Up @@ -361,6 +362,7 @@ describe('connection', () => {
}

const incomingStream = stubInterface<Stream>({
log: defaultLogger().forComponent('stream'),
protocol: streamProtocol
})

Expand All @@ -371,24 +373,100 @@ describe('connection', () => {
onIncomingStream(new CustomEvent('stream', {
detail: incomingStream
}))
/*

// incoming stream is opened asynchronously
await delay(100)

expect(middleware1.called).to.be.true()
expect(middleware2.called).to.be.true()
})

it('should not call outbound middleware if previous middleware errors', async () => {
const streamProtocol = '/test/protocol'
const err = new Error('boom')

const middleware1 = Sinon.stub().callsFake((stream, connection, next) => {
throw err
})
const middleware2 = Sinon.stub().callsFake((stream, connection, next) => {
next(stream, connection)
})

const middleware = [
middleware1,
middleware2
]

registrar.getMiddleware.withArgs(streamProtocol).returns(middleware)
registrar.getHandler.withArgs(streamProtocol).returns({
handler: () => {},
options: {}
})

const connection = createConnection(components, init)

await expect(connection.newStream(streamProtocol))
.to.eventually.be.rejectedWith(err)

expect(middleware1.called).to.be.true()
expect(middleware2.called).to.be.false()
})

it('should not call inbound middleware if previous middleware errors', async () => {
const streamProtocol = '/test/protocol'

const middleware1 = Sinon.stub().callsFake((stream, connection, next) => {
throw new Error('boom')
})
const middleware2 = Sinon.stub().callsFake((stream, connection, next) => {
next(stream, connection)
})

const middleware = [
middleware1,
middleware2
]

registrar.getMiddleware.withArgs(streamProtocol).returns(middleware)
registrar.getHandler.withArgs(streamProtocol).returns({
handler: () => {},
options: {}
})

const muxer = stubInterface<StreamMuxer>({
streams: []
})

createConnection(components, {
...init,
muxer
})

expect(muxer.addEventListener.getCall(0).args[0]).to.equal('stream')
const onIncomingStream = muxer.addEventListener.getCall(0).args[1]

if (onIncomingStream == null) {
throw new Error('No incoming stream handler registered')
}

const incomingStream = stubInterface<Stream>({
id: 'stream-id',
log: logger('test-stream'),
direction: 'outbound',
sink: async (source) => drain(source),
source: map((async function * () {
yield '/multistream/1.0.0\n'
yield `${streamProtocol}\n`
})(), str => encode.single(uint8ArrayFromString(str)))
log: defaultLogger().forComponent('stream'),
protocol: streamProtocol
})
*/
// onIncomingStream?.(incomingStream)

if (typeof onIncomingStream !== 'function') {
throw new Error('Stream handler was not function')
}

onIncomingStream(new CustomEvent('stream', {
detail: incomingStream
}))

// incoming stream is opened asynchronously
await delay(100)

expect(middleware1.called).to.be.true()
expect(middleware2.called).to.be.true()
expect(middleware2.called).to.be.false()
expect(incomingStream).to.have.nested.property('abort.called', true)
})
})
Loading