机器学习中的装饰器

社区文章 发布于2025年6月8日

Python 装饰器可以说是您将遇到的最强大的语言特性之一。它们允许您在不直接修改原始代码的情况下,为函数或类添加新功能。在这篇文章中,我们将深入研究两个顶级开源机器学习框架——**vllm** 和 **trl** 的源代码,并剖析 Python 装饰器在实践中如何使用:从您每天看到的内置装饰器,到高级自定义装饰器,最后到标准库模块 **functools** 中隐藏的便捷工具。

在本文中,您将学习到

  • 为什么装饰器在现代 Python 编程中如此重要,尤其是在机器学习项目中。
  • 装饰器工作的底层机制。
  • 如何使用常用装饰器来实现实际任务,如缓存、上下文管理、模型封装和配置注入。
  • 如何从零开始构建装饰器并将其集成到您的代码库中。
  • 关于 functools 中常用装饰器。

1. ML 项目中的常用装饰器

类方法装饰器:@classmethod

源代码:https://github.com/huggingface/trl/blob/main/trl/core.py#L91

在 HuggingFace 的 TRL 库的源代码中可以找到一个经典示例:

class PPODecorators:
    optimize_device_cache = False

    @classmethod
    @contextmanager
    def empty_device_cache(cls):
        yield
        if cls.optimize_device_cache:
            if is_torch_xpu_available():
                gc.collect()
                torch.xpu.empty_cache()
                gc.collect()
            elif is_torch_npu_available():
                gc.collect()
                torch.npu.empty_cache()
                gc.collect()
            elif torch.cuda.is_available():
                gc.collect()
                torch.cuda.empty_cache()
                gc.collect()

这里 @classmethod 的目的是在类的作用域内定义一个可以访问类变量的上下文管理器。当一个函数被 @classmethod 装饰时:它成为一个类方法,而不是一个实例方法。它的第一个参数是 cls(类本身),允许它访问类属性,例如 cls.optimize_device_cache。当被 @contextmanager 装饰时:它成为一个上下文管理器,可以使用 with 语句调用

with PPODecorators.empty_device_cache():
    # e.g. run a PPO optimization step

如果没有 @classmethod,这个函数只能作为普通函数或实例方法调用,并且无法通过类访问类变量。通过添加 @classmethod,它可以像这样访问它们:

if cls.optimize_device_cache:
    ...

这意味着它可以根据类变量 optimize_device_cache 的值动态控制是否执行缓存清除操作。

为什么使用类变量而不是实例变量?

类变量定义在类中,由所有实例共享,属于类。它通过 ClassName.varself.__class__.var 访问。实例变量定义在 __init__ 或实例方法中,属于特定实例,通过 self.var 访问。

这里 optimize_device_cache 充当一个全局行为开关,控制是否清除设备缓存,影响所有实例。它不绑定到特定对象,而是绑定到 PPO 策略(或整个系统),从而简化了状态管理逻辑,因为您甚至不需要实例化 PPODecorators 就可以使用它。

为什么这里使用类变量 optimize_device_cache?

因为这个变量控制着一个全局行为——是否清除设备缓存。这个开关应该适用于所有实例;它不属于特定对象,而是属于 PPO 策略或整个系统。这简化了状态控制逻辑(你不需要实例化 PPODecorators 就可以使用它)。使用类变量的典型场景包括:

  • 配置选项标志(例如,optimize_device_cache, DEBUG = True)。
  • 缓存或注册表(例如,model_registry = {})。
  • 计数器、共享资源(例如,instance_count = 0)。
  • 工具类,无需实例化的情况(例如,静态方法、上下文管理器、装饰器类)。

在这种情况下,使用实例变量会增加实例化负担并引入状态一致性问题。

上下文管理器装饰器:@contextmanager

当你用 @contextmanager 装饰一个函数时,该函数就成为了一个上下文管理器,可以与 with ... as x: 语法一起使用。

yield 语句有效地将它后面的值“返回”给 with 语句的 as 部分;这是您可以在 with 块中操作的对象。yield 也标志着上下文的“中点”

  • yield 之前的代码在进入上下文时执行 (__enter__())。
  • yield 之后的代码在退出上下文时执行 (__exit__())。
  • yield 的值:绑定到 with 语句中 as 目标的那个对象。

