Accelerate 文档
Kwargs 关键字参数处理程序
并获得增强的文档体验
开始使用
关键字参数处理程序
以下对象可以传递给主要的 Accelerator,以自定义如何创建与分布式训练或混合精度相关的一些 PyTorch 对象。
AutocastKwargs
在你的 Accelerator 中使用此对象来自定义 torch.autocast
的行为。请参阅此上下文管理器的文档以获取有关每个参数的更多信息。
DistributedDataParallelKwargs
class accelerate.DistributedDataParallelKwargs
< source >( dim: int = 0 broadcast_buffers: bool = True bucket_cap_mb: int = 25 find_unused_parameters: bool = False check_reduction: bool = False gradient_as_bucket_view: bool = False static_graph: bool = False comm_hook: DDPCommunicationHookType = <DDPCommunicationHookType.NO: 'no'> comm_wrapper: typing.Literal[<DDPCommunicationHookType.NO: 'no'>, <DDPCommunicationHookType.FP16: 'fp16'>, <DDPCommunicationHookType.BF16: 'bf16'>] = <DDPCommunicationHookType.NO: 'no'> comm_state_option: dict = <factory> )
在你的 Accelerator 中使用此对象来自定义你的模型如何被封装在 torch.nn.parallel.DistributedDataParallel
中。请参阅此封装器的文档以获取有关每个参数的更多信息。
gradient_as_bucket_view
仅在 PyTorch 1.7.0 及更高版本中可用。
static_graph
仅在 PyTorch 1.11.0 及更高版本中可用。
FP8RecipeKwargs
class accelerate.utils.FP8RecipeKwargs
< source >( opt_level: typing.Literal['O1', 'O2'] = None use_autocast_during_eval: bool = None margin: int = None interval: int = None fp8_format: typing.Literal['E4M3', 'HYBRID'] = None amax_history_len: int = None amax_compute_algo: typing.Literal['max', 'most_recent'] = None override_linear_precision: tuple = None backend: typing.Literal['MSAMP', 'TE'] = None )
已弃用。请使用适当的 FP8 配方关键字参数类之一,例如 TERecipeKwargs
或 MSAMPRecipeKwargs
。
ProfileKwargs
class accelerate.ProfileKwargs
< source >( activities: typing.Optional[list[typing.Literal['cpu', 'xpu', 'mtia', 'cuda', 'hpu']]] = None schedule_option: typing.Optional[dict[str, int]] = None on_trace_ready: typing.Optional[typing.Callable] = None record_shapes: bool = False profile_memory: bool = False with_stack: bool = False with_flops: bool = False with_modules: bool = False output_trace_dir: typing.Optional[str] = None )
参数
- activities (
List[str]
, 可选, 默认为None
) — 用于性能分析的活动组列表。必须是 "cpu"、"xpu"、"mtia"、“hpu” 或 "cuda" 之一。 - schedule_option (
Dict[str, int]
, 可选, 默认为None
) — 用于性能分析器的调度选项。可用键包括wait
、warmup
、active
、repeat
和skip_first
。性能分析器将跳过前skip_first
步,然后等待wait
步,然后为接下来的warmup
步进行预热,然后为接下来的active
步进行活动记录,然后重复循环,从wait
步开始。可选的循环次数由repeat
参数指定,零值表示循环将持续到性能分析完成。 - on_trace_ready (
Callable
, 可选, 默认为None
) — 在性能分析期间,当调度返回ProfilerAction.RECORD_AND_SAVE
时,每步调用的可调用对象。 - record_shapes (
bool
, 可选, 默认为False
) — 保存有关运算符输入形状的信息。 - profile_memory (
bool
, 可选, 默认为False
) — 跟踪张量内存分配/释放 - with_stack (
bool
, 可选, 默认为False
) — 记录操作的源信息(文件和行号)。 - with_flops (
bool
, 可选, 默认为False
) — 使用公式估算特定运算符的 FLOPS - with_modules (
bool
, 可选, 默认为False
) — 记录与操作的调用堆栈相对应的模块层次结构(包括函数名)。 - output_trace_dir (
str
, 可选, 默认为None
) — 以 Chrome JSON 格式导出收集的跟踪信息。Chrome 使用 ‘chrome://tracing’ 查看 json 文件。默认为 None,表示性能分析不存储 json 文件。
在您的 Accelerator 中使用此对象来自定义性能分析器的初始化。有关每个参数的更多信息,请参阅此 上下文管理器 的文档。
torch.profiler
仅在 PyTorch 1.8.1 及更高版本中可用。
示例
from accelerate import Accelerator
from accelerate.utils import ProfileKwargs
kwargs = ProfileKwargs(activities=["cpu", "cuda"])
accelerator = Accelerator(kwargs_handlers=[kwargs])
使用当前配置构建性能分析器对象。
GradScalerKwargs
class accelerate.GradScalerKwargs
< source >( init_scale: float = 65536.0 growth_factor: float = 2.0 backoff_factor: float = 0.5 growth_interval: int = 2000 enabled: bool = True )
在您的 Accelerator 中使用此对象来自定义混合精度的行为,特别是如何创建使用的 torch.cuda.amp.GradScaler
。有关每个参数的更多信息,请参阅此 scaler 的文档。
GradScaler
仅在 PyTorch 1.5.0 及更高版本中可用。
InitProcessGroupKwargs
class accelerate.InitProcessGroupKwargs
< source >( backend: typing.Optional[str] = 'nccl' init_method: typing.Optional[str] = None timeout: typing.Optional[datetime.timedelta] = None )
在您的 Accelerator 中使用此对象来自定义分布式进程的初始化。有关每个参数的更多信息,请参阅此 方法 的文档。
注意:如果 timeout
设置为 None
,则默认值将基于 backend
的设置方式。
KwargsHandler
实现数据类 to_kwargs()
方法的内部混入。
返回一个字典,其中包含与此类默认值不同的属性值。