diff --git a/src/index.tsx b/src/index.tsx index 8cce19f2..46ca2f7e 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,4 +1,4 @@ -import React, { HTMLAttributes } from 'react'; +import React, { HTMLAttributes, RefAttributes, ReactInstance } from 'react'; import ReactDOM from 'react-dom'; import contains from 'rc-util/lib/Dom/contains'; import findDOMNode from 'rc-util/lib/Dom/findDOMNode'; @@ -732,9 +732,8 @@ export function generateTrigger( const { popupVisible } = this.state; const { children, forceRender, alignPoint, className } = this.props; const child = React.Children.only(children) as React.ReactElement; - const newChildProps: HTMLAttributes & { key: string } = { - key: 'trigger', - }; + const newChildProps: HTMLAttributes & + RefAttributes & { key: string } = { key: 'trigger' }; if (this.isContextMenuToShow()) { newChildProps.onContextMenu = this.onContextMenu; @@ -779,10 +778,17 @@ export function generateTrigger( if (childrenClassName) { newChildProps.className = childrenClassName; } - const trigger = React.cloneElement(child, { - ...newChildProps, - ref: composeRef(this.triggerRef, (child as any).ref), - }); + + // Prevent adding ref to Functional child components + if ( + !child.type || + typeof child.type !== 'function' || + (child.type.prototype && child.type.prototype.isReactComponent) + ) { + newChildProps.ref = composeRef(this.triggerRef, (child as any).ref); + } + + const trigger = React.cloneElement(child, newChildProps); let portal: React.ReactElement; // prevent unmounting after it's rendered diff --git a/tests/basic.test.jsx b/tests/basic.test.jsx index c5dbc8e4..c1f56d83 100644 --- a/tests/basic.test.jsx +++ b/tests/basic.test.jsx @@ -538,14 +538,11 @@ describe('Trigger.Basic', () => { }); it('support function component', () => { - const NoRef = React.forwardRef((props, ref) => { - React.useImperativeHandle(ref, () => null); - return ( -
- click -
- ); - }); + const FuncComp = props => ( +
+ click +
+ ); const wrapper = mount( { popupAlign={placementAlignMap.left} popup={tooltip2} > - + , ); @@ -564,6 +561,66 @@ describe('Trigger.Basic', () => { expect(wrapper.isHidden()).toBeTruthy(); }); + describe('passes ref to children where applicable', () => { + function getRefResult(component) { + const triggerMock = jest.fn(); + const wrapper = mount( + tooltip2} + ref={triggerMock} + > + {component} + , + ); + + wrapper.trigger(); + + return triggerMock.mock.calls[0][0].triggerRef.current; + } + + it('does not pass ref to function component', () => { + const NoRef = props => ( +
+ click +
+ ); + + const refVal = getRefResult(); + expect(refVal).toBeNull(); + }); + + it('does pass ref to forwardRef function component', () => { + const WithRef = React.forwardRef((props, ref) => ( +
+ click +
+ )); + + const refVal = getRefResult(); + expect(refVal.id).toBe('target'); + }); + + it('does pass ref to class component', () => { + class ClassComp extends React.Component { + id = 'target'; + + render() { + return
click
; + } + } + + const refVal = getRefResult(); + expect(refVal.id).toBe('target'); + }); + + it('does pass ref to element', () => { + const refVal = getRefResult(
click
); + expect(refVal.id).toBe('target'); + }); + }); + it('Popup with mouseDown prevent', () => { const wrapper = mount(