上述代码的执行顺序是什么?

Python 上下文管理器(即 with 语句)的执行流程,以上述代码为例:

with PPODecorators.empty_device_cache():
    do_something()

执行顺序如下:

  1. 调用上下文管理器(即调用 empty_device_cache())。
  2. 进入上下文:执行 yield 之前的代码。
  3. 执行 with 块内的语句(例如,do_something())。
  4. with 块完成后(或发生异常后),执行 yield 之后的代码。退出上下文。
@contextmanager
def empty_device_cache(cls):
    yield  # This is the insertion point for the code inside the with block
    if cls.optimize_device_cache:
        ...  # The cache clearing code executes after the with statement

yield 之前:没有代码(这里是空的)。在 yield 之后:代码在 with 块完成后自动触发。因此,清理操作在 with 块结束之后执行。这等同于以下逻辑:

gen = empty_device_cache()
next(gen)            # Enter the context, execute code before yield (empty here)
do_something()       # Your code inside the 'with' block
try:
    next(gen)        # Execute logic after yield
except StopIteration:
    pass             # Generator is exhausted

让我们看一个更具体的例子:

源代码:https://github.com/huggingface/trl/blob/main/trl/models/utils.py#L185

@contextmanager
def unwrap_model_for_generation(
    model: Union["DistributedDataParallel", "DeepSpeedEngine"],
    accelerator: "Accelerator",
    gather_deepspeed3_params: bool = True,
):
    unwrapped_model = accelerator.unwrap_model(model)
    if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
        if not gather_deepspeed3_params:
            yield accelerator.unwrap_model(model)
        else:
            import deepspeed

            with deepspeed.zero.GatheredParameters(model.parameters()):
                remove_hooks(model)
                yield accelerator.unwrap_model(model)
                add_hooks(model)
    else:
        yield unwrapped_model

这是一个上下文管理函数,用 contextlib 中的 @contextmanager 装饰器来简化上下文管理器的编写。它的功能是根据是否使用 DeepSpeed Stage 3 以及是否需要收集参数,返回一个“准备好进行生成任务的解包模型”。与前面的示例类似,流程如下:

with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
    # When this code is executed, the code before yield has already run.
    # The object returned by yield is assigned to unwrapped_model.

此时:

  • yield 之前的语句 = 上下文进入阶段(做准备工作,如参数聚合)。
  • yield 的值 = 真正的“解包”模型,unwrapped_model,在 with 语句中作为变量可用。
  • yield 之后的语句(如果有)= 上下文退出阶段(做清理工作,如重新添加钩子)。

让我们一步步来看逻辑分支:

  • 情况 1:不使用 DeepSpeed Stage 3
    if accelerator.state.deepspeed_plugin is None or accelerator.state.deepspeed_plugin.zero_stage != 3:
      yield unwrapped_model
    
    它直接返回解包后的模型;没有复杂的运算。这对应于常规 DDP 或 DeepSpeed Stage 1/2。
  • 情况 2:DeepSpeed Stage 3 且不需要收集参数
    if not gather_deepspeed3_params:
      yield accelerator.unwrap_model(model)
    
    如果使用 DeepSpeed ZeRO Stage 3 而不收集参数,则它会跳过参数收集。这可以节省 VRAM,但可能会降低生成速度。在这里,它也直接 yield 解包后的模型。
  • 情况 3:DeepSpeed Stage 3 且需要收集参数
    with deepspeed.zero.GatheredParameters(model.parameters()):
    remove_hooks(model)
    yield accelerator.unwrap_model(model)
    add_hooks(model)
    
    这种实现意味着:
  • 你可以用相同的代码来支持各种分布式训练封装器(DDP/DeepSpeed)。
  • 解包逻辑是自动处理的,它会根据条件执行必要的参数聚合和钩子清理。
  • with 块内部,你可以像对待普通模型一样调用 .generate(...)
  • 并且在 with 块结束后,它会自动清理状态(例如,恢复钩子)。

抽象类装饰器:@abstractmethod 和属性装饰器:@property

