Diffusers 文档

UNet2DConditionModel

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

UNet2DConditionModel

The UNet 模型最初由 Ronneberger 等人引入,用于生物医学图像分割,但它也常用于 🤗 Diffusers,因为它输出的图像大小与输入相同。 它是扩散系统最重要的组成部分之一,因为它促进了实际的扩散过程。 🤗 Diffusers 中 UNet 模型有几种变体,具体取决于其维度数量以及它是否是有条件模型。 这是一个 2D UNet 条件模型。

论文摘要如下:

人们普遍认为,深度网络的成功训练需要数千个带注释的训练样本。 在本文中,我们提出了一种网络和训练策略,该策略依赖于大量使用数据增强来更有效地利用可用的带注释样本。 该架构由捕获上下文的收缩路径和实现精确定位的对称扩展路径组成。 我们表明,这样的网络可以从极少量的图像进行端到端训练,并且在 ISBI 挑战赛中,在电子显微镜堆栈中神经元结构分割方面,优于先前最佳方法(滑动窗口卷积网络)。 使用在透射光显微镜图像(相衬和 DIC)上训练的相同网络,我们以很大的优势赢得了 2015 年 ISBI 细胞跟踪挑战赛的这些类别。 此外,该网络速度很快。 在最新的 GPU 上分割 512x512 图像不到一秒。 完整实现(基于 Caffe)和训练好的网络可在 http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net 获取。

UNet2DConditionModel

diffusers.UNet2DConditionModel

< >

( sample_size: typing.Union[int, typing.Tuple[int, int], NoneType] = None in_channels: int = 4 out_channels: int = 4 center_input_sample: bool = False flip_sin_to_cos: bool = True freq_shift: int = 0 down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280) layers_per_block: typing.Union[int, typing.Tuple[int]] = 2 downsample_padding: int = 1 mid_block_scale_factor: float = 1 dropout: float = 0.0 act_fn: str = 'silu' norm_num_groups: typing.Optional[int] = 32 norm_eps: float = 1e-05 cross_attention_dim: typing.Union[int, typing.Tuple[int]] = 1280 transformer_layers_per_block: typing.Union[int, typing.Tuple[int], typing.Tuple[typing.Tuple]] = 1 reverse_transformer_layers_per_block: typing.Optional[typing.Tuple[typing.Tuple[int]]] = None encoder_hid_dim: typing.Optional[int] = None encoder_hid_dim_type: typing.Optional[str] = None attention_head_dim: typing.Union[int, typing.Tuple[int]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int], NoneType] = None dual_cross_attention: bool = False use_linear_projection: bool = False class_embed_type: typing.Optional[str] = None addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None num_class_embeds: typing.Optional[int] = None upcast_attention: bool = False resnet_time_scale_shift: str = 'default' resnet_skip_time_act: bool = False resnet_out_scale_factor: float = 1.0 time_embedding_type: str = 'positional' time_embedding_dim: typing.Optional[int] = None time_embedding_act_fn: typing.Optional[str] = None timestep_post_act: typing.Optional[str] = None time_cond_proj_dim: typing.Optional[int] = None conv_in_kernel: int = 3 conv_out_kernel: int = 3 projection_class_embeddings_input_dim: typing.Optional[int] = None attention_type: str = 'default' class_embeddings_concat: bool = False mid_block_only_cross_attention: typing.Optional[bool] = None cross_attention_norm: typing.Optional[str] = None addition_embed_type_num_heads: int = 64 )

