From fc4178476a34e2d51e6ac23de3a131d0705dc835 Mon Sep 17 00:00:00 2001 From: Liz Zhang Date: Sun, 24 May 2026 01:00:40 -0700 Subject: [PATCH] refactor: add missing @override decorators to TypeDecorator subclasses in models/types.py (#36565) --- api/models/types.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/api/models/types.py b/api/models/types.py index 23028220f6..092db63856 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,7 +1,7 @@ import enum import json import uuid -from typing import Any, cast +from typing import Any, cast, override import sqlalchemy as sa from pydantic import BaseModel @@ -18,6 +18,7 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]): impl = CHAR cache_ok = True + @override def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -28,12 +29,14 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]): return value.hex return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) + @override def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -44,11 +47,13 @@ class LongText(TypeDecorator[str | None]): impl = TEXT cache_ok = True + @override def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None: if value is None: return value return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(TEXT()) @@ -57,6 +62,7 @@ class LongText(TypeDecorator[str | None]): else: return dialect.type_descriptor(TEXT()) + @override def process_result_value(self, value: str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -77,6 +83,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]): self._model_class = model_class super().__init__() + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(TEXT()) @@ -85,6 +92,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]): else: return dialect.type_descriptor(TEXT()) + @override def process_bind_param(self, value: T | dict[str, Any] | str | None, dialect: Dialect) -> str | None: if value is None: return None @@ -96,6 +104,7 @@ class JSONModelColumn[T: BaseModel](TypeDecorator[T | None]): model = self._model_class.model_validate(value) return json.dumps(model.model_dump(mode="json"), ensure_ascii=False, sort_keys=True, separators=(",", ":")) + @override def process_result_value(self, value: str | None, dialect: Dialect) -> T | None: if value is None or value == "": return None @@ -106,11 +115,13 @@ class BinaryData(TypeDecorator[bytes | None]): impl = LargeBinary cache_ok = True + @override def process_bind_param(self, value: bytes | None, dialect: Dialect) -> bytes | None: if value is None: return value return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(BYTEA()) @@ -119,6 +130,7 @@ class BinaryData(TypeDecorator[bytes | None]): else: return dialect.type_descriptor(LargeBinary()) + @override def process_result_value(self, value: bytes | None, dialect: Dialect) -> bytes | None: if value is None: return value @@ -133,6 +145,7 @@ class AdjustedJSON(TypeDecorator[dict | list | None]): self.astext_type = astext_type super().__init__() + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": if self.astext_type: @@ -144,11 +157,13 @@ class AdjustedJSON(TypeDecorator[dict | list | None]): else: return dialect.type_descriptor(sa.JSON()) + @override def process_bind_param( self, value: dict[str, Any] | list[Any] | None, dialect: Dialect ) -> dict[str, Any] | list[Any] | None: return value + @override def process_result_value( self, value: dict[str, Any] | list[Any] | None, dialect: Dialect ) -> dict[str, Any] | list[Any] | None: @@ -173,6 +188,7 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): # leave some rooms for future longer enum values. self._length = max(max_enum_value_len, 20) + @override def process_bind_param(self, value: T | str | None, dialect: Dialect) -> str | None: if value is None: return value @@ -182,9 +198,11 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): self._enum_class(value) return value + @override def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: return dialect.type_descriptor(VARCHAR(self._length)) + @override def process_result_value(self, value: str | None, dialect: Dialect) -> T | None: if value is None or value == "": return None @@ -197,6 +215,7 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): return cast(T, value_of(value)) raise + @override def compare_values(self, x: T | None, y: T | None) -> bool: if x is None or y is None: return x is y