在下一个示例中,我们将研究 Python 中用于定义类接口的两个经典装饰器:@abstractmethod@property。它们经常一起使用,以在面向对象的架构中构建严格的接口规范。这种模式在框架设计、分布式系统和机器学习服务代码中尤其常见。让我们从一个实际示例开始,了解它们如何协同工作。

源代码:https://github.com/vllm-project/vllm/blob/main/vllm/engine/protocol.py#L27

from abc import ABC, abstractmethod
class EngineClient(ABC):

    @property
    @abstractmethod
    def is_running(self) -> bool:
        ...

    @property
    @abstractmethod
    def is_stopped(self) -> bool:
        ...

    @property
    @abstractmethod
    def errored(self) -> bool:
        ...

    @property
    @abstractmethod
    def dead_error(self) -> BaseException:
        ...

    @abstractmethod
    def generate(
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> AsyncGenerator[RequestOutput, None]:
        """Generate outputs for a request."""
        ...

什么是 ABC

ABC 是抽象基类(Abstract Base Class)的缩写,来自 abc 模块:from abc import ABCABC 用于定义接口或协议类。它指定了子类必须实现的方法或属性。它有助于设计更健壮、模块化和面向对象的代码结构。在运行时,您不能直接实例化具有未实现抽象方法的类;这样做会引发 TypeError

为什么要使用 @abstractmethod

@abstractmethod 指示一个方法或属性是抽象的,必须由子类实现。当与 ABC 一起使用时,该类成为一个抽象类,不能直接实例化。任何继承自该类的子类都必须实现所有用 @abstractmethod 标记的方法和属性;否则,它也无法实例化。

为什么要使用 @property

@property 将方法转换为“属性”,允许像访问字段一样访问它。例如:

@property
def is_running(self) -> bool:
    ...

使用 @property 有以下好处:

  • 更直观:您调用 obj.is_running 而不是 obj.is_running(),这更具可读性。
  • 封装内部逻辑:尽管它看起来像一个属性,但您可以在方法内部执行逻辑检查。
  • 统一接口:对于某些状态属性(如 is_runningerrored),它们显示为字段,但实际上是由逻辑动态计算的。

总结:EngineClient 设计讨论

EngineClient 类是一个接口定义,它规范了所有“客户端”类的结构。

  • 它要求客户端类实现某些状态属性(如 is_running)。
  • 它要求实现一个异步生成方法。
  • 它使用 ABC@abstractmethod 来强制所有子类必须实现这些接口。
  • 它使用 @property 为状态属性提供更简洁的接口。

这种方法在大型工程中非常常见,是一种优雅的接口设计方式。

静态方法装饰器:@staticmethod

@staticmethod 是 Python 中的一个方法装饰器,它表示一个“静态方法”。

class MyClass:
    @staticmethod
    def foo(x):
        ...

@staticmethod 装饰的方法:

  • 不接收 selfcls 作为第一个参数。
  • 无法访问实例属性 (self.x) 或类变量 (cls.y)。
  • 几乎与普通函数相同,但它被放置在类的命名空间中,以便在逻辑上将其组织为类的一部分。

源代码:https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L568

@staticmethod
    def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):
        tokenizer = processing_class  # the processing class is a tokenizer
        prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
        chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"]
        rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"]

        # Add special tokens (typically for encoder-decoder models)
        if add_special_tokens:
            if tokenizer.bos_token_id is not None:
                prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
            if tokenizer.eos_token_id is not None:
                prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
        chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
        rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]

        # Truncate prompt and completion sequences
        if max_prompt_length is not None:
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        if max_completion_length is not None:
            chosen_input_ids = chosen_input_ids[:max_completion_length]
            rejected_input_ids = rejected_input_ids[:max_completion_length]

        return {
            "prompt_input_ids": prompt_input_ids,
            "chosen_input_ids": chosen_input_ids,
            "rejected_input_ids": rejected_input_ids,
        }

那么为什么这里要使用 @staticmethod 呢?

@staticmethod
def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):

该方法的特点是:

  • 它不使用 self(它不特定于任何一个对象)。
  • 它不使用 cls(它不访问任何类变量)。
  • 它只是一个用于处理输入数据的工具函数。
  • 它对 DPOTrainer 类的其他成员没有状态依赖。

