ControlNetModel
ControlNet 模型在论文中提出,向文本到图像扩散模型添加条件控制,作者为 Lvmin Zhang、Anyi Rao、Maneesh Agrawala。它通过对模型进行条件化来提供更高程度的文本到图像生成控制,条件包括边缘图、深度图、分割图和姿态检测的关键点。
论文摘要为:
我们提出了 ControlNet,这是一种神经网络架构,用以向大型、预训练文本到图像扩散模型添加空间条件控制。ControlNet 锁定生产就绪的大型扩散模型,并重复使用其采用数十亿张图像预训练的强大编码层作为坚固骨干来了解各种条件控制。神经架构与“零卷积(零初始化卷积层)”相连接,后者逐步增加参数(从零开始),并确保没有有害噪声会影响微调。我们使用各种条件控制(例如边缘、深度、分割、人体姿态等)对 Stable Diffusion 进行了测试,方法是使用单个或多个条件(无论是否带有提示)。我们已证明 ControlNets 的训练对于小型数据集(<50k)和大型数据集(>1m)都很稳健。大量的结果表明 ControlNet 可以促进更广泛的应用程序来控制图像扩散模型。
从原始格式加载
默认情况下,应该使用ControlNetModel加载 from_pretrained(),但它也可以使用`FromOriginalModelMixin.from_single_file`从原始格式加载,如下所示
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
controlnet = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
ControlNetModel
ControlNet 模型。
forward
< source > ( 样本: 张量 时间步长: 联合 编码器隐藏状态: 张量 控制网络条件: 张量 条件缩放: 浮点 = 1.0 类标签: 可选 = 无 时间步长条件: 可选 = 无 注意力掩码: 可选 = 无 添加条件关键字参数: 可选 = 无 交叉注意力关键字参数: 可选 = 无 猜测模式: 布尔 = 假 返回字典: 布尔 = 真 ) → ControlNet输出或元组
参数
- 样本 (
torch.张量
) — 噪声输入张量。 - timestep (
Union[torch.Tensor, float, int]
) — 去噪输入的时间步数。 - encoder_hidden_states (
torch.Tensor
) — 编码器隐藏状态。 - class_labels(
torch.Tensor
,可选,默认为None
)— 可选的用于调节的类别标签。它们的嵌入将与时间步长嵌入相加。 - timestep_cond (
torch.Tensor
,可选,默认为None
)——时序的额外条件嵌入。如果提供,此嵌入将与通过self.time_embedding
层传递的 timestep_embedding 求和,以获取最终时序嵌入。 - attention_mask (
torch.Tensor
,可选,默认为None
)——形状为(批次,密钥 token)
的注意掩码应用于encoder_hidden_states
。如果为1
则保留掩码,如果为0
则丢弃掩码。掩码将转换为偏置,这会为对应于“丢弃”token 的注意分数添加较大的负值。 - added_cond_kwargs (
dict
)— Stable Diffusion XL UNet 的其他条件。 - cross_attention_kwargs (
dict[str]
,可选,默认值为None
)— 如果指定,则将一个 kwargs 字典传递到AttnProcessor
。
返回
ControlNetOutput 或 tuple
如果 return_dict
是 True
,将返回一个 ControlNetOutput;否则,将返回一个元组,其中第一个元素是样本张量。
ControlNetModel 前向方法。
from_unet
< 源 > ( unet: UNet2DConditionModel controlnet_conditioning_channel_order: str = 'rgb' conditioning_embedding_out_channels: Optional = (16, 32, 96, 256) load_weights_from_unet: bool = True conditioning_channels: int = 3 )
参数
- unet (
UNet2DConditionModel
) — 要复制到 ControlNetModel 的 UNet 模型权重。所有配置选项也会在适用范围内进行复制。
使用 UNet2DConditionModel 实例化一个 ControlNetModel。
set_attention_slice
< ( slice_size: Union )
启用切片注意力计算。
启用此选项时,注意力模块将输入张量拆分为切片,以分多步计算注意力。这对于以小幅降速为代价节省一些内存非常有用。
set_attn_processor
< source > ( 处理器: 联合 )
设置用于计算注意力的注意力处理器。
禁用自定义注意力处理器并设置默认注意力实现。
ControlNetOutput
类 diffusers.models.controlnet.ControlNetOutput
< 源代码 >( down_block_res_samples: 元组 mid_block_res_sample: 张量 )
参数
- down_block_res_samples (
tuple[torch.Tensor]
) — 每个下采样块不同分辨率的下采样激活元组。每个张量应具有以下形状:(batch_size, channel * resolution, height //resolution, width // resolution)
。输出可用于调整原始 UNet 的下采样激活。 - mid_down_block_re_sample (
torch.Tensor
) — 中间块(最低样本分辨率)的激活。每个张量应具有以下形状:(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)
。输出可用于调整原始 UNet 中间块激活。
输出 ControlNetModel。
FlaxControlNetModel
类 diffusers.FlaxControlNetModel
< 源 >( sample_size: int = 32 in_channels: int = 4 down_block_types: Tuple = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') only_cross_attention: Union = False block_out_channels: Tuple = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union = 8 num_attention_heads: Union = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: dtype = <class 'jax.numpy.float32'> flip_sin_to_cos: bool = True freq_shift: int = 0 controlnet_conditioning_channel_order: str = 'rgb' conditioning_embedding_out_channels: Tuple = (16, 32, 96, 256) parent: Union = <flax.linen.module._Sentinel object at 0x7f529575c910> name: Optional = None )
ControlNet 模型。
此模型继承自 FlaxModelMixin。检查超类的文档,了解为所有模型实现的通用方法(例如下载或保存)。
此模型也是 Flax Linen flax.linen.Module
子类。将其用作常规的 Flax Linen 模块,并参考 Flax 文档了解与一般用法和行为相关的所有事项。
支持以下固有的 JAX 特性
FlaxControlNetOutput
类 diffusers.models.controlnet_flax.FlaxControlNetOutput
< 源代码 >( down_block_res_samples: 数组 mid_block_res_sample: 数组 )
“返回新的对象,用新值替换指定字段。”