diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 7c7fab735856e..676acc639dab0 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -726,6 +726,7 @@ namespace ts { const templateLiteralTypes = new Map<string, TemplateLiteralType>(); const stringMappingTypes = new Map<string, StringMappingType>(); const substitutionTypes = new Map<string, SubstitutionType>(); + const subtypeReductionCache = new Map<string, Type[]>(); const evolvingArrayTypes: EvolvingArrayType[] = []; const undefinedProperties: SymbolTable = new Map(); @@ -13323,7 +13324,12 @@ namespace ts { return includes; } - function removeSubtypes(types: Type[], hasObjectTypes: boolean): boolean { + function removeSubtypes(types: Type[], hasObjectTypes: boolean): Type[] | undefined { + const id = getTypeListId(types); + const match = subtypeReductionCache.get(id); + if (match) { + return match; + } // We assume that redundant primitive types have already been removed from the types array and that there // are no any and unknown types in the array. Thus, the only possible supertypes for primitive types are empty // object types, and if none of those are present we can exclude primitive types from the subtype check. @@ -13335,6 +13341,13 @@ namespace ts { i--; const source = types[i]; if (hasEmptyObject || source.flags & TypeFlags.StructuredOrInstantiable) { + // Find the first property with a unit type, if any. When constituents have a property by the same name + // but of a different unit type, we can quickly disqualify them from subtype checks. This helps subtype + // reduction of large discriminated union types. + const keyProperty = source.flags & (TypeFlags.Object | TypeFlags.Intersection | TypeFlags.InstantiableNonPrimitive) ? + find(getPropertiesOfType(source), p => isUnitType(getTypeOfSymbol(p))) : + undefined; + const keyPropertyType = keyProperty && getRegularTypeOfLiteralType(getTypeOfSymbol(keyProperty)); for (const target of types) { if (source !== target) { if (count === 100000) { @@ -13346,10 +13359,16 @@ namespace ts { if (estimatedCount > 1000000) { tracing?.instant(tracing.Phase.CheckTypes, "removeSubtypes_DepthLimit", { typeIds: types.map(t => t.id) }); error(currentNode, Diagnostics.Expression_produces_a_union_type_that_is_too_complex_to_represent); - return false; + return undefined; } } count++; + if (keyProperty && target.flags & (TypeFlags.Object | TypeFlags.Intersection | TypeFlags.InstantiableNonPrimitive)) { + const t = getTypeOfPropertyOfType(target, keyProperty.escapedName); + if (t && isUnitType(t) && getRegularTypeOfLiteralType(t) !== keyPropertyType) { + continue; + } + } if (isTypeRelatedTo(source, target, strictSubtypeRelation) && ( !(getObjectFlags(getTargetType(source)) & ObjectFlags.Class) || !(getObjectFlags(getTargetType(target)) & ObjectFlags.Class) || @@ -13361,7 +13380,8 @@ namespace ts { } } } - return true; + subtypeReductionCache.set(id, types); + return types; } function removeRedundantLiteralTypes(types: Type[], includes: TypeFlags, reduceVoidUndefined: boolean) { @@ -13435,22 +13455,21 @@ namespace ts { if (types.length === 1) { return types[0]; } - const typeSet: Type[] = []; + let typeSet: Type[] | undefined = []; const includes = addTypesToUnion(typeSet, 0, types); if (unionReduction !== UnionReduction.None) { if (includes & TypeFlags.AnyOrUnknown) { return includes & TypeFlags.Any ? includes & TypeFlags.IncludesWildcard ? wildcardType : anyType : unknownType; } - if (unionReduction & (UnionReduction.Literal | UnionReduction.Subtype)) { - if (includes & (TypeFlags.Literal | TypeFlags.UniqueESSymbol) || includes & TypeFlags.Void && includes & TypeFlags.Undefined) { - removeRedundantLiteralTypes(typeSet, includes, !!(unionReduction & UnionReduction.Subtype)); - } - if (includes & TypeFlags.StringLiteral && includes & TypeFlags.TemplateLiteral) { - removeStringLiteralsMatchedByTemplateLiterals(typeSet); - } + if (includes & (TypeFlags.Literal | TypeFlags.UniqueESSymbol) || includes & TypeFlags.Void && includes & TypeFlags.Undefined) { + removeRedundantLiteralTypes(typeSet, includes, !!(unionReduction & UnionReduction.Subtype)); } - if (unionReduction & UnionReduction.Subtype) { - if (!removeSubtypes(typeSet, !!(includes & TypeFlags.Object))) { + if (includes & TypeFlags.StringLiteral && includes & TypeFlags.TemplateLiteral) { + removeStringLiteralsMatchedByTemplateLiterals(typeSet); + } + if (unionReduction === UnionReduction.Subtype) { + typeSet = removeSubtypes(typeSet, !!(includes & TypeFlags.Object)); + if (!typeSet) { return errorType; } } @@ -17529,8 +17548,17 @@ namespace ts { function typeRelatedToSomeType(source: Type, target: UnionOrIntersectionType, reportErrors: boolean): Ternary { const targetTypes = target.types; - if (target.flags & TypeFlags.Union && containsType(targetTypes, source)) { - return Ternary.True; + if (target.flags & TypeFlags.Union) { + if (containsType(targetTypes, source)) { + return Ternary.True; + } + const match = getMatchingUnionConstituentForType(<UnionType>target, source); + if (match) { + const related = isRelatedTo(source, match, /*reportErrors*/ false); + if (related) { + return related; + } + } } for (const type of targetTypes) { const related = isRelatedTo(source, type, /*reportErrors*/ false); @@ -21364,6 +21392,82 @@ namespace ts { return result; } + // Given a set of constituent types and a property name, create and return a map keyed by the literal + // types of the property by that name in each constituent type. No map is returned if some key property + // has a non-literal type or if less than 10 or less than 50% of the constituents have a unique key. + // Entries with duplicate keys have unknownType as the value. + function mapTypesByKeyProperty(types: Type[], name: __String) { + const map = new Map<TypeId, Type>(); + let count = 0; + for (const type of types) { + if (type.flags & (TypeFlags.Object | TypeFlags.Intersection | TypeFlags.InstantiableNonPrimitive)) { + const discriminant = getTypeOfPropertyOfType(type, name); + if (discriminant) { + if (!isLiteralType(discriminant)) { + return undefined; + } + let duplicate = false; + forEachType(discriminant, t => { + const id = getTypeId(getRegularTypeOfLiteralType(t)); + const existing = map.get(id); + if (!existing) { + map.set(id, type); + } + else if (existing !== unknownType) { + map.set(id, unknownType); + duplicate = true; + } + }); + if (!duplicate) count++; + } + } + } + return count >= 10 && count * 2 >= types.length ? map : undefined; + } + + // Return the name of a discriminant property for which it was possible and feasible to construct a map of + // constituent types keyed by the literal types of the property by that name in each constituent type. + function getKeyPropertyName(unionType: UnionType): __String | undefined { + const types = unionType.types; + // We only construct maps for large unions with non-primitive constituents. + if (types.length < 10 || getObjectFlags(unionType) & ObjectFlags.PrimitiveUnion) { + return undefined; + } + if (unionType.keyPropertyName === undefined) { + // The candidate key property name is the name of the first property with a unit type in one of the + // constituent types. + const keyPropertyName = forEach(types, t => + t.flags & (TypeFlags.Object | TypeFlags.Intersection | TypeFlags.InstantiableNonPrimitive) ? + forEach(getPropertiesOfType(t), p => isUnitType(getTypeOfSymbol(p)) ? p.escapedName : undefined) : + undefined); + const mapByKeyProperty = keyPropertyName && mapTypesByKeyProperty(types, keyPropertyName); + unionType.keyPropertyName = mapByKeyProperty ? keyPropertyName : "" as __String; + unionType.constituentMap = mapByKeyProperty; + } + return (unionType.keyPropertyName as string).length ? unionType.keyPropertyName : undefined; + } + + // Given a union type for which getKeyPropertyName returned a non-undefined result, return the constituent + // that corresponds to the given key type for that property name. + function getConstituentTypeForKeyType(unionType: UnionType, keyType: Type) { + const result = unionType.constituentMap?.get(getTypeId(getRegularTypeOfLiteralType(keyType))); + return result !== unknownType ? result : undefined; + } + + function getMatchingUnionConstituentForType(unionType: UnionType, type: Type) { + const keyPropertyName = getKeyPropertyName(unionType); + const propType = keyPropertyName && getTypeOfPropertyOfType(type, keyPropertyName); + return propType && getConstituentTypeForKeyType(unionType, propType); + } + + function getMatchingUnionConstituentForObjectLiteral(unionType: UnionType, node: ObjectLiteralExpression) { + const keyPropertyName = getKeyPropertyName(unionType); + const propNode = keyPropertyName && find(node.properties, p => p.symbol && p.kind === SyntaxKind.PropertyAssignment && + p.symbol.escapedName === keyPropertyName && isPossiblyDiscriminantValue(p.initializer)); + const propType = propNode && getTypeOfExpression((<PropertyAssignment>propNode).initializer); + return propType && getConstituentTypeForKeyType(unionType, propType); + } + function isOrContainsMatchingReference(source: Node, target: Node) { return isMatchingReference(source, target) || containsMatchingReference(source, target); } @@ -22475,8 +22579,7 @@ namespace ts { } } if (isMatchingReferenceDiscriminant(expr, type)) { - type = narrowTypeByDiscriminant(type, expr as AccessExpression, - t => narrowTypeBySwitchOnDiscriminant(t, flow.switchStatement, flow.clauseStart, flow.clauseEnd)); + type = narrowTypeBySwitchOnDiscriminantProperty(type, expr as AccessExpression, flow.switchStatement, flow.clauseStart, flow.clauseEnd); } } return createFlowType(type, isIncomplete(flowType)); @@ -22650,8 +22753,7 @@ namespace ts { if (propName === undefined) { return type; } - const includesNullable = strictNullChecks && maybeTypeOfKind(type, TypeFlags.Nullable); - const removeNullable = includesNullable && isOptionalChain(access); + const removeNullable = strictNullChecks && isOptionalChain(access) && maybeTypeOfKind(type, TypeFlags.Nullable); let propType = getTypeOfPropertyOfType(removeNullable ? getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull) : type, propName); if (!propType) { return type; @@ -22664,6 +22766,32 @@ namespace ts { }); } + function narrowTypeByDiscriminantProperty(type: Type, access: AccessExpression, operator: SyntaxKind, value: Expression, assumeTrue: boolean) { + if ((operator === SyntaxKind.EqualsEqualsEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken) && type.flags & TypeFlags.Union) { + const keyPropertyName = getKeyPropertyName(<UnionType>type); + if (keyPropertyName && keyPropertyName === getAccessedPropertyName(access)) { + const candidate = getConstituentTypeForKeyType(<UnionType>type, getTypeOfExpression(value)); + if (candidate) { + return operator === (assumeTrue ? SyntaxKind.EqualsEqualsEqualsToken : SyntaxKind.ExclamationEqualsEqualsToken) ? candidate : + isUnitType(getTypeOfPropertyOfType(candidate, keyPropertyName) || unknownType) ? filterType(type, t => t !== candidate) : + type; + } + } + } + return narrowTypeByDiscriminant(type, access, t => narrowTypeByEquality(t, operator, value, assumeTrue)); + } + + function narrowTypeBySwitchOnDiscriminantProperty(type: Type, access: AccessExpression, switchStatement: SwitchStatement, clauseStart: number, clauseEnd: number) { + if (clauseStart < clauseEnd && type.flags & TypeFlags.Union && getKeyPropertyName(<UnionType>type) === getAccessedPropertyName(access)) { + const clauseTypes = getSwitchClauseTypes(switchStatement).slice(clauseStart, clauseEnd); + const candidate = getUnionType(map(clauseTypes, t => getConstituentTypeForKeyType(<UnionType>type, t) || unknownType)); + if (candidate !== unknownType) { + return candidate; + } + } + return narrowTypeByDiscriminant(type, access, t => narrowTypeBySwitchOnDiscriminant(t, switchStatement, clauseStart, clauseEnd)); + } + function narrowTypeByTruthiness(type: Type, expr: Expression, assumeTrue: boolean): Type { if (isMatchingReference(reference, expr)) { return getTypeWithFacts(type, assumeTrue ? TypeFacts.Truthy : TypeFacts.Falsy); @@ -22733,10 +22861,10 @@ namespace ts { } } if (isMatchingReferenceDiscriminant(left, type)) { - return narrowTypeByDiscriminant(type, <AccessExpression>left, t => narrowTypeByEquality(t, operator, right, assumeTrue)); + return narrowTypeByDiscriminantProperty(type, <AccessExpression>left, operator, right, assumeTrue); } if (isMatchingReferenceDiscriminant(right, type)) { - return narrowTypeByDiscriminant(type, <AccessExpression>right, t => narrowTypeByEquality(t, operator, left, assumeTrue)); + return narrowTypeByDiscriminantProperty(type, <AccessExpression>right, operator, left, assumeTrue); } if (isMatchingConstructorReference(left)) { return narrowTypeByConstructor(type, operator, right, assumeTrue); @@ -22809,7 +22937,7 @@ namespace ts { } if (assumeTrue) { const filterFn: (t: Type) => boolean = operator === SyntaxKind.EqualsEqualsToken ? - (t => areTypesComparable(t, valueType) || isCoercibleUnderDoubleEquals(t, valueType)) : + t => areTypesComparable(t, valueType) || isCoercibleUnderDoubleEquals(t, valueType) : t => areTypesComparable(t, valueType); return replacePrimitivesWithLiterals(filterType(type, filterFn), valueType); } @@ -24623,7 +24751,7 @@ namespace ts { } function discriminateContextualTypeByObjectMembers(node: ObjectLiteralExpression, contextualType: UnionType) { - return discriminateTypeByDiscriminableItems(contextualType, + return getMatchingUnionConstituentForObjectLiteral(contextualType, node) || discriminateTypeByDiscriminableItems(contextualType, map( filter(node.properties, p => !!p.symbol && p.kind === SyntaxKind.PropertyAssignment && isPossiblyDiscriminantValue(p.initializer) && isDiscriminantProperty(contextualType, p.symbol.escapedName)), prop => ([() => checkExpression((prop as PropertyAssignment).initializer), prop.symbol.escapedName] as [() => Type, __String]) @@ -24653,15 +24781,9 @@ namespace ts { const instantiatedType = instantiateContextualType(contextualType, node, contextFlags); if (instantiatedType && !(contextFlags && contextFlags & ContextFlags.NoConstraints && instantiatedType.flags & TypeFlags.TypeVariable)) { const apparentType = mapType(instantiatedType, getApparentType, /*noReductions*/ true); - if (apparentType.flags & TypeFlags.Union) { - if (isObjectLiteralExpression(node)) { - return discriminateContextualTypeByObjectMembers(node, apparentType as UnionType); - } - else if (isJsxAttributes(node)) { - return discriminateContextualTypeByJSXAttributes(node, apparentType as UnionType); - } - } - return apparentType; + return apparentType.flags & TypeFlags.Union && isObjectLiteralExpression(node) ? discriminateContextualTypeByObjectMembers(node, apparentType as UnionType) : + apparentType.flags & TypeFlags.Union && isJsxAttributes(node) ? discriminateContextualTypeByJSXAttributes(node, apparentType as UnionType) : + apparentType; } } @@ -41077,6 +41199,10 @@ namespace ts { // Keep this up-to-date with the same logic within `getApparentTypeOfContextualType`, since they should behave similarly function findMatchingDiscriminantType(source: Type, target: Type, isRelatedTo: (source: Type, target: Type) => Ternary, skipPartial?: boolean) { if (target.flags & TypeFlags.Union && source.flags & (TypeFlags.Intersection | TypeFlags.Object)) { + const match = getMatchingUnionConstituentForType(<UnionType>target, source); + if (match) { + return match; + } const sourceProperties = getPropertiesOfType(source); if (sourceProperties) { const sourcePropertiesFiltered = findDiscriminantProperties(sourceProperties, target); diff --git a/src/compiler/types.ts b/src/compiler/types.ts index 9f3042896223e..7f922271ccf4f 100644 --- a/src/compiler/types.ts +++ b/src/compiler/types.ts @@ -5299,6 +5299,10 @@ namespace ts { regularType?: UnionType; /* @internal */ origin?: Type; // Denormalized union, intersection, or index type in which union originates + /* @internal */ + keyPropertyName?: __String; // Property with unique unit type that exists in every object/intersection in union type + /* @internal */ + constituentMap?: ESMap<TypeId, Type>; // Constituents keyed by unit type discriminants } export interface IntersectionType extends UnionOrIntersectionType { diff --git a/tests/baselines/reference/typeParameterLeak.types b/tests/baselines/reference/typeParameterLeak.types index d5df7698ae874..f7d14ea165ecc 100644 --- a/tests/baselines/reference/typeParameterLeak.types +++ b/tests/baselines/reference/typeParameterLeak.types @@ -28,7 +28,7 @@ declare const f: BoxFactoryFactory<BoxTypes>; const b = f({ x: "", y: "" })?.getBox(); >b : Box<{ x: string; }> | Box<{ y: string; }> | undefined >f({ x: "", y: "" })?.getBox() : Box<{ x: string; }> | Box<{ y: string; }> | undefined ->f({ x: "", y: "" })?.getBox : (() => Box<{ x: string; }>) | (() => Box<{ y: string; }>) | undefined +>f({ x: "", y: "" })?.getBox : (() => Box<{ y: string; }>) | (() => Box<{ x: string; }>) | undefined >f({ x: "", y: "" }) : BoxFactory<Box<{ x: string; }>> | BoxFactory<Box<{ y: string; }>> | undefined >f : ((arg: { x: string; }) => BoxFactory<Box<{ x: string; }>> | undefined) | ((arg: { y: string; }) => BoxFactory<Box<{ y: string; }>> | undefined) >{ x: "", y: "" } : { x: string; y: string; } @@ -36,7 +36,7 @@ const b = f({ x: "", y: "" })?.getBox(); >"" : "" >y : string >"" : "" ->getBox : (() => Box<{ x: string; }>) | (() => Box<{ y: string; }>) | undefined +>getBox : (() => Box<{ y: string; }>) | (() => Box<{ x: string; }>) | undefined if (b) { >b : Box<{ x: string; }> | Box<{ y: string; }> | undefined diff --git a/tests/baselines/reference/unionOfClassCalls.types b/tests/baselines/reference/unionOfClassCalls.types index 043ed9826f6bd..7edd1c7f02ac2 100644 --- a/tests/baselines/reference/unionOfClassCalls.types +++ b/tests/baselines/reference/unionOfClassCalls.types @@ -211,12 +211,12 @@ declare var a: Bar | Baz; // note, you must annotate `result` for now a.doThing().then((result: Bar | Baz) => { >a.doThing().then((result: Bar | Baz) => { // whatever}) : Promise<void> ->a.doThing().then : (<TResult1 = Bar, TResult2 = never>(onfulfilled?: ((value: Bar) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) | (<TResult1 = Baz, TResult2 = never>(onfulfilled?: ((value: Baz) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) +>a.doThing().then : (<TResult1 = Baz, TResult2 = never>(onfulfilled?: ((value: Baz) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) | (<TResult1 = Bar, TResult2 = never>(onfulfilled?: ((value: Bar) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) >a.doThing() : Promise<Bar> | Promise<Baz> >a.doThing : (() => Promise<Bar>) | (() => Promise<Baz>) >a : Bar | Baz >doThing : (() => Promise<Bar>) | (() => Promise<Baz>) ->then : (<TResult1 = Bar, TResult2 = never>(onfulfilled?: ((value: Bar) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) | (<TResult1 = Baz, TResult2 = never>(onfulfilled?: ((value: Baz) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) +>then : (<TResult1 = Baz, TResult2 = never>(onfulfilled?: ((value: Baz) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) | (<TResult1 = Bar, TResult2 = never>(onfulfilled?: ((value: Bar) => TResult1 | PromiseLike<TResult1>) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined) => Promise<TResult1 | TResult2>) >(result: Bar | Baz) => { // whatever} : (result: Bar | Baz) => void >result : Bar | Baz