AutoTrain 文档

Seq2Seq

您正在查看 main 版本,该版本需要从源码安装。如果您想要常规的 pip 安装,请查看最新的稳定版本 (v0.8.24)。
Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Seq2Seq

Seq2Seq 是一项将一个单词序列转换为另一个单词序列的任务。它用于机器翻译、文本摘要和问题解答。

数据格式

您可以将数据集设为 CSV 文件

text,target
"this movie is great","dieser Film ist großartig"
"this movie is bad","dieser Film ist schlecht"
.
.
.

或 JSONL 文件

{"text": "this movie is great", "target": "dieser Film ist großartig"}
{"text": "this movie is bad", "target": "dieser Film ist schlecht"}
.
.
.

您的 CSV/JSONL 数据集必须有两列:texttarget

参数

class autotrain.trainers.seq2seq.params.Seq2SeqParams

< >

( data_path: str = None model: str = 'google/flan-t5-base' username: typing.Optional[str] = None seed: int = 42 train_split: str = 'train' valid_split: typing.Optional[str] = None project_name: str = 'project-name' token: typing.Optional[str] = None push_to_hub: bool = False text_column: str = 'text' target_column: str = 'target' lr: float = 5e-05 epochs: int = 3 max_seq_length: int = 128 max_target_length: int = 128 batch_size: int = 2 warmup_ratio: float = 0.1 gradient_accumulation: int = 1 optimizer: str = 'adamw_torch' scheduler: str = 'linear' weight_decay: float = 0.0 max_grad_norm: float = 1.0 logging_steps: int = -1 eval_strategy: str = 'epoch' auto_find_batch_size: bool = False mixed_precision: typing.Optional[str] = None save_total_limit: int = 1 peft: bool = False quantization: typing.Optional[str] = 'int8' lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 target_modules: str = 'all-linear' log: str = 'none' early_stopping_patience: int = 5 early_stopping_threshold: float = 0.01 )

参数

  • data_path (str) — 数据集路径。
  • model (str) — 要使用的模型名称。默认为 “google/flan-t5-base”。
  • username (Optional[str]) — Hugging Face 用户名。
  • seed (int) — 用于复现的随机种子。默认为 42。
  • train_split (str) — 训练数据拆分名称。默认为 “train”。
  • valid_split (Optional[str]) — 验证数据拆分名称。
  • project_name (str) — 项目名称或输出目录。默认为 “project-name”。
  • token (Optional[str]) — 用于身份验证的 Hub Token。
  • push_to_hub (bool) — 是否将模型推送到 Hugging Face Hub。默认为 False。
  • text_column (str) — 数据集中文本列的名称。默认为 “text”。
  • target_column (str) — 数据集中目标文本列的名称。默认为 “target”。
  • lr (float) — 训练的学习率。默认为 5e-5。
  • epochs (int) — 训练 epoch 的数量。默认为 3。
  • max_seq_length (int) — 输入文本的最大序列长度。默认为 128。
  • max_target_length (int) — 目标文本的最大序列长度。默认为 128。
  • batch_size (int) — 训练批次大小。默认为 2。
  • warmup_ratio (float) — 预热步长的比例。默认为 0.1。
  • gradient_accumulation (int) — 梯度累积步数。默认为 1。
  • optimizer (str) — 要使用的优化器。默认为 “adamw_torch”。
  • scheduler (str) — 要使用的学习率调度器。默认为 “linear”。
  • weight_decay (float) — 优化器的权重衰减。默认为 0.0。
  • max_grad_norm (float) — 梯度裁剪的最大梯度范数。默认为 1.0。
  • logging_steps (int) — 日志记录之间的步数。默认为 -1 (禁用)。
  • eval_strategy (str) — 评估策略。默认为 “epoch”。
  • auto_find_batch_size (bool) — 是否自动查找批大小。默认为 False。
  • mixed_precision (Optional[str]) — 混合精度训练模式 (fp16、bf16 或 None)。
  • save_total_limit (int) — 要保存的最大检查点数量。默认为 1。
  • peft (bool) — 是否使用参数高效微调 (PEFT)。默认为 False。
  • quantization (Optional[str]) — 量化模式 (int4、int8 或 None)。默认为 “int8”。
  • lora_r (int) — PEFT 的 LoRA-R 参数。默认为 16。
  • lora_alpha (int) — PEFT 的 LoRA-Alpha 参数。默认为 32。
  • lora_dropout (float) — PEFT 的 LoRA-Dropout 参数。默认为 0.05。
  • target_modules (str) — PEFT 的目标模块。默认为 “all-linear”。
  • log (str) — 用于实验跟踪的日志记录方法。默认为 “none”。
  • early_stopping_patience (int) — 提前停止的耐心值。默认为 5。
  • early_stopping_threshold (float) — 提前停止的阈值。默认为 0.01。

Seq2SeqParams 是用于序列到序列训练参数的配置类。

< > 在 GitHub 上更新