refactor: add missing @override decorators to TypeDecorator subclasses in models/types.py (#36565)

This commit is contained in:
Liz Zhang
2026-05-24 01:00:40 -07:00
committed by GitHub
parent 6133c2ab6a
commit fc4178476a

View File

@@ -1,7 +1,7 @@
import enum
import json
import uuid
from typing import Any, cast
from typing import Any, cast, override
import sqlalchemy as sa
from pydantic import BaseModel
@@ -18,6 +18,7 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]):
impl = CHAR
cache_ok = True
@override
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
@@ -28,12 +29,14 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]):
return value.hex
return value
@override
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID())
else:
return dialect.type_descriptor(CHAR(36))
@override
def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
@@ -44,11 +47,13 @@ class LongText(TypeDecorator[str | None]):
impl = TEXT
cache_ok = True
@override
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
if value is None:
return value
return value
@override
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(TEXT())
@@ -57,6 +62,7 @@ class LongText(TypeDecorator[str | None]):
else:
return dialect.type_descriptor(TEXT())
@override
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
if value is None:
return value
@@ -77,6 +83,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
self._model_class = model_class
super().__init__()
@override
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(TEXT())
@@ -85,6 +92,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
else:
return dialect.type_descriptor(TEXT())
@override
def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None:
if value is None:
return None
@@ -96,6 +104,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]):
model = self._model_class.model_validate(value)
return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
@override
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
if value is None or value == "":
return None
@@ -106,11 +115,13 @@ class BinaryData(TypeDecorator[bytes | None]):
impl = LargeBinary
cache_ok = True
@override
def process_bind_param(self, value: bytes | None, dialect: Dialect) -> bytes | None:
if value is None:
return value
return value
@override
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
return dialect.type_descriptor(BYTEA())
@@ -119,6 +130,7 @@ class BinaryData(TypeDecorator[bytes | None]):
else:
return dialect.type_descriptor(LargeBinary())
@override
def process_result_value(self, value: bytes | None, dialect: Dialect) -> bytes | None:
if value is None:
return value
@@ -133,6 +145,7 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
self.astext_type = astext_type
super().__init__()
@override
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql":
if self.astext_type:
@@ -144,11 +157,13 @@ class AdjustedJSON(TypeDecorator[dict | list | None]):
else:
return dialect.type_descriptor(sa.JSON())
@override
def process_bind_param(
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
) -> dict[str, Any] | list[Any] | None:
return value
@override
def process_result_value(
self, value: dict[str, Any] | list[Any] | None, dialect: Dialect
) -> dict[str, Any] | list[Any] | None:
@@ -173,6 +188,7 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
# leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20)
@override
def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None:
if value is None:
return value
@@ -182,9 +198,11 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
self._enum_class(value)
return value
@override
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length))
@override
def process_result_value(self, value: str | None, dialect: Dialect) -> T | None:
if value is None or value == "":
return None
@@ -197,6 +215,7 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]):
return cast(T, value_of(value))
raise
@override
def compare_values(self, x: T | None, y: T | None) -> bool:
if x is None or y is None:
return x is y