AWS Trainium 和 Inferentia 文档

NeuronTrainer

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始

NeuronTrainer

NeuronTrainer 类为功能齐全的 Transformers Trainer 提供了扩展的 API。它用于所有 示例脚本

NeuronTrainer 类针对在 AWS Trainium 上运行的 🤗 Transformers 模型进行了优化。

以下是如何自定义 NeuronTrainer 以使用加权损失(在训练集不平衡时很有用)的示例

from torch import nn
from optimum.neuron import NeuronTrainer


class CustomNeuronTrainer(NeuronTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

自定义 PyTorch NeuronTrainer 的训练循环行为的另一种方法是使用 回调,这些回调可以检查训练循环状态(用于进度报告、TensorBoard 或其他机器学习平台上的日志记录……)并做出决策(例如提前停止)。

NeuronTrainer

optimum.neuron.NeuronTrainer

< >

( *args **kwargs )

适用于在 AWS Tranium 实例上执行训练的 Trainer。

optimum.neuron.Seq2SeqNeuronTrainer

< >

( *args **kwargs )

适用于在 AWS Tranium 实例上执行训练的 Seq2SeqTrainer。