竞赛文档

自定义指标

您正在查看 main 版本,该版本需要从源代码安装。如果您想进行常规 pip 安装,请查看最新的稳定版本 (v0.1.6)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

自定义指标

如果您不满足于默认的 scikit-learn 指标,您可以定义自己的指标。

在这里,我们期望组织者懂 Python。

如何定义自定义指标

要定义自定义指标,请将 conf.json 中的 EVAL_METRIC 更改为 custom。您还必须确保根据指标值越高越好与否,将 EVAL_HIGHER_IS_BETTER 设置为 10

第二步是在私有竞赛仓库中创建一个文件 metric.py。该文件应包含一个 compute 函数,该函数将竞赛参数作为输入。

这是我们检查指标是否为自定义并计算指标值的部分

def compute_metrics(params):
    if params.metric == "custom":
        metric_file = hf_hub_download(
            repo_id=params.competition_id,
            filename="metric.py",
            token=params.token,
            repo_type="dataset",
        )
        sys.path.append(os.path.dirname(metric_file))
        metric = importlib.import_module("metric")
        evaluation = metric.compute(params)
    .
    .
    .

您可以在竞赛 github 仓库 compute_metrics.py 中找到以上部分

params 定义为

class EvalParams(BaseModel):
    competition_id: str
    competition_type: str
    metric: str
    token: str
    team_id: str
    submission_id: str
    submission_id_col: str
    submission_cols: List[str]
    submission_rows: int
    output_path: str
    submission_repo: str
    time_limit: int
    dataset: str  # private test dataset, used only for script competitions

您可以在 compute 函数中自由执行任何操作。最后,它必须返回一个包含以下键的字典

{
    "public_score": {
        "metric1": metric_value,
    },,
    "private_score": {
        "metric1": metric_value,
    },,
}

公共和私有分数必须是字典!您也可以使用多个指标。多个指标的示例

{
    "public_score": {
        "metric1": metric_value,
        "metric2": metric_value,
    },
    "private_score": {
        "metric1": metric_value,
        "metric2": metric_value,
    },
}

注意:当使用多个指标时,conf.json 必须指定 SCORING_METRIC 以对竞赛中的参与者进行排名。

例如,如果我想使用 metric2 对参与者进行排名,我将在 conf.json 中将 SCORING_METRIC 设置为 metric2

自定义指标示例

import pandas as pd
from huggingface_hub import hf_hub_download


def compute(params):
    solution_file = hf_hub_download(
        repo_id=params.competition_id,
        filename="solution.csv",
        token=params.token,
        repo_type="dataset",
    )

    solution_df = pd.read_csv(solution_file)

    submission_filename = f"submissions/{params.team_id}-{params.submission_id}.csv"
    submission_file = hf_hub_download(
        repo_id=params.competition_id,
        filename=submission_filename,
        token=params.token,
        repo_type="dataset",
    )
    submission_df = pd.read_csv(submission_file)

    public_ids = solution_df[solution_df.split == "public"][params.submission_id_col].values
    private_ids = solution_df[solution_df.split == "private"][params.submission_id_col].values

    public_solution_df = solution_df[solution_df[params.submission_id_col].isin(public_ids)]
    public_submission_df = submission_df[submission_df[params.submission_id_col].isin(public_ids)]

    private_solution_df = solution_df[solution_df[params.submission_id_col].isin(private_ids)]
    private_submission_df = submission_df[submission_df[params.submission_id_col].isin(private_ids)]

    public_solution_df = public_solution_df.sort_values(params.submission_id_col).reset_index(drop=True)
    public_submission_df = public_submission_df.sort_values(params.submission_id_col).reset_index(drop=True)

    private_solution_df = private_solution_df.sort_values(params.submission_id_col).reset_index(drop=True)
    private_submission_df = private_submission_df.sort_values(params.submission_id_col).reset_index(drop=True)

    # CALCULATE METRICS HERE.......
    # _metric = SOME METRIC FUNCTION
    target_cols = [col for col in solution_df.columns if col not in [params.submission_id_col, "split"]]
    public_score = _metric(public_solution_df[target_cols], public_submission_df[target_cols])
    private_score = _metric(private_solution_df[target_cols], private_submission_df[target_cols])

    evaluation = {
        "public_score": {
            "metric1": public_score,
        },
        "private_score": {
            "metric1": public_score,
        }
    }
    return evaluation

仔细查看上面的代码。您可以看到我们正在从数据集仓库下载解决方案文件和提交文件。然后,我们正在计算解决方案和提交文件的公共和私有分割上的指标。最后,我们正在字典中返回指标值。

< > 在 GitHub 上更新