From 77e964fe5e4dc673660ccbba16ff61ab9075f3ec Mon Sep 17 00:00:00 2001 From: Lorenzo Del Castillo Date: Fri, 7 Feb 2025 16:59:36 +0100 Subject: [PATCH] Add: possibility to refine TypedDict fields if they are readonly --- mypy/semanal_typeddict.py | 33 +++++++++++++++--- test-data/unit/check-typeddict.test | 53 +++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/mypy/semanal_typeddict.py b/mypy/semanal_typeddict.py index 0d6a0b7ff87f..fcab7373ea8e 100644 --- a/mypy/semanal_typeddict.py +++ b/mypy/semanal_typeddict.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Collection from typing import Final from mypy import errorcodes as codes, message_registry @@ -40,6 +39,7 @@ require_bool_literal_argument, ) from mypy.state import state +from mypy.subtypes import is_subtype from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type from mypy.types import ( TPDICT_NAMES, @@ -166,7 +166,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N base, field_types, required_keys, readonly_keys, defn ) (new_field_types, new_statements, new_required_keys, new_readonly_keys) = ( - self.analyze_typeddict_classdef_fields(defn, oldfields=field_types) + self.analyze_typeddict_classdef_fields( + defn, oldfields=field_types, oldreadonly_keys=readonly_keys + ) ) if new_field_types is None: return True, None # Defer @@ -280,7 +282,10 @@ def map_items_to_base( return mapped_items def analyze_typeddict_classdef_fields( - self, defn: ClassDef, oldfields: Collection[str] | None = None + self, + defn: ClassDef, + oldfields: dict[str, Type] | None = None, + oldreadonly_keys: set[str] | None = None, ) -> tuple[dict[str, Type] | None, list[Statement], set[str], set[str]]: """Analyze fields defined in a TypedDict class definition. @@ -325,8 +330,6 @@ def analyze_typeddict_classdef_fields( self.fail(TPDICT_CLASS_ERROR, stmt) else: name = stmt.lvalues[0].name - if name in (oldfields or []): - self.fail(f'Overwriting TypedDict field "{name}" while extending', stmt) if name in fields: self.fail(f'Duplicate TypedDict key "{name}"', stmt) continue @@ -351,6 +354,26 @@ def analyze_typeddict_classdef_fields( stmt.type = self.extract_meta_info(analyzed, stmt)[0] field_type, required, readonly = self.extract_meta_info(field_type) + + if oldfields and name in oldfields: + # Refinements are only allowed on readonly keys + if name not in (oldreadonly_keys or set()): + self.fail(f'Overwriting TypedDict field "{name}" while extending', stmt) + else: + # Refinements must be ReadOnly too + if not readonly: + self.fail( + f'Overwriting TypedDict ReadOnly field "{name}" with non-ReadOnly type', + stmt, + ) + + # Refinements must be type compatible + if not is_subtype(field_type, oldfields[name]): + self.fail( + f'Overwriting TypedDict ReadOnly field "{name}" with incompatible type', + stmt, + ) + fields[name] = field_type if (total or required is True) and required is not False: diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index c2b734b4b923..f0e261d0d512 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -4138,3 +4138,56 @@ Derived.Params(name="Robert") DerivedOverride.Params(name="Robert") [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + + +[case testRefinementReadOnlyField] +from typing_extensions import TypedDict, ReadOnly, Literal + +class A(TypedDict): + a: ReadOnly[str] + +class B(A): + a: Literal["foo"] # E: Overwriting TypedDict ReadOnly field "a" with non-ReadOnly type + +class C(TypedDict): + a: str + +class D(C): + a: Literal["foo"] # E: Overwriting TypedDict field "a" while extending + +class F(TypedDict): + a: ReadOnly[str] + +class G(F): + a: str # E: Overwriting TypedDict ReadOnly field "a" with non-ReadOnly type + +class H(TypedDict): + a: ReadOnly[str] + +class E(H): + a: ReadOnly[int] # E: Overwriting TypedDict ReadOnly field "a" with incompatible type + + +class S(str): pass +class I(TypedDict): + a: ReadOnly[str] + +class J(I): + a: ReadOnly[S] + + +def f(a: A, b: B) -> None: + reveal_type(a['a']) # N: Revealed type is "builtins.str" + reveal_type(b['a']) # N: Revealed type is "Literal['foo']" + +def g(i: I, j: J) -> None: + reveal_type(i['a']) # N: Revealed type is "builtins.str" + reveal_type(j['a']) # N: Revealed type is "__main__.S" + +def mutate_dictA(d: A) -> None: + d["a"] = "bar" # E: ReadOnly TypedDict key "a" TypedDict is mutated + +def mutate_dictB(d: B) -> None: + d["a"] = "bar" # E: ReadOnly TypedDict key "a" TypedDict is mutated # E: Value of "a" has incompatible type "Literal['bar']"; expected "Literal['foo']" + +[builtins fixtures/primitives.pyi]