参数

  • sample_size (intTuple[int, int], 可选, 默认为 None) — 输入/输出样本的高度和宽度。
  • in_channels (int, 可选, 默认为 4) — 输入样本中的通道数。
  • out_channels (int, 可选, 默认为 4) — 输出中的通道数。
  • center_input_sample (bool, 可选, 默认为 False) — 是否将输入样本居中。
  • flip_sin_to_cos (bool, 可选, 默认为 True) — 是否在时间嵌入中将 sin 翻转为 cos。
  • freq_shift (int, 可选, 默认为 0) — 应用于时间嵌入的频率偏移。
  • down_block_types (Tuple[str], 可选, 默认为 ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")) — 要使用的下采样块的元组。
  • mid_block_type (str, 可选, 默认为 "UNetMidBlock2DCrossAttn") — UNet 中间块的块类型,可以是 UNetMidBlock2DCrossAttn, UNetMidBlock2D, 或 UNetMidBlock2DSimpleCrossAttn 之一。如果为 None,则跳过中间块层。
  • up_block_types (Tuple[str], 可选, 默认为 ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")) — 要使用的上采样块的元组。
  • only_cross_attention(boolTuple[bool], 可选, 默认为 False) — 是否在基本 Transformer 块中包含自注意力,请参阅 BasicTransformerBlock
  • block_out_channels (Tuple[int], 可选, 默认为 (320, 640, 1280, 1280)) — 每个块的输出通道的元组。
  • layers_per_block (int, 可选, 默认为 2) — 每个块的层数。
  • downsample_padding (int, 可选, 默认为 1) — 用于下采样卷积的填充。
  • mid_block_scale_factor (float, 可选, 默认为 1.0) — 用于中间块的缩放因子。
  • dropout (float, 可选, 默认为 0.0) — 要使用的 dropout 概率。
  • act_fn (str, 可选, 默认为 "silu") — 要使用的激活函数。
  • norm_num_groups (int, 可选, 默认为 32) — 用于归一化的组数。如果为 None,则在后处理中跳过归一化和激活层。
  • norm_eps (float, 可选, 默认为 1e-5) — 用于归一化的 epsilon 值。
  • cross_attention_dim (intTuple[int], 可选, 默认为 1280) — 交叉注意力特征的维度。
  • transformer_layers_per_block (int, Tuple[int], 或 Tuple[Tuple] , 可选, 默认为 1) — BasicTransformerBlock 类型的 Transformer 块的数量。 仅与 CrossAttnDownBlock2D, CrossAttnUpBlock2D, UNetMidBlock2DCrossAttn 相关。
  • reverse_transformer_layers_per_block — (Tuple[Tuple], 可选, 默认为 None): BasicTransformerBlock 类型的 Transformer 块的数量,在 U-Net 的上采样块中。 仅当 transformer_layers_per_block 的类型为 Tuple[Tuple] 且与 CrossAttnDownBlock2D, CrossAttnUpBlock2D, UNetMidBlock2DCrossAttn 相关时才相关。
  • encoder_hid_dim (int, 可选, 默认为 None) — 如果定义了 encoder_hid_dim_type,则 encoder_hidden_states 将从 encoder_hid_dim 维度投影到 cross_attention_dim
  • encoder_hid_dim_type (str, 可选, 默认为 None) — 如果给定,encoder_hidden_states 以及可能的其他嵌入将根据 encoder_hid_dim_type 向下投影到维度为 cross_attention 的文本嵌入。
  • attention_head_dim (int, 可选, 默认为 8) — 注意力头的维度。
  • num_attention_heads (int, 可选) — 注意力头的数量。 如果未定义,则默认为 attention_head_dim
  • resnet_time_scale_shift (str, 可选, 默认为 "default") — ResNet 块的时间尺度偏移配置(参见 ResnetBlock2D)。 从 defaultscale_shift 中选择。
  • class_embed_type (str, 可选, 默认为 None) — 要使用的类嵌入类型,它最终与时间嵌入求和。 从 None, "timestep", "identity", "projection", 或 "simple_projection" 中选择。
  • addition_embed_type (str, 可选, 默认为 None) — 配置可选的嵌入,它将与时间嵌入求和。 从 None 或 “text” 中选择。“text” 将使用 TextTimeEmbedding 层。
  • addition_time_embed_dim — (int, 可选, 默认为 None): 时间步嵌入的维度。
  • num_class_embeds (int, 可选, 默认为 None) — 可学习嵌入矩阵的输入维度,当使用 class_embed_type 等于 None 执行类条件调节时,该矩阵将被投影到 time_embed_dim
  • time_embedding_type (str, 可选, 默认为 positional) — 用于时间步的位置嵌入类型。 从 positionalfourier 中选择。
  • time_embedding_dim (int, 可选, 默认为 None) — 用于投影时间嵌入维度的可选覆盖。
  • time_embedding_act_fn (str, 可选, 默认为 None) — 可选的激活函数,仅在时间嵌入传递到 UNet 的其余部分之前使用一次。 从 silumishgeluswish 中选择。
  • timestep_post_act (str, 可选, 默认为 None) — 在时间步嵌入中使用的第二个激活函数。 从 silumishgelu 中选择。
  • time_cond_proj_dim (int, 可选, 默认为 None) — 时间步嵌入中 cond_proj 层的维度。
  • conv_in_kernel (int, 可选, 默认为 3) — conv_in 层的内核大小。
  • conv_out_kernel (int, 可选, 默认为 3) — conv_out 层的内核大小。
  • projection_class_embeddings_input_dim (int, 可选) — 当 class_embed_type="projection" 时,class_labels 输入的维度。 当 class_embed_type="projection" 时为必需。
  • class_embeddings_concat (bool, 可选, 默认为 False) — 是否将时间嵌入与类别嵌入连接。
  • mid_block_only_cross_attention (bool, 可选, 默认为 None) — 当使用 UNetMidBlock2DSimpleCrossAttn 时,是否将交叉注意力与中间块一起使用。 如果将 only_cross_attention 作为单个布尔值给出,并且 mid_block_only_cross_attentionNone,则 only_cross_attention 值将用作 mid_block_only_cross_attention 的值。 否则默认为 False

一个有条件的 2D UNet 模型,它接受一个噪声样本、条件状态和一个时间步,并返回一个样本形状的输出。

此模型继承自 ModelMixin。 查看超类文档,了解为所有模型实现的通用方法(例如下载或保存)。

disable_freeu

< >

( )

禁用 FreeU 机制。

enable_freeu

< >

( s1: float s2: float b1: float b2: float )

参数

  • s1 (float) — 阶段 1 的缩放因子,用于衰减跳跃特征的贡献。 这样做是为了减轻增强去噪过程中的“过度平滑效应”。
  • s2 (float) — 阶段 2 的缩放因子,用于衰减跳跃特征的贡献。 这样做是为了减轻增强去噪过程中的“过度平滑效应”。
  • b1 (float) — 阶段 1 的缩放因子,用于放大骨干特征的贡献。
  • b2 (float) — 阶段 2 的缩放因子,用于放大骨干特征的贡献。

启用来自 https://arxiv.org/abs/2309.11497 的 FreeU 机制。

缩放因子后的后缀表示应用它们的阶段块。

请参阅 官方存储库,了解已知适用于不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)的值组合。

forward

< >

( sample: Tensor timestep: typing.Union[torch.Tensor, float, int] encoder_hidden_states: Tensor class_labels: typing.Optional[torch.Tensor] = None timestep_cond: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None added_cond_kwargs: typing.Optional[typing.Dict[str, torch.Tensor]] = None down_block_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None mid_block_additional_residual: typing.Optional[torch.Tensor] = None down_intrablock_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None return_dict: bool = True ) UNet2DConditionOutputtuple

参数

  • sample (torch.Tensor) — 噪声输入张量,形状如下:(batch, channel, height, width)
  • timestep (torch.Tensorfloatint) — 用于去噪输入的timestep数量。
  • encoder_hidden_states (torch.Tensor) — 编码器隐藏状态,形状为 (batch, sequence_length, feature_dim)
  • class_labels (torch.Tensor, 可选, 默认为 None) — 用于条件调节的可选类别标签。 它们的嵌入将与时间步嵌入求和。
  • timestep_cond — (torch.Tensor, 可选, 默认为 None): 时间步的条件嵌入。 如果提供,则嵌入将与通过 self.time_embedding 层传递的样本求和,以获得时间步嵌入。
  • attention_mask (torch.Tensor, 可选, 默认为 None) — 形状为 (batch, key_tokens) 的注意力掩码应用于 encoder_hidden_states。 如果为 1,则保留掩码;否则,如果为 0,则丢弃掩码。 掩码将转换为偏差,这会将较大的负值添加到对应于“丢弃”标记的注意力分数。
  • cross_attention_kwargs (dict, 可选) — 一个 kwargs 字典,如果指定,则会传递给 AttentionProcessor,如 diffusers.models.attention_processorself.processor 下所定义。
  • added_cond_kwargs — (dict, 可选): 一个 kwargs 字典,包含如果指定则添加到传递给 UNet 块的嵌入的其他嵌入。
  • down_block_additional_residuals — (torch.Tensortuple, 可选): 一个张量元组,如果指定,则会添加到下采样 unet 块的残差中。
  • mid_block_additional_residual — (torch.Tensor, 可选): 一个张量,如果指定,则会添加到中间 unet 块的残差中。
  • down_intrablock_additional_residuals (tuple of torch.Tensor, optional) — additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
  • encoder_attention_mask (torch.Tensor) — A cross-attention mask of shape (batch, sequence_length) is applied to encoder_hidden_states. If True the mask is kept, otherwise if False it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to “discard” tokens.
  • return_dict (bool, optional, defaults to True) — Whether or not to return a UNet2DConditionOutput instead of a plain tuple.

Returns

UNet2DConditionOutput or tuple

If return_dict is True, an UNet2DConditionOutput is returned, otherwise a tuple is returned where the first element is the sample tensor.

The UNet2DConditionModel forward method.

fuse_qkv_projections

< >

( )

Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

This API is 🧪 experimental.

set_attention_slice

< >

( slice_size: typing.Union[str, int, typing.List[int]] = 'auto' )

参数

  • slice_size (str or int or list(int), optional, defaults to "auto") — When "auto", input to the attention heads is halved, so attention is computed in two steps. If "max", maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as attention_head_dim // slice_size. In this case, attention_head_dim must be a multiple of slice_size.

Enable sliced attention computation.

When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed.

set_attn_processor

< >

( processor: typing.Union[diffusers.models.attention_processor.AttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor2_0, diffusers.models.attention_processor.JointAttnProcessor2_0, diffusers.models.attention_processor.PAGJointAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGJointAttnProcessor2_0, diffusers.models.attention_processor.FusedJointAttnProcessor2_0, diffusers.models.attention_processor.AllegroAttnProcessor2_0, diffusers.models.attention_processor.AuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FusedAuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.CogVideoXAttnProcessor2_0, diffusers.models.attention_processor.FusedCogVideoXAttnProcessor2_0, diffusers.models.attention_processor.XFormersAttnAddedKVProcessor, diffusers.models.attention_processor.XFormersAttnProcessor, diffusers.models.attention_processor.XLAFlashAttnProcessor2_0, diffusers.models.attention_processor.AttnProcessorNPU, diffusers.models.attention_processor.AttnProcessor2_0, diffusers.models.attention_processor.MochiVaeAttnProcessor2_0, diffusers.models.attention_processor.MochiAttnProcessor2_0, diffusers.models.attention_processor.StableAudioAttnProcessor2_0, diffusers.models.attention_processor.HunyuanAttnProcessor2_0, diffusers.models.attention_processor.FusedHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.LuminaAttnProcessor2_0, diffusers.models.attention_processor.FusedAttnProcessor2_0, diffusers.models.attention_processor.CustomDiffusionXFormersAttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor2_0, diffusers.models.attention_processor.SlicedAttnProcessor, diffusers.models.attention_processor.SlicedAttnAddedKVProcessor, diffusers.models.attention_processor.SanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleLinearAttention, diffusers.models.attention_processor.SanaMultiscaleAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleAttentionProjection, diffusers.models.attention_processor.IPAdapterAttnProcessor, diffusers.models.attention_processor.IPAdapterAttnProcessor2_0, diffusers.models.attention_processor.IPAdapterXFormersAttnProcessor, diffusers.models.attention_processor.SD3IPAdapterJointAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.LoRAAttnProcessor, diffusers.models.attention_processor.LoRAAttnProcessor2_0, diffusers.models.attention_processor.LoRAXFormersAttnProcessor, diffusers.models.attention_processor.LoRAAttnAddedKVProcessor]]] )

