Python 的 typing 模块远不只 int、str、List[int] 这些基础用法。TypeVar、Protocol、ParamSpec 这些进阶工具能写出更安全、更灵活的类型标注。
TypeVar 基础与约束
TypeVar 用于定义泛型函数和泛型类。最基本的用法:
from typing import TypeVar, List
T = TypeVar('T')
def first(items: List[T]) -> T:
return items[0]
# 类型检查器能推断出返回类型
x: int = first([1, 2, 3]) # OK, T=int
y: str = first(["a", "b"]) # OK, T=str
TypeVar 支持两种约束方式:
# bound: T 必须是某个类型的子类
from typing import TypeVar
class Comparable:
def __lt__(self, other) -> bool: ...
CT = TypeVar('CT', bound=Comparable)
def min_val(a: CT, b: CT) -> CT:
return a if a < b else b
# constraints: T 只能是指定的几种类型之一
StrOrBytes = TypeVar('StrOrBytes', str, bytes)
def concat(a: StrOrBytes, b: StrOrBytes) -> StrOrBytes:
return a + b
concat("hello", " world") # OK
concat(b"hello", b" world") # OK
concat("hello", b" world") # Error! 不能混用
bound 和 constraints 的区别:bound=X 表示 T 是 X 或其子类;constraints=(X, Y) 表示 T 只能是 X 或 Y 中的一个。
Protocol:结构化子类型
Python 3.8 引入的 Protocol 实现了类似 Go interface 的结构化子类型——不需要显式继承,只要有对应的方法就行:
from typing import Protocol, runtime_checkable
@runtime_checkable
class Drawable(Protocol):
def draw(self, x: int, y: int) -> None: ...
class Circle:
def draw(self, x: int, y: int) -> None:
print(f"Drawing circle at ({x}, {y})")
class Rectangle:
def draw(self, x: int, y: int) -> None:
print(f"Drawing rectangle at ({x}, {y})")
# Circle 和 Rectangle 没有继承 Drawable,但类型检查器认为它们兼容
def render(shape: Drawable) -> None:
shape.draw(10, 20)
render(Circle()) # OK
render(Rectangle()) # OK
@runtime_checkable 让 isinstance 在运行时也能检查:
c = Circle()
print(isinstance(c, Drawable)) # True
Protocol 可以定义属性和多个方法:
class Sized(Protocol):
@property
def size(self) -> int: ...
class HasLength(Protocol):
def __len__(self) -> int: ...
class Repository(Protocol):
def get(self, id: str) -> dict: ...
def save(self, entity: dict) -> None: ...
def delete(self, id: str) -> bool: ...
Generic 类
结合 TypeVar 和 Generic 创建泛型类:
from typing import TypeVar, Generic, Optional, Iterator
T = TypeVar('T')
class Stack(Generic[T]):
def __init__(self) -> None:
self._items: list[T] = []
def push(self, item: T) -> None:
self._items.append(item)
def pop(self) -> T:
if not self._items:
raise IndexError("empty stack")
return self._items.pop()
def peek(self) -> Optional[T]:
return self._items[-1] if self._items else None
def __iter__(self) -> Iterator[T]:
return iter(reversed(self._items))
def __len__(self) -> int:
return len(self._items)
# 使用时指定类型参数
int_stack: Stack[int] = Stack()
int_stack.push(1)
int_stack.push(2)
val: int = int_stack.pop() # 类型检查器知道 val 是 int
多类型参数的泛型:
K = TypeVar('K')
V = TypeVar('V')
class Pair(Generic[K, V]):
def __init__(self, key: K, value: V) -> None:
self.key = key
self.value = value
def swap(self) -> 'Pair[V, K]':
return Pair(self.value, self.key)
p = Pair("name", 42) # Pair[str, int]
q = p.swap() # Pair[int, str]
ParamSpec:参数规格
Python 3.10 引入的 ParamSpec 解决了装饰器中保留原函数签名的问题:
from typing import TypeVar, ParamSpec, Callable
import functools
import time
P = ParamSpec('P')
R = TypeVar('R')
def timing(func: Callable[P, R]) -> Callable[P, R]:
'''一个保留原函数签名的计时装饰器'''
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
print(f"{func.__name__} took {elapsed:.4f}s")
return result
return wrapper
@timing
def fetch_data(url: str, timeout: int = 30) -> dict:
...
# 类型检查器知道 fetch_data 的签名仍然是 (url: str, timeout: int) -> dict
fetch_data("http://example.com", timeout=10) # OK
fetch_data(123) # Error!
不用 ParamSpec 的话,装饰器返回的函数签名会丢失,变成 (*args, **kwargs) -> R,IDE 的自动补全和类型检查都失效。
实际应用:类型安全的仓储模式
把上面的工具组合起来,实现一个类型安全的仓储接口:
from typing import TypeVar, Generic, Protocol, Optional, List
from dataclasses import dataclass
# 实体协议
class Entity(Protocol):
@property
def id(self) -> str: ...
# 泛型仓储接口
E = TypeVar('E', bound=Entity)
class Repository(Generic[E]):
def get(self, id: str) -> Optional[E]:
raise NotImplementedError
def list_all(self) -> List[E]:
raise NotImplementedError
def save(self, entity: E) -> None:
raise NotImplementedError
def delete(self, id: str) -> bool:
raise NotImplementedError
# 具体实体
@dataclass
class User:
id: str
name: str
email: str
@dataclass
class Product:
id: str
title: str
price: float
# 具体仓储
class UserRepository(Repository[User]):
def __init__(self):
self._store: dict[str, User] = {}
def get(self, id: str) -> Optional[User]:
return self._store.get(id)
def list_all(self) -> List[User]:
return list(self._store.values())
def save(self, entity: User) -> None:
self._store[entity.id] = entity
def delete(self, id: str) -> bool:
return self._store.pop(id, None) is not None
# 使用
repo = UserRepository()
repo.save(User(id="1", name="Alice", email="alice@example.com"))
user: Optional[User] = repo.get("1") # 类型检查器知道是 User
这些类型工具配合 mypy 或 pyright,能在大型项目中提前发现很多类型错误,写库代码时尤其有用。