|
| 1 | +from typing import Any, cast |
| 2 | + |
| 3 | +from mypy.types import ( |
| 4 | + ProperType, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, |
| 5 | + Instance, TypeVarType, ParamSpecType, PartialType, CallableType, TupleType, TypedDictType, |
| 6 | + LiteralType, UnionType, Overloaded, TypeType, TypeAliasType, UnpackType, Parameters, |
| 7 | + TypeVarTupleType |
| 8 | +) |
| 9 | +from mypy.type_visitor import TypeVisitor |
| 10 | + |
| 11 | + |
| 12 | +def copy_type(t: ProperType) -> ProperType: |
| 13 | + """Create a shallow copy of a type. |
| 14 | +
|
| 15 | + This can be used to mutate the copy with truthiness information. |
| 16 | +
|
| 17 | + Classes compiled with mypyc don't support copy.copy(), so we need |
| 18 | + a custom implementation. |
| 19 | + """ |
| 20 | + return t.accept(TypeShallowCopier()) |
| 21 | + |
| 22 | + |
| 23 | +class TypeShallowCopier(TypeVisitor[ProperType]): |
| 24 | + def visit_unbound_type(self, t: UnboundType) -> ProperType: |
| 25 | + return t |
| 26 | + |
| 27 | + def visit_any(self, t: AnyType) -> ProperType: |
| 28 | + return self.copy_common(t, AnyType(t.type_of_any, t.source_any, t.missing_import_name)) |
| 29 | + |
| 30 | + def visit_none_type(self, t: NoneType) -> ProperType: |
| 31 | + return self.copy_common(t, NoneType()) |
| 32 | + |
| 33 | + def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType: |
| 34 | + dup = UninhabitedType(t.is_noreturn) |
| 35 | + dup.ambiguous = t.ambiguous |
| 36 | + return self.copy_common(t, dup) |
| 37 | + |
| 38 | + def visit_erased_type(self, t: ErasedType) -> ProperType: |
| 39 | + return self.copy_common(t, ErasedType()) |
| 40 | + |
| 41 | + def visit_deleted_type(self, t: DeletedType) -> ProperType: |
| 42 | + return self.copy_common(t, DeletedType(t.source)) |
| 43 | + |
| 44 | + def visit_instance(self, t: Instance) -> ProperType: |
| 45 | + dup = Instance(t.type, t.args, last_known_value=t.last_known_value) |
| 46 | + dup.invalid = t.invalid |
| 47 | + return self.copy_common(t, dup) |
| 48 | + |
| 49 | + def visit_type_var(self, t: TypeVarType) -> ProperType: |
| 50 | + dup = TypeVarType( |
| 51 | + t.name, |
| 52 | + t.fullname, |
| 53 | + t.id, |
| 54 | + values=t.values, |
| 55 | + upper_bound=t.upper_bound, |
| 56 | + variance=t.variance, |
| 57 | + ) |
| 58 | + return self.copy_common(t, dup) |
| 59 | + |
| 60 | + def visit_param_spec(self, t: ParamSpecType) -> ProperType: |
| 61 | + dup = ParamSpecType(t.name, t.fullname, t.id, t.flavor, t.upper_bound, prefix=t.prefix) |
| 62 | + return self.copy_common(t, dup) |
| 63 | + |
| 64 | + def visit_parameters(self, t: Parameters) -> ProperType: |
| 65 | + dup = Parameters(t.arg_types, t.arg_kinds, t.arg_names, |
| 66 | + variables=t.variables, |
| 67 | + is_ellipsis_args=t.is_ellipsis_args) |
| 68 | + return self.copy_common(t, dup) |
| 69 | + |
| 70 | + def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType: |
| 71 | + dup = TypeVarTupleType(t.name, t.fullname, t.id, t.upper_bound) |
| 72 | + return self.copy_common(t, dup) |
| 73 | + |
| 74 | + def visit_unpack_type(self, t: UnpackType) -> ProperType: |
| 75 | + dup = UnpackType(t.type) |
| 76 | + return self.copy_common(t, dup) |
| 77 | + |
| 78 | + def visit_partial_type(self, t: PartialType) -> ProperType: |
| 79 | + return self.copy_common(t, PartialType(t.type, t.var, t.value_type)) |
| 80 | + |
| 81 | + def visit_callable_type(self, t: CallableType) -> ProperType: |
| 82 | + return self.copy_common(t, t.copy_modified()) |
| 83 | + |
| 84 | + def visit_tuple_type(self, t: TupleType) -> ProperType: |
| 85 | + return self.copy_common(t, TupleType(t.items, t.partial_fallback, implicit=t.implicit)) |
| 86 | + |
| 87 | + def visit_typeddict_type(self, t: TypedDictType) -> ProperType: |
| 88 | + return self.copy_common(t, TypedDictType(t.items, t.required_keys, t.fallback)) |
| 89 | + |
| 90 | + def visit_literal_type(self, t: LiteralType) -> ProperType: |
| 91 | + return self.copy_common(t, LiteralType(value=t.value, fallback=t.fallback)) |
| 92 | + |
| 93 | + def visit_union_type(self, t: UnionType) -> ProperType: |
| 94 | + return self.copy_common(t, UnionType(t.items)) |
| 95 | + |
| 96 | + def visit_overloaded(self, t: Overloaded) -> ProperType: |
| 97 | + return self.copy_common(t, Overloaded(items=t.items)) |
| 98 | + |
| 99 | + def visit_type_type(self, t: TypeType) -> ProperType: |
| 100 | + # Use cast since the type annotations in TypeType are imprecise. |
| 101 | + return self.copy_common(t, TypeType(cast(Any, t.item))) |
| 102 | + |
| 103 | + def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: |
| 104 | + assert False, "only ProperTypes supported" |
| 105 | + |
| 106 | + def copy_common(self, t: ProperType, t2: ProperType) -> ProperType: |
| 107 | + t2.line = t.line |
| 108 | + t2.column = t.column |
| 109 | + t2.can_be_false = t.can_be_false |
| 110 | + t2.can_be_true = t.can_be_true |
| 111 | + return t2 |
0 commit comments