参数

  • processor (dict of AttentionProcessor or only AttentionProcessor) — The instantiated processor class or a dictionary of processor classes that will be set as the processor for all Attention layers.

    If processor is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.

Sets the attention processor to use to compute attention.

set_default_attn_processor

< >

( )

Disables custom attention processors and sets the default attention implementation.

unfuse_qkv_projections

< >

( )

Disables the fused QKV projection if enabled.

This API is 🧪 experimental.

UNet2DConditionOutput

class diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput

< >

( sample: Tensor = None )

参数

  • sample (torch.Tensor of shape (batch_size, num_channels, height, width)) — The hidden states output conditioned on encoder_hidden_states input. Output of last layer of model.

The output of UNet2DConditionModel.

FlaxUNet2DConditionModel

class diffusers.FlaxUNet2DConditionModel

< >

( sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: typing.Tuple[str, ...] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') up_block_types: typing.Tuple[str, ...] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: typing.Union[int, typing.Tuple[int, ...]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int, ...], NoneType] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: dtype = <class 'jax.numpy.float32'> flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False split_head_dim: bool = False transformer_layers_per_block: typing.Union[int, typing.Tuple[int, ...]] = 1 addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: typing.Optional[int] = None parent: typing.Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7f484d8851e0> name: typing.Optional[str] = None )

参数

  • sample_size (int, optional) — The size of the input sample.
  • in_channels (int, optional, defaults to 4) — The number of channels in the input sample.
  • out_channels (int, 可选, 默认为 4) — 输出中的通道数。
  • down_block_types (Tuple[str], 可选, 默认为 ("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")) — 要使用的下采样块的元组。
  • up_block_types (Tuple[str], 可选, 默认为 ("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")) — 要使用的上采样块的元组。
  • mid_block_type (str, 可选, 默认为 "UNetMidBlock2DCrossAttn") — UNet 中间块的块类型,可以是 UNetMidBlock2DCrossAttn 之一。如果为 None,则跳过中间块层。
  • block_out_channels (Tuple[int], 可选, 默认为 (320, 640, 1280, 1280)) — 每个块的输出通道的元组。
  • layers_per_block (int, 可选, 默认为 2) — 每个块的层数。
  • attention_head_dim (intTuple[int], 可选, 默认为 8) — 注意力头的维度。
  • num_attention_heads (intTuple[int], 可选) — 注意力头的数量。
  • cross_attention_dim (int, 可选, 默认为 768) — 交叉注意力特征的维度。
  • dropout (float, 可选, 默认为 0) — 用于下、上和瓶颈块的 Dropout 概率。
  • flip_sin_to_cos (bool, 可选, 默认为 True) — 是否在时间嵌入中将 sin 翻转为 cos。
  • freq_shift (int, 可选, 默认为 0) — 应用于时间嵌入的频率偏移。
  • use_memory_efficient_attention (bool, 可选, 默认为 False) — 启用内存高效注意力,如此处所述。
  • split_head_dim (bool, 可选, 默认为 False) — 是否将 head 维度拆分为自注意力计算的新轴。在大多数情况下,启用此标志应加速 Stable Diffusion 2.x 和 Stable Diffusion XL 的计算。

一个有条件的 2D UNet 模型,它接受一个噪声样本、条件状态和一个时间步,并返回一个样本形状的输出。

此模型继承自 FlaxModelMixin。查看超类文档,了解为其所有模型实现的通用方法(例如下载或保存)。

此模型也是 Flax Linen flax.linen.Module 子类。像使用常规 Flax Linen 模块一样使用它,并参考 Flax 文档了解与其通用用法和行为相关的所有事项。

支持固有的 JAX 功能,例如以下各项

FlaxUNet2DConditionOutput

class diffusers.models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput

< >

( sample: Array )

参数

  • sample (jnp.ndarray,形状为 (batch_size, num_channels, height, width)) — 以 encoder_hidden_states 输入为条件的隐藏状态输出。模型的最后一层的输出。

FlaxUNet2DConditionModel 的输出。

replace

< >

( **updates )

返回一个新对象,该对象将指定的字段替换为新值。

< > Update on GitHub