StableAudioDiTModel
Stable Audio Open 用于音频波形的 Transformer 模型。
StableAudioDiTModel
类 diffusers.StableAudioDiTModel
< 源代码 >( sample_size: int = 1024 in_channels: int = 64 num_layers: int = 24 attention_head_dim: int = 64 num_attention_heads: int = 24 num_key_value_attention_heads: int = 12 out_channels: int = 64 cross_attention_dim: int = 768 time_proj_dim: int = 256 global_states_input_dim: int = 1536 cross_attention_input_dim: int = 768 )
Stable Audio 中引入的 Diffusion Transformer 模型。
参考:https://github.com/Stability-AI/stable-audio-tools
前向
< 源 > ( hidden_states: FloatTensor timestep: LongTensor = None encoder_hidden_states: FloatTensor = None global_hidden_states: FloatTensor = None rotary_embedding: FloatTensor = None return_dict: bool = True attention_mask: Optional = None encoder_attention_mask: Optional = None )
参数
- hidden_states (
torch.FloatTensor
of shape(batch size, in_channels, sequence_len)
) — 输入hidden_states
。 - timestep (
torch.LongTensor
)— 用于指示去噪步骤。 - encoder_hidden_states (
torch.FloatTensor
形状为(batch size, encoder_sequence_len, cross_attention_input_dim)
) — 要使用的条件嵌入(根据输入条件(例如提示)计算的嵌入)。 - global_hidden_states (
torch.FloatTensor
形状为(batch size, global_sequence_len, global_states_input_dim)
) — 将预置到隐藏状态的全局嵌入。 - attention_mask (
torch.Tensor
形状为(batch_size, sequence_len)
, 可选) — 掩码,用于避免在填充令牌索引上执行注意力,通过连接两个文本编码器的注意力掩码形成。在[0, 1]
中选定的掩码值:- 未掩码的令牌为 1,
- 已掩码的令牌为 0。
- encoder_attention_mask (
torch.Tensor
形状为(batch_size, sequence_len)
,可选) — 避免对 padding 标记执行注意力交叉注意索引的掩码,通过将两个文本编码器的注意力掩码组合在一起形成。在[0, 1]
中选择的掩码值:- 1 表示未掩码的标记,
- 0 表示已掩码的标记。
StableAudioDiTModel 前向方法。
set_attn_processor
< 源 > ( processor: Union )
设置注意力处理器以用于计算注意力。
禁用自定义注意力处理器并设置默认注意力实施。