Diffusers 文档
UNet2DConditionModel
并获得增强的文档体验
开始使用
UNet2DConditionModel
UNet 模型最初由 Ronneberger 等人引入,用于生物医学图像分割,但它也常用于 🤗 Diffusers,因为它输出的图像大小与输入相同。它是扩散系统最重要的组件之一,因为它促进了实际的扩散过程。在 🤗 Diffusers 中,UNet 模型有多种变体,具体取决于其维度数量以及它是否是条件模型。这是一个 2D UNet 条件模型。
论文摘要如下:
人们普遍认为,深度网络的成功训练需要数千个带注释的训练样本。在本文中,我们提出了一种网络和训练策略,它强烈依赖数据增强,以更有效地利用可用的带注释样本。该架构包括一个收缩路径用于捕获上下文,以及一个对称的扩展路径,可实现精确的定位。我们展示了这样的网络可以从很少的图像端到端训练,并且在 ISBI 挑战赛中,在电子显微镜堆栈中分割神经元结构的任务上,其性能优于先前的最佳方法(滑动窗口卷积网络)。使用在透射光显微镜图像(相差和 DIC)上训练的相同网络,我们在 2015 年 ISBI 细胞追踪挑战赛的这些类别中以大幅优势获胜。此外,该网络速度很快。在最新的 GPU 上分割 512x512 图像所需时间不到一秒。完整实现(基于 Caffe)和训练好的网络可在 http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net 获取。
UNet2DConditionModel
类 diffusers.UNet2DConditionModel
< 来源 >( sample_size: typing.Union[int, typing.Tuple[int, int], NoneType] = None in_channels: int = 4 out_channels: int = 4 center_input_sample: bool = False flip_sin_to_cos: bool = True freq_shift: int = 0 down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280) layers_per_block: typing.Union[int, typing.Tuple[int]] = 2 downsample_padding: int = 1 mid_block_scale_factor: float = 1 dropout: float = 0.0 act_fn: str = 'silu' norm_num_groups: typing.Optional[int] = 32 norm_eps: float = 1e-05 cross_attention_dim: typing.Union[int, typing.Tuple[int]] = 1280 transformer_layers_per_block: typing.Union[int, typing.Tuple[int], typing.Tuple[typing.Tuple]] = 1 reverse_transformer_layers_per_block: typing.Optional[typing.Tuple[typing.Tuple[int]]] = None encoder_hid_dim: typing.Optional[int] = None encoder_hid_dim_type: typing.Optional[str] = None attention_head_dim: typing.Union[int, typing.Tuple[int]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int], NoneType] = None dual_cross_attention: bool = False use_linear_projection: bool = False class_embed_type: typing.Optional[str] = None addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None num_class_embeds: typing.Optional[int] = None upcast_attention: bool = False resnet_time_scale_shift: str = 'default' resnet_skip_time_act: bool = False resnet_out_scale_factor: float = 1.0 time_embedding_type: str = 'positional' time_embedding_dim: typing.Optional[int] = None time_embedding_act_fn: typing.Optional[str] = None timestep_post_act: typing.Optional[str] = None time_cond_proj_dim: typing.Optional[int] = None conv_in_kernel: int = 3 conv_out_kernel: int = 3 projection_class_embeddings_input_dim: typing.Optional[int] = None attention_type: str = 'default' class_embeddings_concat: bool = False mid_block_only_cross_attention: typing.Optional[bool] = None cross_attention_norm: typing.Optional[str] = None addition_embed_type_num_heads: int = 64 )
参数
- sample_size (
int
或Tuple[int, int]
, 可选, 默认为None
) — 输入/输出样本的高度和宽度。 - in_channels (
int
, 可选, 默认为 4) — 输入样本中的通道数。 - out_channels (
int
, 可选, 默认为 4) — 输出中的通道数。 - center_input_sample (
bool
, 可选, 默认为False
) — 是否对输入样本进行居中。 - flip_sin_to_cos (
bool
, 可选, 默认为True
) — 是否在时间嵌入中将 sin 翻转为 cos。 - freq_shift (
int
, 可选, 默认为 0) — 应用于时间嵌入的频率偏移。 - down_block_types (
Tuple[str]
, 可选, 默认为("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
) — 要使用的下采样块的元组。 - mid_block_type (
str
, 可选, 默认为"UNetMidBlock2DCrossAttn"
) — UNet 中间块的块类型,可以是UNetMidBlock2DCrossAttn
,UNetMidBlock2D
, 或UNetMidBlock2DSimpleCrossAttn
之一。如果为None
,则跳过中间块层。 - up_block_types (
Tuple[str]
, 可选, 默认为("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
) — 要使用的上采样块的元组。 - only_cross_attention(
bool
或Tuple[bool]
, 可选, 默认为False
) — 是否在基本转换器块中包含自注意力,请参阅BasicTransformerBlock
。 - block_out_channels (
Tuple[int]
, 可选, 默认为(320, 640, 1280, 1280)
) — 每个块的输出通道元组。 - layers_per_block (
int
, 可选, 默认为 2) — 每个块的层数。 - downsample_padding (
int
, 可选, 默认为 1) — 用于下采样卷积的填充。 - mid_block_scale_factor (
float
, 可选, 默认为 1.0) — 用于中间块的比例因子。 - dropout (
float
, 可选, 默认为 0.0) — 要使用的 dropout 概率。 - act_fn (
str
, 可选, 默认为"silu"
) — 要使用的激活函数。 - norm_num_groups (
int
, 可选, 默认为 32) — 用于归一化的组数。如果为None
,则在后处理中跳过归一化和激活层。 - norm_eps (
float
, 可选, 默认为 1e-5) — 用于归一化的 epsilon。 - cross_attention_dim (
int
或Tuple[int]
, 可选, 默认为 1280) — 交叉注意力特征的维度。 - transformer_layers_per_block (
int
,Tuple[int]
, 或Tuple[Tuple]
, 可选, 默认为 1) —BasicTransformerBlock
类型的转换器块的数量。仅与CrossAttnDownBlock2D
,CrossAttnUpBlock2D
,UNetMidBlock2DCrossAttn
相关。 - reverse_transformer_layers_per_block — (
Tuple[Tuple]
, 可选, 默认为 None): U-Net 上采样块中BasicTransformerBlock
类型的转换器块的数量。仅当transformer_layers_per_block
为Tuple[Tuple]
类型时,且对于CrossAttnDownBlock2D
,CrossAttnUpBlock2D
,UNetMidBlock2DCrossAttn
相关。 - encoder_hid_dim (
int
, 可选, 默认为 None) — 如果定义了encoder_hid_dim_type
,则encoder_hidden_states
将从encoder_hid_dim
维度投射到cross_attention_dim
。 - encoder_hid_dim_type (
str
, 可选, 默认为None
) — 如果给定,encoder_hidden_states
和可能其他嵌入将根据encoder_hid_dim_type
降维投射到cross_attention
维度的文本嵌入。 - attention_head_dim (
int
, 可选, 默认为 8) — 注意力头的维度。 - num_attention_heads (
int
, 可选) — 注意力头的数量。如果未定义,则默认为attention_head_dim
- resnet_time_scale_shift (
str
, 可选, 默认为"default"
) — ResNet 块的时间尺度偏移配置(参见ResnetBlock2D
)。可选择default
或scale_shift
。 - class_embed_type (
str
, 可选, 默认为None
) — 类嵌入的类型,最终会与时间嵌入相加。可选择None
,"timestep"
,"identity"
,"projection"
, 或"simple_projection"
。 - addition_embed_type (
str
, 可选, 默认为None
) — 配置一个可选的嵌入,该嵌入将与时间嵌入相加。可选择None
或 “text”。“text” 将使用TextTimeEmbedding
层。 - addition_time_embed_dim — (
int
, 可选, 默认为None
): 时间步嵌入的维度。 - num_class_embeds (
int
, 可选, 默认为None
) — 可学习嵌入矩阵的输入维度,当使用class_embed_type
为None
进行类别条件化时,该矩阵将投射到time_embed_dim
。 - time_embedding_type (
str
, 可选, 默认为positional
) — 用于时间步长的位置嵌入类型。可选择positional
或fourier
。 - time_embedding_dim (
int
, 可选, 默认为None
) — 投影时间嵌入的可选维度覆盖。 - time_embedding_act_fn (
str
, 可选, 默认为None
) — 在时间嵌入传递给 UNet 的其余部分之前,仅使用一次的可选激活函数。可选择silu
、mish
、gelu
和swish
。 - timestep_post_act (
str
, 可选, 默认为None
) — 在时间步长嵌入中使用的第二个激活函数。可选择silu
、mish
和gelu
。 - time_cond_proj_dim (
int
, 可选, 默认为None
) — 时间步长嵌入中cond_proj
层的维度。 - conv_in_kernel (
int
, 可选, 默认为3
) —conv_in
层的核大小。 - conv_out_kernel (
int
, 可选, 默认为3
) —conv_out
层的核大小。 - projection_class_embeddings_input_dim (
int
, 可选) — 当class_embed_type="projection"
时,class_labels
输入的维度。当class_embed_type="projection"
时必需。 - class_embeddings_concat (
bool
, 可选, 默认为False
) — 是否将时间嵌入与类别嵌入拼接。 - mid_block_only_cross_attention (
bool
, 可选, 默认为None
) — 在使用UNetMidBlock2DSimpleCrossAttn
时,是否使用带有中间块的交叉注意力。如果only_cross_attention
给定为单个布尔值且mid_block_only_cross_attention
为None
,则only_cross_attention
的值将用作mid_block_only_cross_attention
的值。否则默认为False
。
一个条件 2D UNet 模型,接收噪声样本、条件状态和时间步长,并返回一个样本形状的输出。
此模型继承自 ModelMixin。有关所有模型实现的通用方法(如下载或保存),请参阅超类文档。
禁用 FreeU 机制。
启用来自 https://huggingface.ac.cn/papers/2309.11497 的 FreeU 机制。
缩放因子后面的后缀表示它们正在应用的阶段块。
请参阅官方仓库,了解适用于 Stable Diffusion v1、v2 和 Stable Diffusion XL 等不同管道的已知良好值组合。
前向传播
< 源 >( sample: Tensor timestep: typing.Union[torch.Tensor, float, int] encoder_hidden_states: Tensor class_labels: typing.Optional[torch.Tensor] = None timestep_cond: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None added_cond_kwargs: typing.Optional[typing.Dict[str, torch.Tensor]] = None down_block_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None mid_block_additional_residual: typing.Optional[torch.Tensor] = None down_intrablock_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None return_dict: bool = True ) → UNet2DConditionOutput 或 tuple
参数
- sample (
torch.Tensor
) — 形状为(batch, channel, height, width)
的带噪输入张量。 - timestep (
torch.Tensor
或float
或int
) — 用于去噪输入的时间步长数量。 - encoder_hidden_states (
torch.Tensor
) — 形状为(batch, sequence_length, feature_dim)
的编码器隐藏状态。 - class_labels (
torch.Tensor
, 可选, 默认为None
) — 用于条件作用的可选类别标签。它们的嵌入将与时间步长嵌入求和。 - timestep_cond — (
torch.Tensor
, 可选, 默认为None
):时间步长的条件嵌入。如果提供,嵌入将与通过self.time_embedding
层传递的样本求和,以获得时间步长嵌入。 - attention_mask (
torch.Tensor
, 可选, 默认为None
) — 形状为(batch, key_tokens)
的注意力掩码应用于encoder_hidden_states
。如果为1
,则保留掩码,否则为0
则丢弃。掩码将被转换为偏差,这将为与“丢弃”标记对应的注意力分数添加较大的负值。 - cross_attention_kwargs (
dict
, 可选) — 如果指定,将作为 kwargs 字典传递给AttentionProcessor
,其定义在 diffusers.models.attention_processor 中的self.processor
下。 - added_cond_kwargs — (
dict
, 可选):一个 kwargs 字典,如果指定,其中包含的额外嵌入将添加到传递给 UNet 块的嵌入中。 - down_block_additional_residuals — (
tuple
oftorch.Tensor
, 可选):如果指定,将添加到 UNet 下行块残差的张量元组。 - mid_block_additional_residual — (
torch.Tensor
, 可选):如果指定,将添加到中间 UNet 块残差的张量。 - down_intrablock_additional_residuals (
tuple
oftorch.Tensor
, 可选) — 要添加到 UNet 下行块内的额外残差,例如来自 T2I-Adapter 侧模型的残差。 - encoder_attention_mask (
torch.Tensor
) — 形状为(batch, sequence_length)
的交叉注意力掩码应用于encoder_hidden_states
。如果为True
,则保留掩码,否则为False
则丢弃。掩码将被转换为偏差,这将为与“丢弃”标记对应的注意力分数添加较大的负值。 - return_dict (
bool
, 可选, 默认为True
) — 是否返回 UNet2DConditionOutput 而不是普通元组。
返回
UNet2DConditionOutput 或 tuple
如果 return_dict
为 True,则返回 UNet2DConditionOutput,否则返回 tuple
,其中第一个元素是样本张量。
UNet2DConditionModel 前向传播方法。
启用分片注意力计算。
启用此选项后,注意力模块会将输入张量分片以分步计算注意力。这对于节省内存非常有用,但会稍微降低速度。
设置注意力处理器
< 源 >( processor: typing.Union[diffusers.models.attention_processor.AttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor2_0, diffusers.models.attention_processor.JointAttnProcessor2_0, diffusers.models.attention_processor.PAGJointAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGJointAttnProcessor2_0, diffusers.models.attention_processor.FusedJointAttnProcessor2_0, diffusers.models.attention_processor.AllegroAttnProcessor2_0, diffusers.models.attention_processor.AuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FusedAuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.CogVideoXAttnProcessor2_0, diffusers.models.attention_processor.FusedCogVideoXAttnProcessor2_0, diffusers.models.attention_processor.XFormersAttnAddedKVProcessor, diffusers.models.attention_processor.XFormersAttnProcessor, diffusers.models.attention_processor.XLAFlashAttnProcessor2_0, diffusers.models.attention_processor.AttnProcessorNPU, diffusers.models.attention_processor.AttnProcessor2_0, diffusers.models.attention_processor.MochiVaeAttnProcessor2_0, diffusers.models.attention_processor.MochiAttnProcessor2_0, diffusers.models.attention_processor.StableAudioAttnProcessor2_0, diffusers.models.attention_processor.HunyuanAttnProcessor2_0, diffusers.models.attention_processor.FusedHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.LuminaAttnProcessor2_0, diffusers.models.attention_processor.FusedAttnProcessor2_0, diffusers.models.attention_processor.CustomDiffusionXFormersAttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor2_0, diffusers.models.attention_processor.SlicedAttnProcessor, diffusers.models.attention_processor.SlicedAttnAddedKVProcessor, diffusers.models.attention_processor.SanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleLinearAttention, diffusers.models.attention_processor.SanaMultiscaleAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleAttentionProjection, diffusers.models.attention_processor.IPAdapterAttnProcessor, diffusers.models.attention_processor.IPAdapterAttnProcessor2_0, diffusers.models.attention_processor.IPAdapterXFormersAttnProcessor, diffusers.models.attention_processor.SD3IPAdapterJointAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.LoRAAttnProcessor, diffusers.models.attention_processor.LoRAAttnProcessor2_0, diffusers.models.attention_processor.LoRAXFormersAttnProcessor, diffusers.models.attention_processor.LoRAAttnAddedKVProcessor, typing.Dict[str, typing.Union[diffusers.models.attention_processor.AttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor2_0, diffusers.models.attention_processor.JointAttnProcessor2_0, diffusers.models.attention_processor.PAGJointAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGJointAttnProcessor2_0, diffusers.models.attention_processor.FusedJointAttnProcessor2_0, diffusers.models.attention_processor.AllegroAttnProcessor2_0, diffusers.models.attention_processor.AuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FusedAuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.CogVideoXAttnProcessor2_0, diffusers.models.attention_processor.FusedCogVideoXAttnProcessor2_0, diffusers.models.attention_processor.XFormersAttnAddedKVProcessor, diffusers.models.attention_processor.XFormersAttnProcessor, diffusers.models.attention_processor.XLAFlashAttnProcessor2_0, diffusers.models.attention_processor.AttnProcessorNPU, diffusers.models.attention_processor.AttnProcessor2_0, diffusers.models.attention_processor.MochiVaeAttnProcessor2_0, diffusers.models.attention_processor.MochiAttnProcessor2_0, diffusers.models.attention_processor.StableAudioAttnProcessor2_0, diffusers.models.attention_processor.HunyuanAttnProcessor2_0, diffusers.models.attention_processor.FusedHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.LuminaAttnProcessor2_0, diffusers.models.attention_processor.FusedAttnProcessor2_0, diffusers.models.attention_processor.CustomDiffusionXFormersAttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor2_0, diffusers.models.attention_processor.SlicedAttnProcessor, diffusers.models.attention_processor.SlicedAttnAddedKVProcessor, diffusers.models.attention_processor.SanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleLinearAttention, diffusers.models.attention_processor.SanaMultiscaleAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleAttentionProjection, diffusers.models.attention_processor.IPAdapterAttnProcessor, diffusers.models.attention_processor.IPAdapterAttnProcessor2_0, diffusers.models.attention_processor.IPAdapterXFormersAttnProcessor, diffusers.models.attention_processor.SD3IPAdapterJointAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.LoRAAttnProcessor, diffusers.models.attention_processor.LoRAAttnProcessor2_0, diffusers.models.attention_processor.LoRAXFormersAttnProcessor, diffusers.models.attention_processor.LoRAAttnAddedKVProcessor]]] )
设置用于计算注意力的注意力处理器。
禁用自定义注意力处理器并设置默认注意力实现。
UNet2DConditionOutput
class diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput
< 源 >( sample: Tensor = None )
UNet2DConditionModel 的输出。
FlaxUNet2DConditionModel
class diffusers.FlaxUNet2DConditionModel
< 源 >( sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: typing.Tuple[str, ...] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') up_block_types: typing.Tuple[str, ...] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: typing.Union[int, typing.Tuple[int, ...]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int, ...], NoneType] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: dtype = <class 'jax.numpy.float32'> flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False split_head_dim: bool = False transformer_layers_per_block: typing.Union[int, typing.Tuple[int, ...]] = 1 addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: typing.Optional[int] = None parent: typing.Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7fc460aac610> name: typing.Optional[str] = None )
参数
- sample_size (
int
, 可选) — 输入样本的大小。 - in_channels (
int
, 可选, 默认为 4) — 输入样本中的通道数。 - out_channels (
int
, 可选, 默认为 4) — 输出中的通道数。 - down_block_types (
Tuple[str]
, 可选, 默认为("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")
) — 要使用的下采样块的元组。 - up_block_types (
Tuple[str]
, 可选, 默认为("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")
) — 要使用的上采样块的元组。 - mid_block_type (
str
, 可选, 默认为"UNetMidBlock2DCrossAttn"
) — UNet 中间块的块类型,可以是UNetMidBlock2DCrossAttn
之一。如果为None
,则跳过中间块层。 - block_out_channels (
Tuple[int]
, 可选, 默认为(320, 640, 1280, 1280)
) — 每个块的输出通道元组。 - layers_per_block (
int
, 可选, 默认为 2) — 每个块的层数。 - attention_head_dim (
int
或Tuple[int]
, 可选, 默认为 8) — 注意力头的维度。 - num_attention_heads (
int
或Tuple[int]
, 可选) — 注意力头的数量。 - cross_attention_dim (
int
, 可选, 默认为 768) — 交叉注意力特征的维度。 - dropout (
float
, 可选, 默认为 0) — 下采样、上采样和瓶颈块的 dropout 概率。 - flip_sin_to_cos (
bool
, 可选, 默认为True
) — 是否在时间嵌入中将 sin 翻转为 cos。 - freq_shift (
int
, 可选, 默认为 0) — 应用于时间嵌入的频率偏移。 - use_memory_efficient_attention (
bool
, 可选, 默认为False
) — 启用 此处 描述的内存高效注意力。 - split_head_dim (
bool
, 可选, 默认为False
) — 是否将头部维度拆分为自注意力计算的新轴。在大多数情况下,启用此标志应能加速 Stable Diffusion 2.x 和 Stable Diffusion XL 的计算。
一个条件 2D UNet 模型,接收噪声样本、条件状态和时间步长,并返回一个样本形状的输出。
此模型继承自 FlaxModelMixin。请查看超类文档以了解所有模型实现的通用方法(例如下载或保存)。
此模型也是 Flax Linen flax.linen.Module 的子类。将其作为常规 Flax Linen 模块使用,并参阅 Flax 文档中与其一般用法和行为相关的所有内容。
支持以下固有的 JAX 功能
FlaxUNet2DConditionOutput
class diffusers.models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput
< source >( sample: Array )
返回一个新对象,用新值替换指定的字段。