因此,它是一个封装在类中但不依赖于类或实例状态的纯函数,这使得 @staticmethod 成为最合适的选择。使用 @staticmethod 的好处包括:

  • 语义清晰:表达“此函数不依赖于对象或类;它只是一段逻辑。”
  • 无需实例化:可以直接通过 ClassName.method() 调用。
  • 更清晰的代码结构:它是类的一部分,但保持功能独立性。
  • 更易于测试:因为它不涉及类状态,所以不需要构建实例进行测试。

调用它时,可以直接使用 tokenized = DPOTrainer.tokenize_row(features, tokenizer, 3, 3, False),而无需首先实例化 DPOTrainertokenize_row 方法显然只是一个与分词相关的实用函数,因此使用 @staticmethod 是一个非常合适的设计选择。

@staticmethod 不需要 selfcls,不能访问类/实例状态,并作为独立于类或实例状态的实用函数。@classmethod 需要 cls,可以访问类变量,用于在类级别操作的逻辑(如工厂方法)。

一个非常实用的工程经验法则是:如果一个方法不依赖于类的状态(实例变量或类变量),则考虑将其设为静态方法。如果它甚至不属于类的“概念域”,那么就干脆将其设为一个完全独立的函数。

什么时候应该编写独立的工具函数?当该函数在多个类中都有用,或者其逻辑与当前类的语义没有强绑定时。例如:分词器的填充函数、字符串清理或通用日志格式化。在这种情况下,将其编写为 utils.py 中的单独函数更有利于重用、解耦和测试。

2. 自定义函数装饰器

装饰器基础知识快速回顾

当您想以统一、自动化和可重用的方式修改或增强函数行为时,装饰器是最佳选择。

常见用例包括:

  • 日志记录: 例如打印函数名称、输入参数和返回值。
  • 性能分析: 例如自动记录执行时间与内存使用情况。
  • 缓存(记忆化): 记住函数输出以避免重复计算。
  • 访问控制/验证: 检查用户权限或参数有效性。
  • 并发控制: 例如,为函数添加锁以确保线程安全。
  • 重试机制: 例如,在失败后自动重试函数(如 API 调用)。

装饰器的本质是“语法糖”。当您看到这种语法时:

@my_decorator
def foo():
    ...

# It is actually equivalent to:

def foo():
    ...

foo = my_decorator(foo)

任何接受另一个函数作为参数并返回可调用对象(通常也是一个函数)的函数都可以用作装饰器。

def my_decorator(func):
    def wrapper(*args, **kwargs):
        print("Before")
        result = func(*args, **kwargs)
        print("After")
        return result
    return wrapper

@my_decorator 语法糖将它下面的 foo 函数作为 func 参数传递给 my_decorator。调用 my_decorator(func) 返回一个新函数 wrapper,因此原始的 foo() 被替换为包含前/后逻辑的新函数。

带参数的装饰器怎么办?

如果你想写这样的东西:

@my_decorator_with_args("DEBUG")
def foo(): ...

你需要两层函数嵌套:

def my_decorator_with_args(log_level):
    def real_decorator(func):
        def wrapper(*args, **kwargs):
            print(f"[{log_level}] Calling {func.__name__}")
            return func(*args, **kwargs)
        return wrapper
    return real_decorator
 

使用时:

@my_decorator_with_args("DEBUG")  # 实际执行顺序:
def foo():
    pass

# This is equivalent to:
# foo = my_decorator_with_args("DEBUG")(foo)

所以,你看到的 @something(...) 实际上是:

  • 首先,something(...) 被执行,它返回一个真正的装饰器函数。
  • 然后,foo 函数被传递给这个返回的装饰器。

trl 的性能分析装饰器:@profiling_decorator

在 HuggingFace TRL 库中,有一个装饰器可以自动记录函数的执行时间,这非常适合大型训练库。

源代码:https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L822

class GRPOTrainer(Trainer):
...
    @profiling_decorator
    def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, logits_to_keep=None):
        if is_peft_model(unwrapped_model):
            unwrapped_model = unwrapped_model.base_model.model
        last_hidden_state = unwrapped_model.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        last_hidden_state = last_hidden_state[:, :-1, :]  # (B, L-1, H)
        if logits_to_keep is not None:
            last_hidden_state = last_hidden_state[:, -logits_to_keep:, :]  # (B, logits_to_keep, H)
        return last_hidden_state

