数据集文档

指标

Hugging Face's logo
加入 Hugging Face 社区

并获取增强型文档体验

开始

指标

指标在 🤗 Datasets 中已弃用。要了解有关如何使用指标的更多信息,请查看 🤗 Evaluate 库!除了指标外,您还可以找到更多用于评估模型和数据集的工具。

指标对于评估模型的预测非常重要。在教程中,您了解了如何在整个评估集上计算指标。您还看到了如何加载指标。

本指南将向您展示如何

  • 添加预测和参考。
  • 使用不同的方法计算指标。
  • 编写您自己的指标加载脚本。

添加预测和参考

当您想要将模型预测和参考添加到 Metric 实例时,您有两个选择

通过将模型预测和模型预测应根据其进行评估的参考传递给它,来使用 Metric.add_batch()

>>> import datasets
>>> metric = datasets.load_metric('my_metric')
>>> for model_input, gold_references in evaluation_dataset:
...     model_predictions = model(model_inputs)
...     metric.add_batch(predictions=model_predictions, references=gold_references)
>>> final_score = metric.compute()

Metrics 接受各种输入格式(Python 列表、NumPy 数组、PyTorch 张量等),并将它们转换为适合存储和计算的格式。

计算分数

计算指标的最直接方法是调用 Metric.compute()。但某些指标具有其他参数,这些参数允许您修改指标行为。

让我们加载 SacreBLEU 指标,并使用不同的平滑方法计算它。

  1. 加载 SacreBLEU 指标
>>> import datasets
>>> metric = datasets.load_metric('sacrebleu')
  1. 检查用于计算指标的不同参数方法
>>> print(metric.inputs_description)
Produces BLEU scores along with its sufficient statistics
from a source against one or more references.

Args:
    predictions: The system stream (a sequence of segments).
    references: A list of one or more reference streams (each a sequence of segments).
    smooth_method: The smoothing method to use. (Default: 'exp').
    smooth_value: The smoothing value. Only valid for 'floor' and 'add-k'. (Defaults: floor: 0.1, add-k: 1).
    tokenize: Tokenization method to use for BLEU. If not provided, defaults to 'zh' for Chinese, 'ja-mecab' for Japanese and '13a' (mteval) otherwise.
    lowercase: Lowercase the data. If True, enables case-insensitivity. (Default: False).
    force: Insist that your tokenized input is actually detokenized.
...
  1. 使用 floor 方法和不同的 smooth_value 计算指标
>>> score = metric.compute(smooth_method="floor", smooth_value=0.2)

自定义指标加载脚本

编写一个指标加载脚本以使用您自己的自定义指标(或 Hub 上没有的指标)。然后,您可以像往常一样使用 load_metric() 加载它。

为了帮助您入门,请打开 SQuAD 指标加载脚本 并继续。

使用我们的指标加载脚本 模板 快速入门!

添加指标属性

首先在 Metric._info() 中添加一些关于指标的信息。您应该指定的最重要的属性是

  1. MetricInfo.description 提供了有关指标的简短描述。

  2. MetricInfo.citation 包含指标的 BibTex 引用。

  3. MetricInfo.inputs_description 描述了预期的输入和输出。它还可以提供指标的示例用法。

  4. MetricInfo.features 定义了预测和参考的名称和类型。

在您在模板中填写了所有这些字段后,它应该类似于 SQuAD 指标脚本中的以下示例

class Squad(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
                    "references": {
                        "id": datasets.Value("string"),
                        "answers": datasets.features.Sequence(
                            {
                                "text": datasets.Value("string"),
                                "answer_start": datasets.Value("int32"),
                            }
                        ),
                    },
                }
            ),
            codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
            reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
        )

下载指标文件

如果您的指标需要下载或检索本地文件,则需要使用 Metric._download_and_prepare() 方法。对于此示例,让我们检查 BLEURT 指标加载脚本

  1. 提供指向指标文件的 URL 的字典
CHECKPOINT_URLS = {
    "bleurt-tiny-128": "https://storage.googleapis.com/bleurt-oss/bleurt-tiny-128.zip",
    "bleurt-tiny-512": "https://storage.googleapis.com/bleurt-oss/bleurt-tiny-512.zip",
    "bleurt-base-128": "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip",
    "bleurt-base-512": "https://storage.googleapis.com/bleurt-oss/bleurt-base-512.zip",
    "bleurt-large-128": "https://storage.googleapis.com/bleurt-oss/bleurt-large-128.zip",
    "bleurt-large-512": "https://storage.googleapis.com/bleurt-oss/bleurt-large-512.zip",
}

如果文件存储在本地,请提供路径(s)的字典,而不是 URL。

  1. Metric._download_and_prepare() 将获取 URL 并下载指定的指标文件
def _download_and_prepare(self, dl_manager):

    # check that config name specifies a valid BLEURT model
    if self.config_name == "default":
        logger.warning(
            "Using default BLEURT-Base checkpoint for sequence maximum length 128. "
            "You can use a bigger model for better results with e.g.: datasets.load_metric('bleurt', 'bleurt-large-512')."
        )
        self.config_name = "bleurt-base-128"
    if self.config_name not in CHECKPOINT_URLS.keys():
        raise KeyError(
            f"{self.config_name} model not found. You should supply the name of a model checkpoint for bleurt in {CHECKPOINT_URLS.keys()}"
        )

    # download the model checkpoint specified by self.config_name and set up the scorer
    model_path = dl_manager.download_and_extract(CHECKPOINT_URLS[self.config_name])
    self.scorer = score.BleurtScorer(os.path.join(model_path, self.config_name))

计算分数

DatasetBuilder._compute 提供了在给定预测和参考的情况下如何计算指标的实际说明。现在让我们看看 GLUE 指标加载脚本

  1. DatasetBuilder._compute 提供函数以计算您的指标
def simple_accuracy(preds, labels):
    return (preds == labels).mean().item()

def acc_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    f1 = f1_score(y_true=labels, y_pred=preds).item()
    return {
        "accuracy": acc,
        "f1": f1,
    }

def pearson_and_spearman(preds, labels):
    pearson_corr = pearsonr(preds, labels)[0].item()
    spearman_corr = spearmanr(preds, labels)[0].item()
    return {
        "pearson": pearson_corr,
        "spearmanr": spearman_corr,
    }
  1. 使用用于每个配置要计算的指标的说明创建 DatasetBuilder._compute
def _compute(self, predictions, references):
    if self.config_name == "cola":
        return {"matthews_correlation": matthews_corrcoef(references, predictions)}
    elif self.config_name == "stsb":
        return pearson_and_spearman(predictions, references)
    elif self.config_name in ["mrpc", "qqp"]:
        return acc_and_f1(predictions, references)
    elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
        return {"accuracy": simple_accuracy(predictions, references)}
    else:
        raise KeyError(
            "You should supply a configuration name selected in "
            '["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
            '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
        )

测试

完成指标加载脚本的编写后,尝试在本地加载它。

>>> from datasets import load_metric
>>> metric = load_metric('PATH/TO/MY/SCRIPT.py')
< > 更新 在GitHub上