指标
指标在 🤗 Datasets 中已弃用。要了解有关如何使用指标的更多信息,请查看 🤗 Evaluate 库!除了指标外,您还可以找到更多用于评估模型和数据集的工具。
指标对于评估模型的预测非常重要。在教程中,您了解了如何在整个评估集上计算指标。您还看到了如何加载指标。
本指南将向您展示如何
- 添加预测和参考。
- 使用不同的方法计算指标。
- 编写您自己的指标加载脚本。
添加预测和参考
当您想要将模型预测和参考添加到 Metric 实例时,您有两个选择
Metric.add() 添加单个
prediction
和reference
。Metric.add_batch() 添加一批
predictions
和references
。
通过将模型预测和模型预测应根据其进行评估的参考传递给它,来使用 Metric.add_batch()
>>> import datasets
>>> metric = datasets.load_metric('my_metric')
>>> for model_input, gold_references in evaluation_dataset:
... model_predictions = model(model_inputs)
... metric.add_batch(predictions=model_predictions, references=gold_references)
>>> final_score = metric.compute()
Metrics 接受各种输入格式(Python 列表、NumPy 数组、PyTorch 张量等),并将它们转换为适合存储和计算的格式。
计算分数
计算指标的最直接方法是调用 Metric.compute()。但某些指标具有其他参数,这些参数允许您修改指标行为。
让我们加载 SacreBLEU 指标,并使用不同的平滑方法计算它。
- 加载 SacreBLEU 指标
>>> import datasets
>>> metric = datasets.load_metric('sacrebleu')
- 检查用于计算指标的不同参数方法
>>> print(metric.inputs_description)
Produces BLEU scores along with its sufficient statistics
from a source against one or more references.
Args:
predictions: The system stream (a sequence of segments).
references: A list of one or more reference streams (each a sequence of segments).
smooth_method: The smoothing method to use. (Default: 'exp').
smooth_value: The smoothing value. Only valid for 'floor' and 'add-k'. (Defaults: floor: 0.1, add-k: 1).
tokenize: Tokenization method to use for BLEU. If not provided, defaults to 'zh' for Chinese, 'ja-mecab' for Japanese and '13a' (mteval) otherwise.
lowercase: Lowercase the data. If True, enables case-insensitivity. (Default: False).
force: Insist that your tokenized input is actually detokenized.
...
- 使用
floor
方法和不同的smooth_value
计算指标
>>> score = metric.compute(smooth_method="floor", smooth_value=0.2)
自定义指标加载脚本
编写一个指标加载脚本以使用您自己的自定义指标(或 Hub 上没有的指标)。然后,您可以像往常一样使用 load_metric() 加载它。
为了帮助您入门,请打开 SQuAD 指标加载脚本 并继续。
使用我们的指标加载脚本 模板 快速入门!
添加指标属性
首先在 Metric._info()
中添加一些关于指标的信息。您应该指定的最重要的属性是
MetricInfo.description
提供了有关指标的简短描述。MetricInfo.citation
包含指标的 BibTex 引用。MetricInfo.inputs_description
描述了预期的输入和输出。它还可以提供指标的示例用法。MetricInfo.features
定义了预测和参考的名称和类型。
在您在模板中填写了所有这些字段后,它应该类似于 SQuAD 指标脚本中的以下示例
class Squad(datasets.Metric):
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
"references": {
"id": datasets.Value("string"),
"answers": datasets.features.Sequence(
{
"text": datasets.Value("string"),
"answer_start": datasets.Value("int32"),
}
),
},
}
),
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
)
下载指标文件
如果您的指标需要下载或检索本地文件,则需要使用 Metric._download_and_prepare()
方法。对于此示例,让我们检查 BLEURT 指标加载脚本。
- 提供指向指标文件的 URL 的字典
CHECKPOINT_URLS = {
"bleurt-tiny-128": "https://storage.googleapis.com/bleurt-oss/bleurt-tiny-128.zip",
"bleurt-tiny-512": "https://storage.googleapis.com/bleurt-oss/bleurt-tiny-512.zip",
"bleurt-base-128": "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip",
"bleurt-base-512": "https://storage.googleapis.com/bleurt-oss/bleurt-base-512.zip",
"bleurt-large-128": "https://storage.googleapis.com/bleurt-oss/bleurt-large-128.zip",
"bleurt-large-512": "https://storage.googleapis.com/bleurt-oss/bleurt-large-512.zip",
}
如果文件存储在本地,请提供路径(s)的字典,而不是 URL。
Metric._download_and_prepare()
将获取 URL 并下载指定的指标文件
def _download_and_prepare(self, dl_manager):
# check that config name specifies a valid BLEURT model
if self.config_name == "default":
logger.warning(
"Using default BLEURT-Base checkpoint for sequence maximum length 128. "
"You can use a bigger model for better results with e.g.: datasets.load_metric('bleurt', 'bleurt-large-512')."
)
self.config_name = "bleurt-base-128"
if self.config_name not in CHECKPOINT_URLS.keys():
raise KeyError(
f"{self.config_name} model not found. You should supply the name of a model checkpoint for bleurt in {CHECKPOINT_URLS.keys()}"
)
# download the model checkpoint specified by self.config_name and set up the scorer
model_path = dl_manager.download_and_extract(CHECKPOINT_URLS[self.config_name])
self.scorer = score.BleurtScorer(os.path.join(model_path, self.config_name))
计算分数
DatasetBuilder._compute
提供了在给定预测和参考的情况下如何计算指标的实际说明。现在让我们看看 GLUE 指标加载脚本。
- 为
DatasetBuilder._compute
提供函数以计算您的指标
def simple_accuracy(preds, labels):
return (preds == labels).mean().item()
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds).item()
return {
"accuracy": acc,
"f1": f1,
}
def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0].item()
spearman_corr = spearmanr(preds, labels)[0].item()
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
}
- 使用用于每个配置要计算的指标的说明创建
DatasetBuilder._compute
def _compute(self, predictions, references):
if self.config_name == "cola":
return {"matthews_correlation": matthews_corrcoef(references, predictions)}
elif self.config_name == "stsb":
return pearson_and_spearman(predictions, references)
elif self.config_name in ["mrpc", "qqp"]:
return acc_and_f1(predictions, references)
elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
return {"accuracy": simple_accuracy(predictions, references)}
else:
raise KeyError(
"You should supply a configuration name selected in "
'["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
'"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
)
测试
完成指标加载脚本的编写后,尝试在本地加载它。
>>> from datasets import load_metric
>>> metric = load_metric('PATH/TO/MY/SCRIPT.py')