def profiling_decorator(func: callable) -> callable:
    """
    Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].

    Args:
        func (`callable`):
            Function to be profiled.

    Example:
    ```python
    from transformers import Trainer
    from trl.extras.profiling import profiling_decorator

    class MyTrainer(Trainer):
        @profiling_decorator
        def some_method(self):
            A = np.random.rand(1000, 1000)
            B = np.random.rand(1000, 1000)
            # Code to profile: simulate a computationally expensive operation
            result = A @ B
    ```
    """

    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        with profiling_context(self, func.__name__):
            return func(self, *args, **kwargs)

    return wrapper


@contextmanager
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
    """
    A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
    depending on the trainer's configuration.

    Args:
        trainer (`~transformers.Trainer`):
            Trainer object.
        name (`str`):
            Name of the block to be profiled. Used as a key in the logged dictionary.

    Example:
    ```python
    from transformers import Trainer
    from trl.extras.profiling import profiling_context

    class MyTrainer(Trainer):
        def some_method(self):
            A = np.random.rand(1000, 1000)
            B = np.random.rand(1000, 1000)
            with profiling_context(self, "matrix_multiplication"):
                # Code to profile: simulate a computationally expensive operation
                result = A @ B  # Matrix multiplication
    ```
    """
    start_time = time.perf_counter()
    yield
    end_time = time.perf_counter()
    duration = end_time - start_time

    profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}
    if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
        wandb.log(profiling_metrics)

    if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process:
        mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)

这个装饰器有什么用?它会自动为我们的函数执行性能分析,无需手动检测。


def profiling_decorator(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        with profiling_context(self, func.__name__):  # ← This is the "timer"
            return func(self, *args, **kwargs)
    return wrapper

这个装饰器执行以下操作:

  • 它创建了一个包装函数 wrapper
  • 它用 profiling_context(...) 封装了原始函数调用。
  • profiling_context 是一个上下文管理器,负责计时和收集分析器数据。
  • return func(...) 调用原始函数,允许函数体正常执行。
  • 它使用 @functools.wraps(func) 来保留原始函数的元数据(如其名称和文档字符串)。

有些函数需要进行性能分析,但您不想为每个函数手动编写性能分析逻辑。使用装饰器,您只需一行代码即可添加性能分析功能。它会自动记录函数名称,具有清晰的作用范围,非常适合在大型训练库中使用。

3. functools 模块中的常用装饰器

保留元数据:@functools.wraps

如前面的示例所示,此装饰器的目的是在装饰器中保留原始函数的元数据(如其名称、文档字符串和签名)。它在自定义装饰器定义中起着不可或缺的作用。@functools.wraps 的用法如下:

@functools.wraps(orig_method)
def wrapped_method(model_self, *args, **kwargs):
    ...

它用于保留原始函数 orig_method 的元数据:

  • 函数名 __name__:否则,它将变为 wrapped_method
  • 函数文档字符串 __doc__
  • 函数签名信息。
  • 这有助于调试、日志记录、文档工具和跟踪工具。

没有 @wraps

>>> module.forward.__name__
'wrapped_method'

使用 @wraps(forward)

>>> module.forward.__name__
'forward'

为什么它几乎总是出现在自定义装饰器中?正如前面讨论装饰器语法时提到的,当你看到这个:

@my_decorator
def foo():
    ...

它实际上等同于:

def foo():
    ...

foo = my_decorator(foo)

在这里,语句 foo = my_decorator(foo) 导致 foo 的自身元数据(如其名称和文档字符串)被 my_decorator 返回的 wrapper 函数的元数据替换。这可能导致许多问题,特别是:

  • 在 IDE 中查看函数时,无法看到原始的文档字符串和签名。
  • 分析代码时,无法检查正确的函数结构。
  • 自动生成文档时,无法显示正确的信息。
  • 使用断点调试时,调试器显示 wrapper,而不是原始函数。

当你正确定义装饰器如下时:

import functools

def my_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@my_decorator
def say_hello():
    """This function says hello"""
    print("Hello!")

print(say_hello.__name__)   # Output: 'say_hello'
print(say_hello.__doc__)    # Output: 'This function says hello'

这实质上是告诉 Python,wrapper 函数是 func 的代理,指示它将 func 的所有元数据传输到 wrapper。因此,无论何时编写装饰器函数,都几乎总是应该添加 @functools.wraps(func),除非你确实不需要保留原始函数的信息(这种情况很少见)。

除了在装饰器定义中使用外,@functools.wraps 也经常单独使用。让我们看另一个 trl 中的例子:

源代码:https://github.com/huggingface/trl/blob/main/trl/trainer/online_dpo_trainer.py#L381

@wraps(Trainer.get_eval_dataloader)
    def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        # If we have persistent workers, don't do a fork bomb especially as eval datasets
        # don't change during training
        dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
        if (
            hasattr(self, "_eval_dataloaders")
            and dataloader_key in self._eval_dataloaders
            and self.args.dataloader_persistent_workers
        ):
            return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])

        eval_dataset = (
            self.eval_dataset[eval_dataset]
            if isinstance(eval_dataset, str)
            else eval_dataset
            if eval_dataset is not None
            else self.eval_dataset
        )
        data_collator = self.data_collator

        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        # accelerator.free_memory() will destroy the references, so
        # we need to store the non-prepared version
        eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
        if self.args.dataloader_persistent_workers:
            if hasattr(self, "_eval_dataloaders"):
                self._eval_dataloaders[dataloader_key] = eval_dataloader
            else:
                self._eval_dataloaders = {dataloader_key: eval_dataloader}

        return self.accelerator.prepare(eval_dataloader)

