竞赛文档

自定义指标

您正在查看 main 版本,需要从源代码安装. 如果您想使用常规 pip 安装,请查看最新的稳定版本 (v0.1.6).
Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始

自定义指标

如果您不满足于默认的 scikit-learn 指标,您可以定义自己的指标。

在这里,我们希望组织者了解 python。

如何定义自定义指标

要定义自定义指标,请将 conf.json 中的 EVAL_METRIC 更改为 custom。您还必须确保 EVAL_HIGHER_IS_BETTER 设置为 10,具体取决于指标的较高值是否更好。

第二步是在私有竞赛仓库中创建一个文件 metric.py。该文件应包含一个 compute 函数,该函数以竞赛参数作为输入。

这是我们检查指标是否为自定义并计算指标值的部分

def compute_metrics(params):
    if params.metric == "custom":
        metric_file = hf_hub_download(
            repo_id=params.competition_id,
            filename="metric.py",
            token=params.token,
            repo_type="dataset",
        )
        sys.path.append(os.path.dirname(metric_file))
        metric = importlib.import_module("metric")
        evaluation = metric.compute(params)
    .
    .
    .

您可以在竞赛 github 仓库的 compute_metrics.py 中找到上述部分

params 定义为

class EvalParams(BaseModel):
    competition_id: str
    competition_type: str
    metric: str
    token: str
    team_id: str
    submission_id: str
    submission_id_col: str
    submission_cols: List[str]
    submission_rows: int
    output_path: str
    submission_repo: str
    time_limit: int
    dataset: str  # private test dataset, used only for script competitions

您可以随意在 compute 函数中执行任何操作。最后,它必须返回一个包含以下键的字典

{
    "public_score": {
        "metric1": metric_value,
    },,
    "private_score": {
        "metric1": metric_value,
    },,
}

公共和私有分数必须是字典!您也可以使用多个指标。多个指标示例

{
    "public_score": {
        "metric1": metric_value,
        "metric2": metric_value,
    },
    "private_score": {
        "metric1": metric_value,
        "metric2": metric_value,
    },
}

注意:使用多个指标时,conf.json 必须指定 SCORING_METRIC 以对竞赛中的参与者进行排名。

例如,如果我想使用 metric2 对参与者进行排名,我将在 conf.json 中将 SCORING_METRIC 设置为 metric2

自定义指标示例

import pandas as pd
from huggingface_hub import hf_hub_download


def compute(params):
    solution_file = hf_hub_download(
        repo_id=params.competition_id,
        filename="solution.csv",
        token=params.token,
        repo_type="dataset",
    )

    solution_df = pd.read_csv(solution_file)

    submission_filename = f"submissions/{params.team_id}-{params.submission_id}.csv"
    submission_file = hf_hub_download(
        repo_id=params.competition_id,
        filename=submission_filename,
        token=params.token,
        repo_type="dataset",
    )
    submission_df = pd.read_csv(submission_file)

    public_ids = solution_df[solution_df.split == "public"][params.submission_id_col].values
    private_ids = solution_df[solution_df.split == "private"][params.submission_id_col].values

    public_solution_df = solution_df[solution_df[params.submission_id_col].isin(public_ids)]
    public_submission_df = submission_df[submission_df[params.submission_id_col].isin(public_ids)]

    private_solution_df = solution_df[solution_df[params.submission_id_col].isin(private_ids)]
    private_submission_df = submission_df[submission_df[params.submission_id_col].isin(private_ids)]

    public_solution_df = public_solution_df.sort_values(params.submission_id_col).reset_index(drop=True)
    public_submission_df = public_submission_df.sort_values(params.submission_id_col).reset_index(drop=True)

    private_solution_df = private_solution_df.sort_values(params.submission_id_col).reset_index(drop=True)
    private_submission_df = private_submission_df.sort_values(params.submission_id_col).reset_index(drop=True)

    # CALCULATE METRICS HERE.......
    # _metric = SOME METRIC FUNCTION
    target_cols = [col for col in solution_df.columns if col not in [params.submission_id_col, "split"]]
    public_score = _metric(public_solution_df[target_cols], public_submission_df[target_cols])
    private_score = _metric(private_solution_df[target_cols], private_submission_df[target_cols])

    evaluation = {
        "public_score": {
            "metric1": public_score,
        },
        "private_score": {
            "metric1": public_score,
        }
    }
    return evaluation

仔细查看上面的代码。您可以看到我们正在从数据集仓库中下载解决方案文件和提交文件。然后,我们正在计算解决方案和提交文件的公共和私有拆分的指标。最后,我们将指标值以字典形式返回。

< > 更新 在 GitHub 上