Python类型提示进阶:Protocol与TypeVar

Python 的 typing 模块远不只 intstrList[int] 这些基础用法。TypeVarProtocolParamSpec 这些进阶工具能写出更安全、更灵活的类型标注。

TypeVar 基础与约束

TypeVar 用于定义泛型函数和泛型类。最基本的用法:

from typing import TypeVar, List

T = TypeVar('T')

def first(items: List[T]) -> T:
    return items[0]

# 类型检查器能推断出返回类型
x: int = first([1, 2, 3])       # OK, T=int
y: str = first(["a", "b"])      # OK, T=str

TypeVar 支持两种约束方式:

# bound: T 必须是某个类型的子类
from typing import TypeVar

class Comparable:
    def __lt__(self, other) -> bool: ...

CT = TypeVar('CT', bound=Comparable)

def min_val(a: CT, b: CT) -> CT:
    return a if a < b else b


# constraints: T 只能是指定的几种类型之一
StrOrBytes = TypeVar('StrOrBytes', str, bytes)

def concat(a: StrOrBytes, b: StrOrBytes) -> StrOrBytes:
    return a + b

concat("hello", " world")   # OK
concat(b"hello", b" world") # OK
concat("hello", b" world")  # Error! 不能混用

boundconstraints 的区别:bound=X 表示 T 是 X 或其子类;constraints=(X, Y) 表示 T 只能是 X 或 Y 中的一个。

Protocol:结构化子类型

Python 3.8 引入的 Protocol 实现了类似 Go interface 的结构化子类型——不需要显式继承,只要有对应的方法就行:

from typing import Protocol, runtime_checkable

@runtime_checkable
class Drawable(Protocol):
    def draw(self, x: int, y: int) -> None: ...

class Circle:
    def draw(self, x: int, y: int) -> None:
        print(f"Drawing circle at ({x}, {y})")

class Rectangle:
    def draw(self, x: int, y: int) -> None:
        print(f"Drawing rectangle at ({x}, {y})")

# Circle 和 Rectangle 没有继承 Drawable,但类型检查器认为它们兼容
def render(shape: Drawable) -> None:
    shape.draw(10, 20)

render(Circle())      # OK
render(Rectangle())   # OK

@runtime_checkableisinstance 在运行时也能检查:

c = Circle()
print(isinstance(c, Drawable))  # True

Protocol 可以定义属性和多个方法:

class Sized(Protocol):
    @property
    def size(self) -> int: ...

class HasLength(Protocol):
    def __len__(self) -> int: ...

class Repository(Protocol):
    def get(self, id: str) -> dict: ...
    def save(self, entity: dict) -> None: ...
    def delete(self, id: str) -> bool: ...

Generic 类

结合 TypeVarGeneric 创建泛型类:

from typing import TypeVar, Generic, Optional, Iterator

T = TypeVar('T')

class Stack(Generic[T]):
    def __init__(self) -> None:
        self._items: list[T] = []

    def push(self, item: T) -> None:
        self._items.append(item)

    def pop(self) -> T:
        if not self._items:
            raise IndexError("empty stack")
        return self._items.pop()

    def peek(self) -> Optional[T]:
        return self._items[-1] if self._items else None

    def __iter__(self) -> Iterator[T]:
        return iter(reversed(self._items))

    def __len__(self) -> int:
        return len(self._items)


# 使用时指定类型参数
int_stack: Stack[int] = Stack()
int_stack.push(1)
int_stack.push(2)
val: int = int_stack.pop()  # 类型检查器知道 val 是 int

多类型参数的泛型:

K = TypeVar('K')
V = TypeVar('V')

class Pair(Generic[K, V]):
    def __init__(self, key: K, value: V) -> None:
        self.key = key
        self.value = value

    def swap(self) -> 'Pair[V, K]':
        return Pair(self.value, self.key)

p = Pair("name", 42)  # Pair[str, int]
q = p.swap()           # Pair[int, str]

ParamSpec:参数规格

Python 3.10 引入的 ParamSpec 解决了装饰器中保留原函数签名的问题:

from typing import TypeVar, ParamSpec, Callable
import functools
import time

P = ParamSpec('P')
R = TypeVar('R')

def timing(func: Callable[P, R]) -> Callable[P, R]:
    '''一个保留原函数签名的计时装饰器'''
    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"{func.__name__} took {elapsed:.4f}s")
        return result
    return wrapper

@timing
def fetch_data(url: str, timeout: int = 30) -> dict:
    ...

# 类型检查器知道 fetch_data 的签名仍然是 (url: str, timeout: int) -> dict
fetch_data("http://example.com", timeout=10)  # OK
fetch_data(123)  # Error!

不用 ParamSpec 的话,装饰器返回的函数签名会丢失,变成 (*args, **kwargs) -> R,IDE 的自动补全和类型检查都失效。

实际应用:类型安全的仓储模式

把上面的工具组合起来,实现一个类型安全的仓储接口:

from typing import TypeVar, Generic, Protocol, Optional, List
from dataclasses import dataclass

# 实体协议
class Entity(Protocol):
    @property
    def id(self) -> str: ...

# 泛型仓储接口
E = TypeVar('E', bound=Entity)

class Repository(Generic[E]):
    def get(self, id: str) -> Optional[E]:
        raise NotImplementedError

    def list_all(self) -> List[E]:
        raise NotImplementedError

    def save(self, entity: E) -> None:
        raise NotImplementedError

    def delete(self, id: str) -> bool:
        raise NotImplementedError

# 具体实体
@dataclass
class User:
    id: str
    name: str
    email: str

@dataclass
class Product:
    id: str
    title: str
    price: float

# 具体仓储
class UserRepository(Repository[User]):
    def __init__(self):
        self._store: dict[str, User] = {}

    def get(self, id: str) -> Optional[User]:
        return self._store.get(id)

    def list_all(self) -> List[User]:
        return list(self._store.values())

    def save(self, entity: User) -> None:
        self._store[entity.id] = entity

    def delete(self, id: str) -> bool:
        return self._store.pop(id, None) is not None


# 使用
repo = UserRepository()
repo.save(User(id="1", name="Alice", email="alice@example.com"))
user: Optional[User] = repo.get("1")  # 类型检查器知道是 User

这些类型工具配合 mypy 或 pyright,能在大型项目中提前发现很多类型错误,写库代码时尤其有用。