上面的代码是 HuggingFace Trainer 类的 get_eval_dataloader() 方法的一个重写或增强版本。

@wraps(Trainer.get_eval_dataloader)
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
    ...

在这里,@wraps(Trainer.get_eval_dataloader) 本质上是说:“我编写了一个新的 get_eval_dataloader(),它增强了原始方法,但我希望保留原始方法的元数据(例如其名称、文档字符串、签名等)。”这可能涉及的一个重要用例是在重写父类方法时保留原始方法的信息。让我们抽象这个过程:

from functools import wraps

class Base:
    def say(self):
        """Say something."""
        print("Base says")

class Sub(Base):
    @wraps(Base.say)
    def say(self):
        print("Sub overrides")

print(Sub.say.__name__)  #  'say'
print(Sub.say.__doc__)   # 'Say something.'

在类中重写方法时使用 @wraps(...) 可确保子类方法从外部角度看与父类方法保持一致,这有助于 IDE、文档工具、调试器和装饰器链分析器等工具识别。

另外,@functools.wraps 也常用于手动包装实例方法(猴子补丁)。猴子补丁是指在运行时(而不是在源代码中)动态修改类、模块或函数的行为。简单来说:您不更改原始代码文件,而是在代码运行时“秘密地”重写函数或类的实现。这在调试时修改第三方库行为时很常见:

class Greeter:
    def greet(self, name):
        """Greet someone."""
        return f"Hello {name}"

g = Greeter()
original = g.greet

@functools.wraps(original)
def new_greet(self, name):
    print("Pre-hook")
    result = original(name)
    print("Post-hook")
    return result

g.greet = new_greet.__get__(g, Greeter)

print(g.greet.__name__)  #  'greet'
print(g.greet.__doc__)   #  'Greet someone.'

在这种猴子补丁的情况下,@wraps(...) 确保替换的方法看起来仍然像原始方法,避免“元数据损坏”。

既然谈到了,我们再简要讨论一下 functools 库在机器学习领域中其他常用的装饰器。

自动缓存计算结果:@functools.cache@functools.lru_cache

@functools.cache 是标准库 functools 中的一个装饰器,它自动缓存函数的返回值(基于其输入参数),以避免重复计算并提高效率。

源代码:https://github.com/volcengine/verl/blob/main/verl/utils/import_utils.py#L24

@cache
def is_megatron_core_available():
    try:
        mcore_spec = importlib.util.find_spec("megatron.core")
    except ModuleNotFoundError:
        mcore_spec = None
    return mcore_spec is not None


@cache
def is_vllm_available():
    try:
        vllm_spec = importlib.util.find_spec("vllm")
    except ModuleNotFoundError:
        vllm_spec = None
    return vllm_spec is not None


@cache
def is_sglang_available():
    try:
        sglang_spec = importlib.util.find_spec("sglang")
    except ModuleNotFoundError:
        sglang_spec = None
    return sglang_spec is not None

在以下情况下,使用此装饰器是理想的选择:

  • 结果仅取决于输入,并且不随时间变化,例如检查模块是否已安装、计算斐波那契数或路径查找。
  • 函数运行成本高昂但结果稳定,例如加载模型、查找依赖项或编译过程。
  • 您不希望函数重复运行,例如检查、探测或初始化函数。

以上面的例子为例:

@cache
def is_megatron_core_available():
    try:
        mcore_spec = importlib.util.find_spec("megatron.core")
    except ModuleNotFoundError:
        mcore_spec = None
    return mcore_spec is not None

此函数的行为是调用 importlib.util.find_spec() 来确定模块是否存在。此操作涉及搜索系统路径和加载信息,使其成为 I/O 密集型操作。结果在程序生命周期中是稳定的,因此在第一次调用后缓存结果非常合理!

在幕后,@cache 使用一个无界字典进行缓存,其中函数参数作为键,返回值作为值:

def f(x): ...
f(1)  # → computes and caches
f(1)  # → directly returns the cached result, does not execute the function body again

@functools.lru_cache@functools.cache 的“进化”或“更灵活”版本,提供了更强大和更细粒度的控制。@functools.lru_cache(maxsize=N) 缓存最近 N 次调用的结果,并支持基于最不常用 (LRU) 策略的自动淘汰。让我们看下面的示例:

源代码:https://github.com/vllm-project/vllm/blob/main/vllm/engine/output_processor/multi_step.py#L72

@functools.lru_cache
    def _log_prompt_logprob_unsupported_warning_once():
        # Reminder: Please update docs/features/compatibility_matrix.md
        # If the feature combo become valid
        logger.warning(
            "Prompt logprob is not supported by multi step workers. "
            "(e.g., speculative decode uses multi step workers).")

这段代码的目的是明确的:即使函数被多次调用,这个警告也应该只打印一次。函数没有参数,所以:

  • 第一次调用时,它会打印警告并缓存返回值(为 None)。
  • 后续调用时,由于参数相同(无),它将直接返回缓存结果,而不会执行函数体。
  • 因此,logger.warning 只执行一次。

这实际上等同于 @functools.lru_cache(maxsize=128)(默认缓存大小为 128 种不同的输入组合)。@functools.lru_cache(maxsize=None) 等同于 @cache(无界缓存)。

理论上,同样的目标可以通过 if-else 语句实现:

_warned = False

def _log_prompt_logprob_unsupported_warning_once():
    global _warned
    if not _warned:
        logger.warning("...")
        _warned = True

然而,使用 @lru_cache 的好处是:

  • 简洁性:一行代码即可完成,无需管理全局变量。
  • 线程安全:内部缓存机制是线程安全的。
  • 更清晰的函数式风格:没有副作用变量,适用于多贡献者模块。
  • 可扩展性:支持基于参数的缓存。如果您将来需要为不同参数打印不同的警告,只需添加参数即可。

一个更复杂的用例:针对特定 key 只记录一次日志。

@functools.lru_cache(maxsize=None)
def warn_once_for_key(key):
    logger.warning(f"Warning for {key}")

调用它:

warn_once_for_key("feature_a")  # Logs once
warn_once_for_key("feature_a")  # Does not log again
warn_once_for_key("feature_b")  # New key, logs once

@cache@property 结合起来,你就得到了 @functools.cached_property,它合并了它们的效果。由于篇幅限制,这里不再赘述。

4. 结论

我们从两个流行的开源库 vllmtrl 中的实际代码出发,剖析了 Python 装饰器在机器学习项目中的各种用途。这包括类方法、上下文管理器、抽象方法、静态方法、自定义装饰器以及标准库中的缓存装饰器。巧妙地使用装饰器不仅能使代码更清晰、更有条理,更重要的是,还能显著提高项目的可维护性和可扩展性。希望通过本文,您能对装饰器的本质和应用有更深入的理解,并能启发您在自己的机器学习项目中设计出更优雅、更强大的装饰器!

社区

注册登录 以发表评论