Transformers 文档
自动类
并获得增强的文档体验
开始使用
自动类
在许多情况下,您想要使用的架构可以从您提供给 from_pretrained()
方法的预训练模型的名称或路径中猜测出来。AutoClasses 的作用就是为您完成这项工作,以便您根据预训练权重/配置/词汇表的名称/路径自动检索相关模型。
实例化 AutoConfig、AutoModel 和 AutoTokenizer 中的一个将直接创建一个相关架构的类。例如
model = AutoModel.from_pretrained("google-bert/bert-base-cased")
将创建一个模型,该模型是 BertModel 的实例。
每个任务和每个后端(PyTorch、TensorFlow 或 Flax)都有一个 AutoModel
类。
扩展自动类
每个自动类都有一种方法可以使用您的自定义类进行扩展。例如,如果您定义了一个自定义模型类 NewModel
,请确保您有一个 NewModelConfig
,然后您可以像这样将它们添加到自动类中
from transformers import AutoConfig, AutoModel
AutoConfig.register("new-model", NewModelConfig)
AutoModel.register(NewModelConfig, NewModel)
然后您就可以像平常一样使用自动类了!
如果您的 NewModelConfig
是 PretrainedConfig 的子类,请确保其 model_type
属性设置为您注册配置时使用的相同键(此处为 "new-model"
)。
同样,如果您的 NewModel
是 PreTrainedModel 的子类,请确保其 config_class
属性设置为您注册模型时使用的相同类(此处为 NewModelConfig
)。
AutoConfig
这是一个通用配置类,当使用 from_pretrained() 类方法创建时,它将被实例化为库的配置类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是:- 字符串,托管在 huggingface.co 上的模型仓库内的预训练模型配置的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 方法或 save_pretrained() 方法保存的配置文件,例如,
./my_model_directory/
。 - 保存的配置 JSON 文件的路径或 URL,例如,
./my_model_directory/configuration.json
。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,并覆盖缓存版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - return_unused_kwargs (
bool
, 可选, 默认为False
) — 如果为False
,则此函数仅返回最终配置对象。如果为
True
,则此函数返回Tuple(config, unused_kwargs)
,其中 unused_kwargs 是一个字典,由键/值对组成,其键不是配置属性:即,kwargs
中未用于更新config
且在其他情况下被忽略的部分。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应为信任的存储库设置为True
,并且您已阅读代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - kwargs(additional 关键字参数, 可选) — kwargs 中任何键是配置属性的值将用于覆盖加载的值。关于键不是配置属性的键/值对的行为由
return_unused_kwargs
关键字参数控制。
从预训练模型配置实例化库的配置类之一。
要实例化的配置类是根据加载的配置对象的 model_type
属性选择的,或者在缺少该属性时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertConfig (ALBERT 模型)
- align — AlignConfig (ALIGN 模型)
- altclip — AltCLIPConfig (AltCLIP 模型)
- aria — AriaConfig (Aria 模型)
- aria_text — AriaTextConfig (AriaText 模型)
- audio-spectrogram-transformer — ASTConfig (Audio Spectrogram Transformer 模型)
- autoformer — AutoformerConfig (Autoformer 模型)
- aya_vision — AyaVisionConfig (AyaVision 模型)
- bamba — BambaConfig (Bamba 模型)
- bark — BarkConfig (Bark 模型)
- bart — BartConfig (BART 模型)
- beit — BeitConfig (BEiT 模型)
- bert — BertConfig (BERT 模型)
- bert-generation — BertGenerationConfig (Bert Generation 模型)
- big_bird — BigBirdConfig (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusConfig (BigBird-Pegasus 模型)
- biogpt — BioGptConfig (BioGpt 模型)
- bit — BitConfig (BiT 模型)
- blenderbot — BlenderbotConfig (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallConfig (BlenderbotSmall 模型)
- blip — BlipConfig (BLIP 模型)
- blip-2 — Blip2Config (BLIP-2 模型)
- bloom — BloomConfig (BLOOM 模型)
- bridgetower — BridgeTowerConfig (BridgeTower 模型)
- bros — BrosConfig (BROS 模型)
- camembert — CamembertConfig (CamemBERT 模型)
- canine — CanineConfig (CANINE 模型)
- chameleon — ChameleonConfig (Chameleon 模型)
- chinese_clip — ChineseCLIPConfig (Chinese-CLIP 模型)
- chinese_clip_vision_model — ChineseCLIPVisionConfig (ChineseCLIPVisionModel 模型)
- clap — ClapConfig (CLAP 模型)
- clip — CLIPConfig (CLIP 模型)
- clip_text_model — CLIPTextConfig (CLIPTextModel 模型)
- clip_vision_model — CLIPVisionConfig (CLIPVisionModel 模型)
- clipseg — CLIPSegConfig (CLIPSeg 模型)
- clvp — ClvpConfig (CLVP 模型)
- code_llama — LlamaConfig (CodeLlama 模型)
- codegen — CodeGenConfig (CodeGen 模型)
- cohere — CohereConfig (Cohere 模型)
- cohere2 — Cohere2Config (Cohere2 模型)
- colpali — ColPaliConfig (ColPali 模型)
- conditional_detr — ConditionalDetrConfig (Conditional DETR 模型)
- convbert — ConvBertConfig (ConvBERT 模型)
- convnext — ConvNextConfig (ConvNeXT 模型)
- convnextv2 — ConvNextV2Config (ConvNeXTV2 模型)
- cpmant — CpmAntConfig (CPM-Ant 模型)
- ctrl — CTRLConfig (CTRL 模型)
- cvt — CvtConfig (CvT 模型)
- dab-detr — DabDetrConfig (DAB-DETR 模型)
- dac — DacConfig (DAC 模型)
- data2vec-audio — Data2VecAudioConfig (Data2VecAudio 模型)
- data2vec-text — Data2VecTextConfig (Data2VecText 模型)
- data2vec-vision — Data2VecVisionConfig (Data2VecVision 模型)
- dbrx — DbrxConfig (DBRX 模型)
- deberta — DebertaConfig (DeBERTa 模型)
- deberta-v2 — DebertaV2Config (DeBERTa-v2 模型)
- decision_transformer — DecisionTransformerConfig (Decision Transformer 模型)
- deformable_detr — DeformableDetrConfig (Deformable DETR 模型)
- deit — DeiTConfig (DeiT 模型)
- depth_anything — DepthAnythingConfig (Depth Anything 模型)
- depth_pro — DepthProConfig (DepthPro 模型)
- deta — DetaConfig (DETA 模型)
- detr — DetrConfig (DETR 模型)
- diffllama — DiffLlamaConfig (DiffLlama 模型)
- dinat — DinatConfig (DiNAT 模型)
- dinov2 — Dinov2Config (DINOv2 模型)
- dinov2_with_registers — Dinov2WithRegistersConfig (DINOv2 with Registers 模型)
- distilbert — DistilBertConfig (DistilBERT 模型)
- donut-swin — DonutSwinConfig (DonutSwin 模型)
- dpr — DPRConfig (DPR 模型)
- dpt — DPTConfig (DPT 模型)
- efficientformer — EfficientFormerConfig (EfficientFormer 模型)
- efficientnet — EfficientNetConfig (EfficientNet 模型)
- electra — ElectraConfig (ELECTRA 模型)
- emu3 — Emu3Config (Emu3 模型)
- encodec — EncodecConfig (EnCodec 模型)
- encoder-decoder — EncoderDecoderConfig (Encoder decoder 模型)
- ernie — ErnieConfig (ERNIE 模型)
- ernie_m — ErnieMConfig (ErnieM 模型)
- esm — EsmConfig (ESM 模型)
- falcon — FalconConfig (Falcon 模型)
- falcon_mamba — FalconMambaConfig (FalconMamba 模型)
- fastspeech2_conformer — FastSpeech2ConformerConfig (FastSpeech2Conformer 模型)
- flaubert — FlaubertConfig (FlauBERT 模型)
- flava — FlavaConfig (FLAVA 模型)
- fnet — FNetConfig (FNet 模型)
- focalnet — FocalNetConfig (FocalNet 模型)
- fsmt — FSMTConfig (FairSeq 机器翻译模型)
- funnel — FunnelConfig (Funnel Transformer 模型)
- fuyu — FuyuConfig (Fuyu 模型)
- gemma — GemmaConfig (Gemma 模型)
- gemma2 — Gemma2Config (Gemma2 模型)
- gemma3 — Gemma3Config (Gemma3ForConditionalGeneration 模型)
- gemma3_text — Gemma3TextConfig (Gemma3ForCausalLM 模型)
- git — GitConfig (GIT 模型)
- glm — GlmConfig (GLM 模型)
- glpn — GLPNConfig (GLPN 模型)
- got_ocr2 — GotOcr2Config (GOT-OCR2 模型)
- gpt-sw3 — GPT2Config (GPT-Sw3 模型)
- gpt2 — GPT2Config (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeConfig (GPTBigCode 模型)
- gpt_neo — GPTNeoConfig (GPT Neo 模型)
- gpt_neox — GPTNeoXConfig (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseConfig (GPT NeoX Japanese 模型)
- gptj — GPTJConfig (GPT-J 模型)
- gptsan-japanese — GPTSanJapaneseConfig (GPTSAN-japanese 模型)
- granite — GraniteConfig (Granite 模型)
- granitemoe — GraniteMoeConfig (GraniteMoeMoe 模型)
- granitemoeshared — GraniteMoeSharedConfig (GraniteMoeSharedMoe 模型)
- granitevision — LlavaNextConfig (LLaVA-NeXT 模型)
- graphormer — GraphormerConfig (Graphormer 模型)
- grounding-dino — GroundingDinoConfig (Grounding DINO 模型)
- groupvit — GroupViTConfig (GroupViT 模型)
- helium — HeliumConfig (Helium 模型)
- hiera — HieraConfig (Hiera 模型)
- hubert — HubertConfig (Hubert 模型)
- ibert — IBertConfig (I-BERT 模型)
- idefics — IdeficsConfig (IDEFICS 模型)
- idefics2 — Idefics2Config (Idefics2 模型)
- idefics3 — Idefics3Config (Idefics3 模型)
- idefics3_vision — Idefics3VisionConfig (Idefics3VisionTransformer 模型)
- ijepa — IJepaConfig (I-JEPA 模型)
- imagegpt — ImageGPTConfig (ImageGPT 模型)
- informer — InformerConfig (Informer 模型)
- instructblip — InstructBlipConfig (InstructBLIP 模型)
- instructblipvideo — InstructBlipVideoConfig (InstructBlipVideo 模型)
- jamba — JambaConfig (Jamba 模型)
- jetmoe — JetMoeConfig (JetMoe 模型)
- jukebox — JukeboxConfig (Jukebox 模型)
- kosmos-2 — Kosmos2Config (KOSMOS-2 模型)
- layoutlm — LayoutLMConfig (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2Config (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Config (LayoutLMv3 模型)
- led — LEDConfig (LED 模型)
- levit — LevitConfig (LeViT 模型)
- lilt — LiltConfig (LiLT 模型)
- llama — LlamaConfig (LLaMA 模型)
- llava — LlavaConfig (LLaVa 模型)
- llava_next — LlavaNextConfig (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoConfig (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionConfig (LLaVA-Onevision 模型)
- longformer — LongformerConfig (Longformer 模型)
- longt5 — LongT5Config (LongT5 模型)
- luke — LukeConfig (LUKE 模型)
- lxmert — LxmertConfig (LXMERT 模型)
- m2m_100 — M2M100Config (M2M100 模型)
- mamba — MambaConfig (Mamba 模型)
- mamba2 — Mamba2Config (mamba2 模型)
- marian — MarianConfig (Marian 模型)
- markuplm — MarkupLMConfig (MarkupLM 模型)
- mask2former — Mask2FormerConfig (Mask2Former 模型)
- maskformer — MaskFormerConfig (MaskFormer 模型)
- maskformer-swin —
MaskFormerSwinConfig
(MaskFormerSwin 模型) - mbart — MBartConfig (mBART 模型)
- mctct — MCTCTConfig (M-CTC-T 模型)
- mega — MegaConfig (MEGA 模型)
- megatron-bert — MegatronBertConfig (Megatron-BERT 模型)
- mgp-str — MgpstrConfig (MGP-STR 模型)
- mimi — MimiConfig (Mimi 模型)
- mistral — MistralConfig (Mistral 模型)
- mistral3 — Mistral3Config (Mistral3 模型)
- mixtral — MixtralConfig (Mixtral 模型)
- mllama — MllamaConfig (Mllama 模型)
- mobilebert — MobileBertConfig (MobileBERT 模型)
- mobilenet_v1 — MobileNetV1Config (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2Config (MobileNetV2 模型)
- mobilevit — MobileViTConfig (MobileViT 模型)
- mobilevitv2 — MobileViTV2Config (MobileViTV2 模型)
- modernbert — ModernBertConfig (ModernBERT 模型)
- moonshine — MoonshineConfig (Moonshine 模型)
- moshi — MoshiConfig (Moshi 模型)
- mpnet — MPNetConfig (MPNet 模型)
- mpt — MptConfig (MPT 模型)
- mra — MraConfig (MRA 模型)
- mt5 — MT5Config (MT5 模型)
- musicgen — MusicgenConfig (MusicGen 模型)
- musicgen_melody — MusicgenMelodyConfig (MusicGen Melody 模型)
- mvp — MvpConfig (MVP 模型)
- nat — NatConfig (NAT 模型)
- nemotron — NemotronConfig (Nemotron 模型)
- nezha — NezhaConfig (Nezha 模型)
- nllb-moe — NllbMoeConfig (NLLB-MOE 模型)
- nougat — VisionEncoderDecoderConfig (Nougat 模型)
- nystromformer — NystromformerConfig (Nyströmformer 模型)
- olmo — OlmoConfig (OLMo 模型)
- olmo2 — Olmo2Config (OLMo2 模型)
- olmoe — OlmoeConfig (OLMoE 模型)
- omdet-turbo — OmDetTurboConfig (OmDet-Turbo 模型)
- oneformer — OneFormerConfig (OneFormer 模型)
- open-llama — OpenLlamaConfig (OpenLlama 模型)
- openai-gpt — OpenAIGPTConfig (OpenAI GPT 模型)
- opt — OPTConfig (OPT 模型)
- owlv2 — Owlv2Config (OWLv2 模型)
- owlvit — OwlViTConfig (OWL-ViT 模型)
- paligemma — PaliGemmaConfig (PaliGemma 模型)
- patchtsmixer — PatchTSMixerConfig (PatchTSMixer 模型)
- patchtst — PatchTSTConfig (PatchTST 模型)
- pegasus — PegasusConfig (Pegasus 模型)
- pegasus_x — PegasusXConfig (PEGASUS-X 模型)
- perceiver — PerceiverConfig (Perceiver 模型)
- persimmon — PersimmonConfig (Persimmon 模型)
- phi — PhiConfig (Phi 模型)
- phi3 — Phi3Config (Phi3 模型)
- phimoe — PhimoeConfig (Phimoe 模型)
- pix2struct — Pix2StructConfig (Pix2Struct 模型)
- pixtral — PixtralVisionConfig (Pixtral 模型)
- plbart — PLBartConfig (PLBart 模型)
- poolformer — PoolFormerConfig (PoolFormer 模型)
- pop2piano — Pop2PianoConfig (Pop2Piano 模型)
- prompt_depth_anything — PromptDepthAnythingConfig (PromptDepthAnything 模型)
- prophetnet — ProphetNetConfig (ProphetNet 模型)
- pvt — PvtConfig (PVT 模型)
- pvt_v2 — PvtV2Config (PVTv2 模型)
- qdqbert — QDQBertConfig (QDQBert 模型)
- qwen2 — Qwen2Config (Qwen2 模型)
- qwen2_5_vl — Qwen2_5_VLConfig (Qwen2_5_VL 模型)
- qwen2_audio — Qwen2AudioConfig (Qwen2Audio 模型)
- qwen2_audio_encoder — Qwen2AudioEncoderConfig (Qwen2AudioEncoder 模型)
- qwen2_moe — Qwen2MoeConfig (Qwen2MoE 模型)
- qwen2_vl — Qwen2VLConfig (Qwen2VL 模型)
- rag — RagConfig (RAG 模型)
- realm — RealmConfig (REALM 模型)
- recurrent_gemma — RecurrentGemmaConfig (RecurrentGemma 模型)
- reformer — ReformerConfig (Reformer 模型)
- regnet — RegNetConfig (RegNet 模型)
- rembert — RemBertConfig (RemBERT 模型)
- resnet — ResNetConfig (ResNet 模型)
- retribert — RetriBertConfig (RetriBERT 模型)
- roberta — RobertaConfig (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormConfig (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertConfig (RoCBert 模型)
- roformer — RoFormerConfig (RoFormer 模型)
- rt_detr — RTDetrConfig (RT-DETR 模型)
- rt_detr_resnet — RTDetrResNetConfig (RT-DETR-ResNet 模型)
- rt_detr_v2 — RTDetrV2Config (RT-DETRv2 模型)
- rwkv — RwkvConfig (RWKV 模型)
- sam — SamConfig (SAM 模型)
- seamless_m4t — SeamlessM4TConfig (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2Config (SeamlessM4Tv2 模型)
- segformer — SegformerConfig (SegFormer 模型)
- seggpt — SegGptConfig (SegGPT 模型)
- sew — SEWConfig (SEW 模型)
- sew-d — SEWDConfig (SEW-D 模型)
- shieldgemma2 — ShieldGemma2Config (Shieldgemma2 模型)
- siglip — SiglipConfig (SigLIP 模型)
- siglip2 — Siglip2Config (SigLIP2 模型)
- siglip_vision_model — SiglipVisionConfig (SiglipVisionModel 模型)
- smolvlm — SmolVLMConfig (SmolVLM 模型)
- smolvlm_vision — SmolVLMVisionConfig (SmolVLMVisionTransformer 模型)
- speech-encoder-decoder — SpeechEncoderDecoderConfig (Speech Encoder decoder 模型)
- speech_to_text — Speech2TextConfig (Speech2Text 模型)
- speech_to_text_2 — Speech2Text2Config (Speech2Text2 模型)
- speecht5 — SpeechT5Config (SpeechT5 模型)
- splinter — SplinterConfig (Splinter 模型)
- squeezebert — SqueezeBertConfig (SqueezeBERT 模型)
- stablelm — StableLmConfig (StableLm 模型)
- starcoder2 — Starcoder2Config (Starcoder2 模型)
- superglue — SuperGlueConfig (SuperGlue 模型)
- superpoint — SuperPointConfig (SuperPoint 模型)
- swiftformer — SwiftFormerConfig (SwiftFormer 模型)
- swin — SwinConfig (Swin Transformer 模型)
- swin2sr — Swin2SRConfig (Swin2SR 模型)
- swinv2 — Swinv2Config (Swin Transformer V2 模型)
- switch_transformers — SwitchTransformersConfig (SwitchTransformers 模型)
- t5 — T5Config (T5 模型)
- table-transformer — TableTransformerConfig (Table Transformer 模型)
- tapas — TapasConfig (TAPAS 模型)
- textnet — TextNetConfig (TextNet 模型)
- time_series_transformer — TimeSeriesTransformerConfig (Time Series Transformer 模型)
- timesformer — TimesformerConfig (TimeSformer 模型)
- timm_backbone — TimmBackboneConfig (TimmBackbone 模型)
- timm_wrapper — TimmWrapperConfig (TimmWrapperModel 模型)
- trajectory_transformer — TrajectoryTransformerConfig (Trajectory Transformer 模型)
- transfo-xl — TransfoXLConfig (Transformer-XL 模型)
- trocr — TrOCRConfig (TrOCR 模型)
- tvlt — TvltConfig (TVLT 模型)
- tvp — TvpConfig (TVP 模型)
- udop — UdopConfig (UDOP 模型)
- umt5 — UMT5Config (UMT5 模型)
- unispeech — UniSpeechConfig (UniSpeech 模型)
- unispeech-sat — UniSpeechSatConfig (UniSpeechSat 模型)
- univnet — UnivNetConfig (UnivNet 模型)
- upernet — UperNetConfig (UPerNet 模型)
- van — VanConfig (VAN 模型)
- video_llava — VideoLlavaConfig (VideoLlava 模型)
- videomae — VideoMAEConfig (VideoMAE 模型)
- vilt — ViltConfig (ViLT 模型)
- vipllava — VipLlavaConfig (VipLlava 模型)
- vision-encoder-decoder — VisionEncoderDecoderConfig (Vision Encoder decoder 模型)
- vision-text-dual-encoder — VisionTextDualEncoderConfig (VisionTextDualEncoder 模型)
- visual_bert — VisualBertConfig (VisualBERT 模型)
- vit — ViTConfig (ViT 模型)
- vit_hybrid — ViTHybridConfig (ViT Hybrid 模型)
- vit_mae — ViTMAEConfig (ViTMAE 模型)
- vit_msn — ViTMSNConfig (ViTMSN 模型)
- vitdet — VitDetConfig (VitDet 模型)
- vitmatte — VitMatteConfig (ViTMatte 模型)
- vitpose — VitPoseConfig (ViTPose 模型)
- vitpose_backbone —
VitPoseBackboneConfig
(ViTPoseBackbone 模型) - vits — VitsConfig (VITS 模型)
- vivit — VivitConfig (ViViT 模型)
- wav2vec2 — Wav2Vec2Config (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertConfig (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerConfig (Wav2Vec2-Conformer 模型)
- wavlm — WavLMConfig (WavLM 模型)
- whisper — WhisperConfig (Whisper 模型)
- xclip — XCLIPConfig (X-CLIP 模型)
- xglm — XGLMConfig (XGLM 模型)
- xlm — XLMConfig (XLM 模型)
- xlm-prophetnet — XLMProphetNetConfig (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaConfig (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLConfig (XLM-RoBERTa-XL 模型)
- xlnet — XLNetConfig (XLNet 模型)
- xmod — XmodConfig (X-MOD 模型)
- yolos — YolosConfig (YOLOS 模型)
- yoso — YosoConfig (YOSO 模型)
- zamba — ZambaConfig (Zamba 模型)
- zamba2 — Zamba2Config (Zamba2 模型)
- zoedepth — ZoeDepthConfig (ZoeDepth 模型)
示例
>>> from transformers import AutoConfig
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
>>> # Download configuration from huggingface.co (user-uploaded) and cache.
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
>>> # Load a specific configuration file.
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
>>> # Change some config attributes when loading a pretrained config.
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
>>> config.output_attentions
True
>>> config, unused_kwargs = AutoConfig.from_pretrained(
... "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
... )
>>> config.output_attentions
True
>>> unused_kwargs
{'foo': False}
注册
< 源代码 >( model_type config exist_ok = False )
参数
- model_type (
str
) — 模型类型,例如 “bert” 或 “gpt”。 - config (PretrainedConfig) — 要注册的配置。
为此类注册新的配置。
AutoTokenizer
这是一个通用分词器类,当使用 AutoTokenizer.from_pretrained() 类方法创建时,它将被实例化为库中的分词器类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_pretrained
< 源代码 >( pretrained_model_name_or_path *inputs **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预定义分词器的模型 ID。
- 一个目录的路径,其中包含分词器所需的词汇表文件,例如使用 save_pretrained() 方法保存的,例如
./my_model_directory/
。 - 如果分词器仅需要单个词汇表文件(如 Bert 或 XLNet),则可以是单个已保存词汇表文件的路径或 URL,例如:
./my_model_directory/vocab.txt
。(不适用于所有派生类)
- inputs (额外的的位置参数,可选) — 将传递给 Tokenizer
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于确定要实例化的分词器类的配置对象。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,并覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在默认情况下,所有下载都尽可能恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - subfolder (
str
, 可选) — 如果相关文件位于 huggingface.co 上的模型仓库的子文件夹中(例如,对于 facebook/rag-token-base),请在此处指定它。 - use_fast (
bool
, 可选, 默认为True
) — 如果给定模型支持 快速的基于 Rust 的分词器,则使用它。如果给定模型没有可用的快速分词器,则返回正常的基于 Python 的分词器。 - tokenizer_type (
str
, 可选) — 要加载的分词器类型。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许使用 Hub 上自定义模型文件中定义的自定义模型。此选项仅应针对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - kwargs (额外的关键字参数,可选) — 将传递给 Tokenizer
__init__()
方法。可用于设置特殊标记,例如bos_token
、eos_token
、unk_token
、sep_token
、pad_token
、cls_token
、mask_token
、additional_special_tokens
。 有关更多详细信息,请参见__init__()
中的参数。
从预训练模型词汇表实例化库中的一个分词器类。
要实例化的分词器类是基于配置对象的 model_type
属性选择的(可以作为参数传递,也可以在可能的情况下从 pretrained_model_name_or_path
加载),或者在缺少该属性时,通过回退到对 pretrained_model_name_or_path
使用模式匹配来选择。
- albert — AlbertTokenizer 或 AlbertTokenizerFast (ALBERT 模型)
- align — BertTokenizer 或 BertTokenizerFast (ALIGN 模型)
- aria — LlamaTokenizer 或 LlamaTokenizerFast (Aria 模型)
- aya_vision — CohereTokenizerFast (AyaVision 模型)
- bark — BertTokenizer 或 BertTokenizerFast (Bark 模型)
- bart — BartTokenizer 或 BartTokenizerFast (BART 模型)
- barthez — BarthezTokenizer 或 BarthezTokenizerFast (BARThez 模型)
- bartpho — BartphoTokenizer (BARTpho 模型)
- bert — BertTokenizer 或 BertTokenizerFast (BERT 模型)
- bert-generation — BertGenerationTokenizer (Bert Generation 模型)
- bert-japanese — BertJapaneseTokenizer (BertJapanese 模型)
- bertweet — BertweetTokenizer (BERTweet 模型)
- big_bird — BigBirdTokenizer 或 BigBirdTokenizerFast (BigBird 模型)
- bigbird_pegasus — PegasusTokenizer 或 PegasusTokenizerFast (BigBird-Pegasus 模型)
- biogpt — BioGptTokenizer (BioGpt 模型)
- blenderbot — BlenderbotTokenizer 或 BlenderbotTokenizerFast (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallTokenizer (BlenderbotSmall 模型)
- blip — BertTokenizer 或 BertTokenizerFast (BLIP 模型)
- blip-2 — GPT2Tokenizer 或 GPT2TokenizerFast (BLIP-2 模型)
- bloom — BloomTokenizerFast (BLOOM 模型)
- bridgetower — RobertaTokenizer 或 RobertaTokenizerFast (BridgeTower 模型)
- bros — BertTokenizer 或 BertTokenizerFast (BROS 模型)
- byt5 — ByT5Tokenizer (ByT5 模型)
- camembert — CamembertTokenizer 或 CamembertTokenizerFast (CamemBERT 模型)
- canine — CanineTokenizer (CANINE 模型)
- chameleon — LlamaTokenizer 或 LlamaTokenizerFast (Chameleon 模型)
- chinese_clip — BertTokenizer 或 BertTokenizerFast (Chinese-CLIP 模型)
- clap — RobertaTokenizer 或 RobertaTokenizerFast (CLAP 模型)
- clip — CLIPTokenizer 或 CLIPTokenizerFast (CLIP 模型)
- clipseg — CLIPTokenizer 或 CLIPTokenizerFast (CLIPSeg 模型)
- clvp — ClvpTokenizer (CLVP 模型)
- code_llama — CodeLlamaTokenizer 或 CodeLlamaTokenizerFast (CodeLlama 模型)
- codegen — CodeGenTokenizer 或 CodeGenTokenizerFast (CodeGen 模型)
- cohere — CohereTokenizerFast (Cohere 模型)
- cohere2 — CohereTokenizerFast (Cohere2 模型)
- colpali — LlamaTokenizer 或 LlamaTokenizerFast (ColPali 模型)
- convbert — ConvBertTokenizer 或 ConvBertTokenizerFast (ConvBERT 模型)
- cpm — CpmTokenizer 或 CpmTokenizerFast (CPM 模型)
- cpmant — CpmAntTokenizer (CPM-Ant 模型)
- ctrl — CTRLTokenizer (CTRL 模型)
- data2vec-audio — Wav2Vec2CTCTokenizer (Data2VecAudio 模型)
- data2vec-text — RobertaTokenizer 或 RobertaTokenizerFast (Data2VecText 模型)
- dbrx — GPT2Tokenizer 或 GPT2TokenizerFast (DBRX 模型)
- deberta — DebertaTokenizer 或 DebertaTokenizerFast (DeBERTa 模型)
- deberta-v2 — DebertaV2Tokenizer 或 DebertaV2TokenizerFast (DeBERTa-v2 模型)
- diffllama — LlamaTokenizer 或 LlamaTokenizerFast (DiffLlama 模型)
- distilbert — DistilBertTokenizer 或 DistilBertTokenizerFast (DistilBERT 模型)
- dpr — DPRQuestionEncoderTokenizer 或 DPRQuestionEncoderTokenizerFast (DPR 模型)
- electra — ElectraTokenizer 或 ElectraTokenizerFast (ELECTRA 模型)
- emu3 — GPT2Tokenizer 或 GPT2TokenizerFast (Emu3 模型)
- ernie — BertTokenizer 或 BertTokenizerFast (ERNIE 模型)
- ernie_m — ErnieMTokenizer (ErnieM 模型)
- esm — EsmTokenizer (ESM 模型)
- falcon — PreTrainedTokenizerFast (Falcon 模型)
- falcon_mamba — GPTNeoXTokenizerFast (FalconMamba 模型)
- fastspeech2_conformer — (FastSpeech2Conformer 模型)
- flaubert — FlaubertTokenizer (FlauBERT 模型)
- fnet — FNetTokenizer 或 FNetTokenizerFast (FNet 模型)
- fsmt — FSMTTokenizer (FairSeq 机器翻译模型)
- funnel — FunnelTokenizer 或 FunnelTokenizerFast (Funnel Transformer 模型)
- gemma — GemmaTokenizer 或 GemmaTokenizerFast (Gemma 模型)
- gemma2 — GemmaTokenizer 或 GemmaTokenizerFast (Gemma2 模型)
- gemma3 — GemmaTokenizer 或 GemmaTokenizerFast (Gemma3ForConditionalGeneration 模型)
- gemma3_text — GemmaTokenizer 或 GemmaTokenizerFast (Gemma3ForCausalLM 模型)
- git — BertTokenizer 或 BertTokenizerFast (GIT 模型)
- glm — PreTrainedTokenizerFast (GLM 模型)
- gpt-sw3 — GPTSw3Tokenizer (GPT-Sw3 模型)
- gpt2 — GPT2Tokenizer 或 GPT2TokenizerFast (OpenAI GPT-2 模型)
- gpt_bigcode — GPT2Tokenizer 或 GPT2TokenizerFast (GPTBigCode 模型)
- gpt_neo — GPT2Tokenizer 或 GPT2TokenizerFast (GPT Neo 模型)
- gpt_neox — GPTNeoXTokenizerFast (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseTokenizer (GPT NeoX 日语模型)
- gptj — GPT2Tokenizer 或 GPT2TokenizerFast (GPT-J 模型)
- gptsan-japanese — GPTSanJapaneseTokenizer (GPTSAN-日语模型)
- grounding-dino — BertTokenizer 或 BertTokenizerFast (Grounding DINO 模型)
- groupvit — CLIPTokenizer 或 CLIPTokenizerFast (GroupViT 模型)
- helium — PreTrainedTokenizerFast (Helium 模型)
- herbert — HerbertTokenizer 或 HerbertTokenizerFast (HerBERT 模型)
- hubert — Wav2Vec2CTCTokenizer (Hubert 模型)
- ibert — RobertaTokenizer 或 RobertaTokenizerFast (I-BERT 模型)
- idefics — LlamaTokenizerFast (IDEFICS 模型)
- idefics2 — LlamaTokenizer 或 LlamaTokenizerFast (Idefics2 模型)
- idefics3 — LlamaTokenizer 或 LlamaTokenizerFast (Idefics3 模型)
- instructblip — GPT2Tokenizer 或 GPT2TokenizerFast (InstructBLIP 模型)
- instructblipvideo — GPT2Tokenizer 或 GPT2TokenizerFast (InstructBlipVideo 模型)
- jamba — LlamaTokenizer 或 LlamaTokenizerFast (Jamba 模型)
- jetmoe — LlamaTokenizer 或 LlamaTokenizerFast (JetMoe 模型)
- jukebox — JukeboxTokenizer (Jukebox 模型)
- kosmos-2 — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (KOSMOS-2 模型)
- layoutlm — LayoutLMTokenizer 或 LayoutLMTokenizerFast (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2Tokenizer 或 LayoutLMv2TokenizerFast (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Tokenizer 或 LayoutLMv3TokenizerFast (LayoutLMv3 模型)
- layoutxlm — LayoutXLMTokenizer 或 LayoutXLMTokenizerFast (LayoutXLM 模型)
- led — LEDTokenizer 或 LEDTokenizerFast (LED 模型)
- lilt — LayoutLMv3Tokenizer 或 LayoutLMv3TokenizerFast (LiLT 模型)
- llama — LlamaTokenizer 或 LlamaTokenizerFast (LLaMA 模型)
- llava — LlamaTokenizer 或 LlamaTokenizerFast (LLaVa 模型)
- llava_next — LlamaTokenizer 或 LlamaTokenizerFast (LLaVA-NeXT 模型)
- llava_next_video — LlamaTokenizer 或 LlamaTokenizerFast (LLaVa-NeXT-Video 模型)
- llava_onevision — LlamaTokenizer 或 LlamaTokenizerFast (LLaVA-Onevision 模型)
- longformer — LongformerTokenizer 或 LongformerTokenizerFast (Longformer 模型)
- longt5 — T5Tokenizer 或 T5TokenizerFast (LongT5 模型)
- luke — LukeTokenizer (LUKE 模型)
- lxmert — LxmertTokenizer 或 LxmertTokenizerFast (LXMERT 模型)
- m2m_100 — M2M100Tokenizer (M2M100 模型)
- mamba — GPTNeoXTokenizerFast (Mamba 模型)
- mamba2 — GPTNeoXTokenizerFast (mamba2 模型)
- marian — MarianTokenizer (Marian 模型)
- mbart — MBartTokenizer 或 MBartTokenizerFast (mBART 模型)
- mbart50 — MBart50Tokenizer 或 MBart50TokenizerFast (mBART-50 模型)
- mega — RobertaTokenizer 或 RobertaTokenizerFast (MEGA 模型)
- megatron-bert — BertTokenizer 或 BertTokenizerFast (Megatron-BERT 模型)
- mgp-str — MgpstrTokenizer (MGP-STR 模型)
- mistral — LlamaTokenizer 或 LlamaTokenizerFast (Mistral 模型)
- mixtral — LlamaTokenizer 或 LlamaTokenizerFast (Mixtral 模型)
- mllama — LlamaTokenizer 或 LlamaTokenizerFast (Mllama 模型)
- mluke — MLukeTokenizer (mLUKE 模型)
- mobilebert — MobileBertTokenizer 或 MobileBertTokenizerFast (MobileBERT 模型)
- modernbert — PreTrainedTokenizerFast (ModernBERT 模型)
- moonshine — PreTrainedTokenizerFast (Moonshine 模型)
- moshi — PreTrainedTokenizerFast (Moshi 模型)
- mpnet — MPNetTokenizer 或 MPNetTokenizerFast (MPNet 模型)
- mpt — GPTNeoXTokenizerFast (MPT 模型)
- mra — RobertaTokenizer 或 RobertaTokenizerFast (MRA 模型)
- mt5 — MT5Tokenizer 或 MT5TokenizerFast (MT5 模型)
- musicgen — T5Tokenizer 或 T5TokenizerFast (MusicGen 模型)
- musicgen_melody — T5Tokenizer 或 T5TokenizerFast (MusicGen Melody 模型)
- mvp — MvpTokenizer 或 MvpTokenizerFast (MVP 模型)
- myt5 — MyT5Tokenizer (myt5 模型)
- nemotron — PreTrainedTokenizerFast (Nemotron 模型)
- nezha — BertTokenizer 或 BertTokenizerFast (Nezha 模型)
- nllb — NllbTokenizer 或 NllbTokenizerFast (NLLB 模型)
- nllb-moe — NllbTokenizer 或 NllbTokenizerFast (NLLB-MOE 模型)
- nystromformer — AlbertTokenizer 或 AlbertTokenizerFast (Nyströmformer 模型)
- olmo — GPTNeoXTokenizerFast (OLMo 模型)
- olmo2 — GPTNeoXTokenizerFast (OLMo2 模型)
- olmoe — GPTNeoXTokenizerFast (OLMoE 模型)
- omdet-turbo — CLIPTokenizer 或 CLIPTokenizerFast (OmDet-Turbo 模型)
- oneformer — CLIPTokenizer 或 CLIPTokenizerFast (OneFormer 模型)
- openai-gpt — OpenAIGPTTokenizer 或 OpenAIGPTTokenizerFast (OpenAI GPT 模型)
- opt — GPT2Tokenizer 或 GPT2TokenizerFast (OPT 模型)
- owlv2 — CLIPTokenizer 或 CLIPTokenizerFast (OWLv2 模型)
- owlvit — CLIPTokenizer 或 CLIPTokenizerFast (OWL-ViT 模型)
- paligemma — LlamaTokenizer 或 LlamaTokenizerFast (PaliGemma 模型)
- pegasus — PegasusTokenizer 或 PegasusTokenizerFast (Pegasus 模型)
- pegasus_x — PegasusTokenizer 或 PegasusTokenizerFast (PEGASUS-X 模型)
- perceiver — PerceiverTokenizer (Perceiver 模型)
- persimmon — LlamaTokenizer 或 LlamaTokenizerFast (Persimmon 模型)
- phi — CodeGenTokenizer 或 CodeGenTokenizerFast (Phi 模型)
- phi3 — LlamaTokenizer 或 LlamaTokenizerFast (Phi3 模型)
- phimoe — LlamaTokenizer 或 LlamaTokenizerFast (Phimoe 模型)
- phobert — PhobertTokenizer (PhoBERT 模型)
- pix2struct — T5Tokenizer 或 T5TokenizerFast (Pix2Struct 模型)
- pixtral — PreTrainedTokenizerFast (Pixtral 模型)
- plbart — PLBartTokenizer (PLBart 模型)
- prophetnet — ProphetNetTokenizer (ProphetNet 模型)
- qdqbert — BertTokenizer 或 BertTokenizerFast (QDQBert 模型)
- qwen2 — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2 模型)
- qwen2_5_vl — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2_5_VL 模型)
- qwen2_audio — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2Audio 模型)
- qwen2_moe — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2MoE 模型)
- qwen2_vl — Qwen2Tokenizer 或 Qwen2TokenizerFast (Qwen2VL 模型)
- rag — RagTokenizer (RAG 模型)
- realm — RealmTokenizer 或 RealmTokenizerFast (REALM 模型)
- recurrent_gemma — GemmaTokenizer 或 GemmaTokenizerFast (RecurrentGemma 模型)
- reformer — ReformerTokenizer 或 ReformerTokenizerFast (Reformer 模型)
- rembert — RemBertTokenizer 或 RemBertTokenizerFast (RemBERT 模型)
- retribert — RetriBertTokenizer 或 RetriBertTokenizerFast (RetriBERT 模型)
- roberta — RobertaTokenizer 或 RobertaTokenizerFast (RoBERTa 模型)
- roberta-prelayernorm — RobertaTokenizer 或 RobertaTokenizerFast (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertTokenizer (RoCBert 模型)
- roformer — RoFormerTokenizer 或 RoFormerTokenizerFast (RoFormer 模型)
- rwkv — GPTNeoXTokenizerFast (RWKV 模型)
- seamless_m4t — SeamlessM4TTokenizer 或 SeamlessM4TTokenizerFast (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4TTokenizer 或 SeamlessM4TTokenizerFast (SeamlessM4Tv2 模型)
- shieldgemma2 — GemmaTokenizer 或 GemmaTokenizerFast (Shieldgemma2 模型)
- siglip — SiglipTokenizer (SigLIP 模型)
- siglip2 — GemmaTokenizer 或 GemmaTokenizerFast (SigLIP2 模型)
- speech_to_text — Speech2TextTokenizer (Speech2Text 模型)
- speech_to_text_2 — Speech2Text2Tokenizer (Speech2Text2 模型)
- speecht5 — SpeechT5Tokenizer (SpeechT5 模型)
- splinter — SplinterTokenizer 或 SplinterTokenizerFast (Splinter 模型)
- squeezebert — SqueezeBertTokenizer 或 SqueezeBertTokenizerFast (SqueezeBERT 模型)
- stablelm — GPTNeoXTokenizerFast (StableLm 模型)
- starcoder2 — GPT2Tokenizer 或 GPT2TokenizerFast (Starcoder2 模型)
- switch_transformers — T5Tokenizer 或 T5TokenizerFast (SwitchTransformers 模型)
- t5 — T5Tokenizer 或 T5TokenizerFast (T5 模型)
- tapas — TapasTokenizer (TAPAS 模型)
- tapex — TapexTokenizer (TAPEX 模型)
- transfo-xl — TransfoXLTokenizer (Transformer-XL 模型)
- tvp — BertTokenizer 或 BertTokenizerFast (TVP 模型)
- udop — UdopTokenizer 或 UdopTokenizerFast (UDOP 模型)
- umt5 — T5Tokenizer 或 T5TokenizerFast (UMT5 模型)
- video_llava — LlamaTokenizer 或 LlamaTokenizerFast (VideoLlava 模型)
- vilt — BertTokenizer 或 BertTokenizerFast (ViLT 模型)
- vipllava — LlamaTokenizer 或 LlamaTokenizerFast (VipLlava 模型)
- visual_bert — BertTokenizer 或 BertTokenizerFast (VisualBERT 模型)
- vits — VitsTokenizer (VITS 模型)
- wav2vec2 — Wav2Vec2CTCTokenizer (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2CTCTokenizer (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2CTCTokenizer (Wav2Vec2-Conformer 模型)
- wav2vec2_phoneme — Wav2Vec2PhonemeCTCTokenizer (Wav2Vec2Phoneme 模型)
- whisper — WhisperTokenizer 或 WhisperTokenizerFast (Whisper 模型)
- xclip — CLIPTokenizer 或 CLIPTokenizerFast (X-CLIP 模型)
- xglm — XGLMTokenizer 或 XGLMTokenizerFast (XGLM 模型)
- xlm — XLMTokenizer (XLM 模型)
- xlm-prophetnet — XLMProphetNetTokenizer (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (XLM-RoBERTa-XL 模型)
- xlnet — XLNetTokenizer 或 XLNetTokenizerFast (XLNet 模型)
- xmod — XLMRobertaTokenizer 或 XLMRobertaTokenizerFast (X-MOD 模型)
- yoso — AlbertTokenizer 或 AlbertTokenizerFast (YOSO 模型)
- zamba — LlamaTokenizer 或 LlamaTokenizerFast (Zamba 模型)
- zamba2 — LlamaTokenizer 或 LlamaTokenizerFast (Zamba2 模型)
示例
>>> from transformers import AutoTokenizer
>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
>>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
>>> # Download vocabulary from huggingface.co and define model-specific arguments
>>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", add_prefix_space=True)
注册
< source >( config_class slow_tokenizer_class = None fast_tokenizer_class = None exist_ok = False )
参数
- config_class (PretrainedConfig) — 与要注册的模型相对应的配置。
- slow_tokenizer_class (
PretrainedTokenizer
, 可选的) — 要注册的慢速分词器。 - fast_tokenizer_class (
PretrainedTokenizerFast
, 可选的) — 要注册的快速分词器。
在此映射中注册一个新的分词器。
AutoFeatureExtractor
这是一个通用的特征提取器类,当使用 AutoFeatureExtractor.from_pretrained() 类方法创建时,它将被实例化为库中的特征提取器类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_pretrained
< 源文件 >( pretrained_model_name_or_path **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中托管的预训练 feature_extractor 的模型 ID。
- 一个目录的路径,该目录包含使用 save_pretrained() 方法保存的特征提取器文件,例如
./my_model_directory/
。 - 保存的特征提取器 JSON 文件的路径或 URL,例如
./my_model_directory/preprocessor_config.json
。
- cache_dir (
str
或os.PathLike
, 可选的) — 如果不应使用标准缓存,则下载的预训练模型特征提取器应缓存到的目录路径。 - force_download (
bool
, 可选的, 默认为False
) — 是否强制(重新)下载特征提取器文件,并覆盖已存在的缓存版本。 - resume_download — 已弃用并忽略。现在所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选的) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
代理用于每个请求。 - token (
str
或 bool, 可选的) — 用作远程文件的 HTTP Bearer 授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 - revision (
str
, 可选的, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - return_unused_kwargs (
bool
, 可选的, 默认为False
) — 如果为False
,则此函数仅返回最终的特征提取器对象。如果为True
,则此函数返回Tuple(feature_extractor, unused_kwargs)
,其中 unused_kwargs 是一个字典,由键/值对组成,其键不是特征提取器属性:即kwargs
中未用于更新feature_extractor
的部分,否则将被忽略。 - trust_remote_code (
bool
, 可选的, 默认为False
) — 是否允许 Hub 上自定义模型及其自身的建模文件。此选项仅应在您信任的仓库中设置为True
,并且您已阅读过代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - kwargs (
Dict[str, Any]
, 可选的) — kwargs 中任何键是特征提取器属性的值将用于覆盖加载的值。关于键不是特征提取器属性的键/值对的行为由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表实例化库中的特征提取器类之一。
要实例化的特征提取器类是基于配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它缺失时,通过回退到使用 pretrained_model_name_or_path
上的模式匹配来选择。
- audio-spectrogram-transformer — ASTFeatureExtractor (音频频谱图 Transformer 模型)
- beit — BeitFeatureExtractor (BEiT 模型)
- chinese_clip — ChineseCLIPFeatureExtractor (Chinese-CLIP 模型)
- clap — ClapFeatureExtractor (CLAP 模型)
- clip — CLIPFeatureExtractor (CLIP 模型)
- clipseg — ViTFeatureExtractor (CLIPSeg 模型)
- clvp — ClvpFeatureExtractor (CLVP 模型)
- conditional_detr — ConditionalDetrFeatureExtractor (Conditional DETR 模型)
- convnext — ConvNextFeatureExtractor (ConvNeXT 模型)
- cvt — ConvNextFeatureExtractor (CvT 模型)
- dac — DacFeatureExtractor (DAC 模型)
- data2vec-audio — Wav2Vec2FeatureExtractor (Data2VecAudio 模型)
- data2vec-vision — BeitFeatureExtractor (Data2VecVision 模型)
- deformable_detr — DeformableDetrFeatureExtractor (Deformable DETR 模型)
- deit — DeiTFeatureExtractor (DeiT 模型)
- detr — DetrFeatureExtractor (DETR 模型)
- dinat — ViTFeatureExtractor (DiNAT 模型)
- donut-swin — DonutFeatureExtractor (DonutSwin 模型)
- dpt — DPTFeatureExtractor (DPT 模型)
- encodec — EncodecFeatureExtractor (EnCodec 模型)
- flava — FlavaFeatureExtractor (FLAVA 模型)
- glpn — GLPNFeatureExtractor (GLPN 模型)
- groupvit — CLIPFeatureExtractor (GroupViT 模型)
- hubert — Wav2Vec2FeatureExtractor (Hubert 模型)
- imagegpt — ImageGPTFeatureExtractor (ImageGPT 模型)
- layoutlmv2 — LayoutLMv2FeatureExtractor (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3FeatureExtractor (LayoutLMv3 模型)
- levit — LevitFeatureExtractor (LeViT 模型)
- maskformer — MaskFormerFeatureExtractor (MaskFormer 模型)
- mctct — MCTCTFeatureExtractor (M-CTC-T 模型)
- mimi — EncodecFeatureExtractor (Mimi 模型)
- mobilenet_v1 — MobileNetV1FeatureExtractor (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2FeatureExtractor (MobileNetV2 模型)
- mobilevit — MobileViTFeatureExtractor (MobileViT 模型)
- moonshine — Wav2Vec2FeatureExtractor (Moonshine 模型)
- moshi — EncodecFeatureExtractor (Moshi 模型)
- nat — ViTFeatureExtractor (NAT 模型)
- owlvit —
OwlViTFeatureExtractor
(OWL-ViT 模型) - perceiver — PerceiverFeatureExtractor (Perceiver 模型)
- poolformer — PoolFormerFeatureExtractor (PoolFormer 模型)
- pop2piano — Pop2PianoFeatureExtractor (Pop2Piano 模型)
- regnet — ConvNextFeatureExtractor (RegNet 模型)
- resnet — ConvNextFeatureExtractor (ResNet 模型)
- seamless_m4t — SeamlessM4TFeatureExtractor (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4TFeatureExtractor (SeamlessM4Tv2 模型)
- segformer — SegformerFeatureExtractor (SegFormer 模型)
- sew — Wav2Vec2FeatureExtractor (SEW 模型)
- sew-d — Wav2Vec2FeatureExtractor (SEW-D 模型)
- speech_to_text — Speech2TextFeatureExtractor (Speech2Text 模型)
- speecht5 — SpeechT5FeatureExtractor (SpeechT5 模型)
- swiftformer — ViTFeatureExtractor (SwiftFormer 模型)
- swin — ViTFeatureExtractor (Swin Transformer 模型)
- swinv2 — ViTFeatureExtractor (Swin Transformer V2 模型)
- table-transformer — DetrFeatureExtractor (Table Transformer 模型)
- timesformer — VideoMAEFeatureExtractor (TimeSformer 模型)
- tvlt — TvltFeatureExtractor (TVLT 模型)
- unispeech — Wav2Vec2FeatureExtractor (UniSpeech 模型)
- unispeech-sat — Wav2Vec2FeatureExtractor (UniSpeechSat 模型)
- univnet — UnivNetFeatureExtractor (UnivNet 模型)
- van — ConvNextFeatureExtractor (VAN 模型)
- videomae — VideoMAEFeatureExtractor (VideoMAE 模型)
- vilt — ViltFeatureExtractor (ViLT 模型)
- vit — ViTFeatureExtractor (ViT 模型)
- vit_mae — ViTFeatureExtractor (ViTMAE 模型)
- vit_msn — ViTFeatureExtractor (ViTMSN 模型)
- wav2vec2 — Wav2Vec2FeatureExtractor (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2FeatureExtractor (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2FeatureExtractor (Wav2Vec2-Conformer 模型)
- wavlm — Wav2Vec2FeatureExtractor (WavLM 模型)
- whisper — WhisperFeatureExtractor (Whisper 模型)
- xclip — CLIPFeatureExtractor (X-CLIP 模型)
- yolos — YolosFeatureExtractor (YOLOS 模型)
当您想使用私有模型时,需要传递 token=True
。
示例
>>> from transformers import AutoFeatureExtractor
>>> # Download feature extractor from huggingface.co and cache.
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
注册
< 源文件 >( config_class feature_extractor_class exist_ok = False )
参数
- config_class (PretrainedConfig) — 与要注册的模型对应的配置。
- feature_extractor_class (
FeatureExtractorMixin
) — 要注册的特征提取器。
为此类注册一个新的特征提取器。
AutoImageProcessor
这是一个通用图像处理器类,当使用 AutoImageProcessor.from_pretrained() 类方法创建时,它将被实例化为库中的图像处理器类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path *inputs **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 这可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练 image_processor 的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 方法保存的图像处理器文件,例如,
./my_model_directory/
。 - 一个保存的图像处理器 JSON 文件的路径或 URL,例如,
./my_model_directory/preprocessor_config.json
。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型图像处理器的目录路径,如果不想使用标准缓存目录,则应指定此项。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载图像处理器文件,并覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个代理服务器字典,用于按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - token (
str
或 bool, 可选) — 用作远程文件的 HTTP Bearer 授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - use_fast (
bool
, 可选, 默认为False
) — 如果给定模型支持快速 torchvision 基础的图像处理器,则使用它。如果给定模型没有快速图像处理器,则返回普通的基于 numpy 的图像处理器。 - return_unused_kwargs (
bool
, 可选, 默认为False
) — 如果为False
,则此函数仅返回最终的图像处理器对象。如果为True
,则此函数返回Tuple(image_processor, unused_kwargs)
,其中 unused_kwargs 是一个字典,包含键/值对,这些键不是图像处理器属性:即kwargs
中未用于更新image_processor
且在其他情况下被忽略的部分。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应设置为True
,用于您信任的存储库,并且您已阅读其中的代码,因为它将在您的本地计算机上执行 Hub 上的代码。 - image_processor_filename (
str
, 可选, 默认为"config.json"
) — 模型目录中用于图像处理器配置的文件名。 - kwargs (
Dict[str, Any]
, 可选) —kwargs
中任何键是图像处理器属性的键值对,都将用于覆盖已加载的值。关于键值对的键不是图像处理器属性的行为,由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表实例化库中的一个图像处理器类。
要实例化的图像处理器类是基于配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者在缺少 model_type
属性时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- align — EfficientNetImageProcessor (ALIGN 模型)
- aria —
A
或r
(Aria 模型) - beit — BeitImageProcessor (BEiT 模型)
- bit — BitImageProcessor (BiT 模型)
- blip — BlipImageProcessor 或 BlipImageProcessorFast (BLIP 模型)
- blip-2 — BlipImageProcessor 或 BlipImageProcessorFast (BLIP-2 模型)
- bridgetower — BridgeTowerImageProcessor (BridgeTower 模型)
- chameleon — ChameleonImageProcessor (Chameleon 模型)
- chinese_clip — ChineseCLIPImageProcessor (Chinese-CLIP 模型)
- clip — CLIPImageProcessor 或 CLIPImageProcessorFast (CLIP 模型)
- clipseg — ViTImageProcessor 或 ViTImageProcessorFast (CLIPSeg 模型)
- conditional_detr — ConditionalDetrImageProcessor (Conditional DETR 模型)
- convnext — ConvNextImageProcessor 或 ConvNextImageProcessorFast (ConvNeXT 模型)
- convnextv2 — ConvNextImageProcessor 或 ConvNextImageProcessorFast (ConvNeXTV2 模型)
- cvt — ConvNextImageProcessor 或 ConvNextImageProcessorFast (CvT 模型)
- data2vec-vision — BeitImageProcessor (Data2VecVision 模型)
- deformable_detr — DeformableDetrImageProcessor 或 DeformableDetrImageProcessorFast (Deformable DETR 模型)
- deit — DeiTImageProcessor 或 DeiTImageProcessorFast (DeiT 模型)
- depth_anything — DPTImageProcessor (Depth Anything 模型)
- depth_pro — DepthProImageProcessor 或 DepthProImageProcessorFast (DepthPro 模型)
- deta — DetaImageProcessor (DETA 模型)
- detr — DetrImageProcessor 或 DetrImageProcessorFast (DETR 模型)
- dinat — ViTImageProcessor 或 ViTImageProcessorFast (DiNAT 模型)
- dinov2 — BitImageProcessor (DINOv2 模型)
- donut-swin — DonutImageProcessor (DonutSwin 模型)
- dpt — DPTImageProcessor (DPT 模型)
- efficientformer — EfficientFormerImageProcessor (EfficientFormer 模型)
- efficientnet — EfficientNetImageProcessor (EfficientNet 模型)
- flava — FlavaImageProcessor (FLAVA 模型)
- focalnet — BitImageProcessor (FocalNet 模型)
- fuyu — FuyuImageProcessor (Fuyu 模型)
- gemma3 — Gemma3ImageProcessor 或 Gemma3ImageProcessorFast (Gemma3ForConditionalGeneration 模型)
- git — CLIPImageProcessor 或 CLIPImageProcessorFast (GIT 模型)
- glpn — GLPNImageProcessor (GLPN 模型)
- got_ocr2 — GotOcr2ImageProcessor 或 GotOcr2ImageProcessorFast (GOT-OCR2 模型)
- grounding-dino — GroundingDinoImageProcessor (Grounding DINO 模型)
- groupvit — CLIPImageProcessor 或 CLIPImageProcessorFast (GroupViT 模型)
- hiera — BitImageProcessor (Hiera 模型)
- idefics — IdeficsImageProcessor (IDEFICS 模型)
- idefics2 — Idefics2ImageProcessor (Idefics2 模型)
- idefics3 — Idefics3ImageProcessor (Idefics3 模型)
- ijepa — ViTImageProcessor 或 ViTImageProcessorFast (I-JEPA 模型)
- imagegpt — ImageGPTImageProcessor (ImageGPT 模型)
- instructblip — BlipImageProcessor 或 BlipImageProcessorFast (InstructBLIP 模型)
- instructblipvideo — InstructBlipVideoImageProcessor (InstructBlipVideo 模型)
- kosmos-2 — CLIPImageProcessor 或 CLIPImageProcessorFast (KOSMOS-2 模型)
- layoutlmv2 — LayoutLMv2ImageProcessor (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ImageProcessor (LayoutLMv3 模型)
- levit — LevitImageProcessor (LeViT 模型)
- llava — LlavaImageProcessor 或 LlavaImageProcessorFast (LLaVa 模型)
- llava_next — LlavaNextImageProcessor 或 LlavaNextImageProcessorFast (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoImageProcessor (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionImageProcessor 或 LlavaOnevisionImageProcessorFast (LLaVA-Onevision 模型)
- mask2former — Mask2FormerImageProcessor (Mask2Former 模型)
- maskformer — MaskFormerImageProcessor (MaskFormer 模型)
- mgp-str — ViTImageProcessor 或 ViTImageProcessorFast (MGP-STR 模型)
- mistral3 — PixtralImageProcessor 或 PixtralImageProcessorFast (Mistral3 模型)
- mllama — MllamaImageProcessor (Mllama 模型)
- mobilenet_v1 — MobileNetV1ImageProcessor (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2ImageProcessor (MobileNetV2 模型)
- mobilevit — MobileViTImageProcessor (MobileViT 模型)
- mobilevitv2 — MobileViTImageProcessor (MobileViTV2 模型)
- nat — ViTImageProcessor 或 ViTImageProcessorFast (NAT 模型)
- nougat — NougatImageProcessor (Nougat 模型)
- oneformer — OneFormerImageProcessor (OneFormer 模型)
- owlv2 — Owlv2ImageProcessor (OWLv2 模型)
- owlvit — OwlViTImageProcessor (OWL-ViT 模型)
- paligemma — SiglipImageProcessor 或 SiglipImageProcessorFast (PaliGemma 模型)
- perceiver — PerceiverImageProcessor (Perceiver 模型)
- pix2struct — Pix2StructImageProcessor (Pix2Struct 模型)
- pixtral — PixtralImageProcessor 或 PixtralImageProcessorFast (Pixtral 模型)
- poolformer — PoolFormerImageProcessor (PoolFormer 模型)
- prompt_depth_anything — PromptDepthAnythingImageProcessor (PromptDepthAnything 模型)
- pvt — PvtImageProcessor (PVT 模型)
- pvt_v2 — PvtImageProcessor (PVTv2 模型)
- qwen2_5_vl — Qwen2VLImageProcessor 或 Qwen2VLImageProcessorFast (Qwen2_5_VL 模型)
- qwen2_vl — Qwen2VLImageProcessor 或 Qwen2VLImageProcessorFast (Qwen2VL 模型)
- regnet — ConvNextImageProcessor 或 ConvNextImageProcessorFast (RegNet 模型)
- resnet — ConvNextImageProcessor 或 ConvNextImageProcessorFast (ResNet 模型)
- rt_detr — RTDetrImageProcessor 或 RTDetrImageProcessorFast (RT-DETR 模型)
- sam — SamImageProcessor (SAM 模型)
- segformer — SegformerImageProcessor (SegFormer 模型)
- seggpt — SegGptImageProcessor (SegGPT 模型)
- shieldgemma2 — Gemma3ImageProcessor 或 Gemma3ImageProcessorFast (Shieldgemma2 模型)
- siglip — SiglipImageProcessor 或 SiglipImageProcessorFast (SigLIP 模型)
- siglip2 — Siglip2ImageProcessor 或 Siglip2ImageProcessorFast (SigLIP2 模型)
- superglue —
S
或u
(SuperGlue 模型) - swiftformer — ViTImageProcessor 或 ViTImageProcessorFast (SwiftFormer 模型)
- swin — ViTImageProcessor 或 ViTImageProcessorFast (Swin Transformer 模型)
- swin2sr — Swin2SRImageProcessor (Swin2SR 模型)
- swinv2 — ViTImageProcessor 或 ViTImageProcessorFast (Swin Transformer V2 模型)
- table-transformer — DetrImageProcessor (Table Transformer 模型)
- timesformer — VideoMAEImageProcessor (TimeSformer 模型)
- timm_wrapper — TimmWrapperImageProcessor (TimmWrapperModel 模型)
- tvlt — TvltImageProcessor (TVLT 模型)
- tvp — TvpImageProcessor (TVP 模型)
- udop — LayoutLMv3ImageProcessor (UDOP 模型)
- upernet — SegformerImageProcessor (UPerNet 模型)
- van — ConvNextImageProcessor 或 ConvNextImageProcessorFast (VAN 模型)
- videomae — VideoMAEImageProcessor (VideoMAE 模型)
- vilt — ViltImageProcessor (ViLT 模型)
- vipllava — CLIPImageProcessor 或 CLIPImageProcessorFast (VipLlava 模型)
- vit — ViTImageProcessor 或 ViTImageProcessorFast (ViT 模型)
- vit_hybrid — ViTHybridImageProcessor (ViT Hybrid 模型)
- vit_mae — ViTImageProcessor 或 ViTImageProcessorFast (ViTMAE 模型)
- vit_msn — ViTImageProcessor 或 ViTImageProcessorFast (ViTMSN 模型)
- vitmatte — VitMatteImageProcessor (ViTMatte 模型)
- xclip — CLIPImageProcessor 或 CLIPImageProcessorFast (X-CLIP 模型)
- yolos — YolosImageProcessor (YOLOS 模型)
- zoedepth — ZoeDepthImageProcessor (ZoeDepth 模型)
当您想使用私有模型时,需要传递 token=True
。
示例
>>> from transformers import AutoImageProcessor
>>> # Download image processor from huggingface.co and cache.
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
注册
< source >( config_class image_processor_class = None slow_image_processor_class = None fast_image_processor_class = None exist_ok = False )
参数
- config_class (PretrainedConfig) — 注册模型对应的配置类。
- image_processor_class (ImageProcessingMixin) — 要注册的图像处理器类。
为此类注册一个新的图像处理器。
AutoProcessor
这是一个通用处理器类,当使用 AutoProcessor.from_pretrained() 类方法创建时,它将被实例化为库中的处理器类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 模型仓库中托管的预训练 feature_extractor 的模型 ID。
- 一个目录的路径,其中包含使用
save_pretrained()
方法保存的处理器文件,例如,./my_model_directory/
。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型特征提取器的目录路径,如果不想使用标准缓存。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载特征提取器文件并覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - token (
str
或 bool, 可选) — 用作远程文件的 HTTP Bearer 授权的令牌。如果为True
,将使用运行huggingface-cli login
时生成的令牌(存储在~/.huggingface
中)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - return_unused_kwargs (
bool
, 可选, 默认为False
) — 如果为False
,则此函数仅返回最终的特征提取器对象。如果为True
,则此函数返回一个Tuple(feature_extractor, unused_kwargs)
,其中 unused_kwargs 是一个字典,其中包含键/值对,这些键不是特征提取器属性:即,kwargs
中未用于更新feature_extractor
且在其他情况下被忽略的部分。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项应仅对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地机器上执行 Hub 上存在的代码。 - kwargs (
Dict[str, Any]
, 可选) — kwargs 中任何键是特征提取器属性的值将用于覆盖加载的值。关于键/值对的行为,其键不是特征提取器属性,由return_unused_kwargs
关键字参数控制。
从预训练模型词汇表实例化库中的一个处理器类。
要实例化的处理器类是根据配置对象的 model_type
属性选择的(可以作为参数传递,也可以从 pretrained_model_name_or_path
加载,如果可能)。
- align — AlignProcessor (ALIGN 模型)
- altclip — AltCLIPProcessor (AltCLIP 模型)
- aria — AriaProcessor (Aria 模型)
- aya_vision — AyaVisionProcessor (AyaVision 模型)
- bark — BarkProcessor (Bark 模型)
- blip — BlipProcessor (BLIP 模型)
- blip-2 — Blip2Processor (BLIP-2 模型)
- bridgetower — BridgeTowerProcessor (BridgeTower 模型)
- chameleon — ChameleonProcessor (Chameleon 模型)
- chinese_clip — ChineseCLIPProcessor (Chinese-CLIP 模型)
- clap — ClapProcessor (CLAP 模型)
- clip — CLIPProcessor (CLIP 模型)
- clipseg — CLIPSegProcessor (CLIPSeg 模型)
- clvp — ClvpProcessor (CLVP 模型)
- colpali — ColPaliProcessor (ColPali 模型)
- emu3 — Emu3Processor (Emu3 模型)
- flava — FlavaProcessor (FLAVA 模型)
- fuyu — FuyuProcessor (Fuyu 模型)
- gemma3 — Gemma3Processor (Gemma3ForConditionalGeneration 模型)
- git — GitProcessor (GIT 模型)
- got_ocr2 — GotOcr2Processor (GOT-OCR2 模型)
- grounding-dino — GroundingDinoProcessor (Grounding DINO 模型)
- groupvit — CLIPProcessor (GroupViT 模型)
- hubert — Wav2Vec2Processor (Hubert 模型)
- idefics — IdeficsProcessor (IDEFICS 模型)
- idefics2 — Idefics2Processor (Idefics2 模型)
- idefics3 — Idefics3Processor (Idefics3 模型)
- instructblip — InstructBlipProcessor (InstructBLIP 模型)
- instructblipvideo — InstructBlipVideoProcessor (InstructBlipVideo 模型)
- kosmos-2 — Kosmos2Processor (KOSMOS-2 模型)
- layoutlmv2 — LayoutLMv2Processor (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Processor (LayoutLMv3 模型)
- llava — LlavaProcessor (LLaVa 模型)
- llava_next — LlavaNextProcessor (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoProcessor (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionProcessor (LLaVA-Onevision 模型)
- markuplm — MarkupLMProcessor (MarkupLM 模型)
- mctct — MCTCTProcessor (M-CTC-T 模型)
- mgp-str — MgpstrProcessor (MGP-STR 模型)
- mistral3 — PixtralProcessor (Mistral3 模型)
- mllama — MllamaProcessor (Mllama 模型)
- moonshine — Wav2Vec2Processor (Moonshine 模型)
- oneformer — OneFormerProcessor (OneFormer 模型)
- owlv2 — Owlv2Processor (OWLv2 模型)
- owlvit — OwlViTProcessor (OWL-ViT 模型)
- paligemma — PaliGemmaProcessor (PaliGemma 模型)
- pix2struct — Pix2StructProcessor (Pix2Struct 模型)
- pixtral — PixtralProcessor (Pixtral 模型)
- pop2piano — Pop2PianoProcessor (Pop2Piano 模型)
- qwen2_5_vl — Qwen2_5_VLProcessor (Qwen2_5_VL 模型)
- qwen2_audio — Qwen2AudioProcessor (Qwen2Audio 模型)
- qwen2_vl — Qwen2VLProcessor (Qwen2VL 模型)
- sam — SamProcessor (SAM 模型)
- seamless_m4t — SeamlessM4TProcessor (SeamlessM4T 模型)
- sew — Wav2Vec2Processor (SEW 模型)
- sew-d — Wav2Vec2Processor (SEW-D 模型)
- shieldgemma2 — ShieldGemma2Processor (Shieldgemma2 模型)
- siglip — SiglipProcessor (SigLIP 模型)
- siglip2 — Siglip2Processor (SigLIP2 模型)
- speech_to_text — Speech2TextProcessor (Speech2Text 模型)
- speech_to_text_2 — Speech2Text2Processor (Speech2Text2 模型)
- speecht5 — SpeechT5Processor (SpeechT5 模型)
- trocr — TrOCRProcessor (TrOCR 模型)
- tvlt — TvltProcessor (TVLT 模型)
- tvp — TvpProcessor (TVP 模型)
- udop — UdopProcessor (UDOP 模型)
- unispeech — Wav2Vec2Processor (UniSpeech 模型)
- unispeech-sat — Wav2Vec2Processor (UniSpeechSat 模型)
- video_llava — VideoLlavaProcessor (VideoLlava 模型)
- vilt — ViltProcessor (ViLT 模型)
- vipllava — LlavaProcessor (VipLlava 模型)
- vision-text-dual-encoder — VisionTextDualEncoderProcessor (VisionTextDualEncoder 模型)
- wav2vec2 — Wav2Vec2Processor (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2Processor (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2Processor (Wav2Vec2-Conformer 模型)
- wavlm — Wav2Vec2Processor (WavLM 模型)
- whisper — WhisperProcessor (Whisper 模型)
- xclip — XCLIPProcessor (X-CLIP 模型)
当您想使用私有模型时,需要传递 token=True
。
示例
>>> from transformers import AutoProcessor
>>> # Download processor from huggingface.co and cache.
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
注册
< source >( config_class processor_class exist_ok = False )
参数
- config_class (PretrainedConfig) — 要注册的模型对应的配置类。
- processor_class (ProcessorMixin) — 要注册的处理器。
为此类注册一个新的处理器。
通用模型类
以下自动类可用于实例化没有特定 head 的基本模型类。
AutoModel
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的基础模型类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — The model class to instantiate is selected based on the configuration class:
- ASTConfig configuration class: ASTModel (Audio Spectrogram Transformer model)
- AlbertConfig configuration class: AlbertModel (ALBERT model)
- AlignConfig configuration class: AlignModel (ALIGN model)
- AltCLIPConfig configuration class: AltCLIPModel (AltCLIP model)
- AriaConfig configuration class: AriaForConditionalGeneration (Aria model)
- AriaTextConfig configuration class: AriaTextModel (AriaText model)
- AutoformerConfig configuration class: AutoformerModel (Autoformer model)
- BambaConfig configuration class: BambaModel (Bamba model)
- BarkConfig configuration class: BarkModel (Bark model)
- BartConfig configuration class: BartModel (BART model)
- BeitConfig configuration class: BeitModel (BEiT model)
- BertConfig configuration class: BertModel (BERT model)
- BertGenerationConfig configuration class: BertGenerationEncoder (Bert Generation model)
- BigBirdConfig configuration class: BigBirdModel (BigBird model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusModel (BigBird-Pegasus model)
- BioGptConfig configuration class: BioGptModel (BioGpt model)
- BitConfig configuration class: BitModel (BiT model)
- BlenderbotConfig configuration class: BlenderbotModel (Blenderbot model)
- BlenderbotSmallConfig configuration class: BlenderbotSmallModel (BlenderbotSmall model)
- Blip2Config configuration class: Blip2Model (BLIP-2 model)
- BlipConfig configuration class: BlipModel (BLIP model)
- BloomConfig configuration class: BloomModel (BLOOM model)
- BridgeTowerConfig configuration class: BridgeTowerModel (BridgeTower model)
- BrosConfig configuration class: BrosModel (BROS model)
- CLIPConfig configuration class: CLIPModel (CLIP model)
- CLIPSegConfig configuration class: CLIPSegModel (CLIPSeg model)
- CLIPTextConfig configuration class: CLIPTextModel (CLIPTextModel model)
- CLIPVisionConfig configuration class: CLIPVisionModel (CLIPVisionModel model)
- CTRLConfig configuration class: CTRLModel (CTRL model)
- CamembertConfig configuration class: CamembertModel (CamemBERT model)
- CanineConfig configuration class: CanineModel (CANINE model)
- ChameleonConfig configuration class: ChameleonModel (Chameleon model)
- ChineseCLIPConfig configuration class: ChineseCLIPModel (Chinese-CLIP model)
- ChineseCLIPVisionConfig configuration class: ChineseCLIPVisionModel (ChineseCLIPVisionModel model)
- ClapConfig configuration class: ClapModel (CLAP model)
- ClvpConfig configuration class: ClvpModelForConditionalGeneration (CLVP model)
- CodeGenConfig configuration class: CodeGenModel (CodeGen model)
- Cohere2Config configuration class: Cohere2Model (Cohere2 model)
- CohereConfig configuration class: CohereModel (Cohere model)
- ConditionalDetrConfig configuration class: ConditionalDetrModel (Conditional DETR model)
- ConvBertConfig configuration class: ConvBertModel (ConvBERT model)
- ConvNextConfig configuration class: ConvNextModel (ConvNeXT model)
- ConvNextV2Config configuration class: ConvNextV2Model (ConvNeXTV2 model)
- CpmAntConfig configuration class: CpmAntModel (CPM-Ant model)
- CvtConfig configuration class: CvtModel (CvT model)
- DPRConfig configuration class: DPRQuestionEncoder (DPR model)
- DPTConfig configuration class: DPTModel (DPT model)
- DabDetrConfig configuration class: DabDetrModel (DAB-DETR model)
- DacConfig configuration class: DacModel (DAC model)
- Data2VecAudioConfig configuration class: Data2VecAudioModel (Data2VecAudio model)
- Data2VecTextConfig configuration class: Data2VecTextModel (Data2VecText model)
- Data2VecVisionConfig configuration class: Data2VecVisionModel (Data2VecVision model)
- DbrxConfig configuration class: DbrxModel (DBRX model)
- DebertaConfig configuration class: DebertaModel (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2Model (DeBERTa-v2 model)
- DecisionTransformerConfig configuration class: DecisionTransformerModel (Decision Transformer model)
- DeformableDetrConfig configuration class: DeformableDetrModel (Deformable DETR model)
- DeiTConfig configuration class: DeiTModel (DeiT model)
- DepthProConfig configuration class: DepthProModel (DepthPro model)
- DetaConfig configuration class: DetaModel (DETA model)
- DetrConfig configuration class: DetrModel (DETR model)
- DiffLlamaConfig configuration class: DiffLlamaModel (DiffLlama model)
- DinatConfig configuration class: DinatModel (DiNAT model)
- Dinov2Config configuration class: Dinov2Model (DINOv2 model)
- Dinov2WithRegistersConfig configuration class: Dinov2WithRegistersModel (DINOv2 with Registers model)
- DistilBertConfig configuration class: DistilBertModel (DistilBERT model)
- DonutSwinConfig configuration class: DonutSwinModel (DonutSwin model)
- EfficientFormerConfig configuration class: EfficientFormerModel (EfficientFormer model)
- EfficientNetConfig configuration class: EfficientNetModel (EfficientNet model)
- ElectraConfig configuration class: ElectraModel (ELECTRA model)
- EncodecConfig configuration class: EncodecModel (EnCodec model)
- ErnieConfig configuration class: ErnieModel (ERNIE model)
- ErnieMConfig configuration class: ErnieMModel (ErnieM model)
- EsmConfig configuration class: EsmModel (ESM model)
- FNetConfig configuration class: FNetModel (FNet model)
- FSMTConfig configuration class: FSMTModel (FairSeq Machine-Translation model)
- FalconConfig configuration class: FalconModel (Falcon model)
- FalconMambaConfig configuration class: FalconMambaModel (FalconMamba model)
- FastSpeech2ConformerConfig configuration class: FastSpeech2ConformerModel (FastSpeech2Conformer model)
- FlaubertConfig configuration class: FlaubertModel (FlauBERT model)
- FlavaConfig configuration class: FlavaModel (FLAVA model)
- FocalNetConfig configuration class: FocalNetModel (FocalNet model)
- FunnelConfig configuration class: FunnelModel or FunnelBaseModel (Funnel Transformer model)
- GLPNConfig configuration class: GLPNModel (GLPN model)
- GPT2Config configuration class: GPT2Model (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeModel (GPTBigCode model)
- GPTJConfig configuration class: GPTJModel (GPT-J model)
- GPTNeoConfig configuration class: GPTNeoModel (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXModel (GPT NeoX model)
- GPTNeoXJapaneseConfig configuration class: GPTNeoXJapaneseModel (GPT NeoX Japanese model)
- GPTSanJapaneseConfig configuration class: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- Gemma2Config configuration class: Gemma2Model (Gemma2 model)
- Gemma3TextConfig configuration class: Gemma3TextModel (Gemma3ForCausalLM model)
- GemmaConfig configuration class: GemmaModel (Gemma model)
- GitConfig configuration class: GitModel (GIT model)
- GlmConfig configuration class: GlmModel (GLM model)
- GotOcr2Config configuration class: GotOcr2ForConditionalGeneration (GOT-OCR2 model)
- GraniteConfig configuration class: GraniteModel (Granite model)
- GraniteMoeConfig configuration class: GraniteMoeModel (GraniteMoeMoe model)
- GraniteMoeSharedConfig configuration class: GraniteMoeSharedModel (GraniteMoeSharedMoe model)
- GraphormerConfig configuration class: GraphormerModel (Graphormer model)
- GroundingDinoConfig configuration class: GroundingDinoModel (Grounding DINO model)
- GroupViTConfig configuration class: GroupViTModel (GroupViT model)
- HeliumConfig configuration class: HeliumModel (Helium model)
- HieraConfig configuration class: HieraModel (Hiera model)
- HubertConfig configuration class: HubertModel (Hubert model)
- IBertConfig configuration class: IBertModel (I-BERT model)
- IJepaConfig configuration class: IJepaModel (I-JEPA model)
- Idefics2Config configuration class: Idefics2Model (Idefics2 model)
- Idefics3Config configuration class: Idefics3Model (Idefics3 model)
- Idefics3VisionConfig configuration class: Idefics3VisionTransformer (Idefics3VisionTransformer model)
- IdeficsConfig configuration class: IdeficsModel (IDEFICS model)
- ImageGPTConfig configuration class: ImageGPTModel (ImageGPT model)
- InformerConfig configuration class: InformerModel (Informer model)
- JambaConfig configuration class: JambaModel (Jamba model)
- JetMoeConfig configuration class: JetMoeModel (JetMoe model)
- JukeboxConfig configuration class: JukeboxModel (Jukebox model)
- Kosmos2Config configuration class: Kosmos2Model (KOSMOS-2 model)
- LEDConfig configuration class: LEDModel (LED model)
- LayoutLMConfig configuration class: LayoutLMModel (LayoutLM model)
- LayoutLMv2Config configuration class: LayoutLMv2Model (LayoutLMv2 model)
- LayoutLMv3Config configuration class: LayoutLMv3Model (LayoutLMv3 model)
- LevitConfig configuration class: LevitModel (LeViT model)
- LiltConfig configuration class: LiltModel (LiLT model)
- LlamaConfig configuration class: LlamaModel (LLaMA model)
- LongT5Config configuration class: LongT5Model (LongT5 model)
- LongformerConfig configuration class: LongformerModel (Longformer model)
- LukeConfig configuration class: LukeModel (LUKE model)
- LxmertConfig configuration class: LxmertModel (LXMERT model)
- M2M100Config configuration class: M2M100Model (M2M100 model)
- MBartConfig configuration class: MBartModel (mBART model)
- MCTCTConfig configuration class: MCTCTModel (M-CTC-T model)
- MPNetConfig configuration class: MPNetModel (MPNet model)
- MT5Config configuration class: MT5Model (MT5 model)
- Mamba2Config configuration class: Mamba2Model (mamba2 model)
- MambaConfig configuration class: MambaModel (Mamba model)
- MarianConfig configuration class: MarianModel (Marian model)
- MarkupLMConfig configuration class: MarkupLMModel (MarkupLM model)
- Mask2FormerConfig configuration class: Mask2FormerModel (Mask2Former model)
- MaskFormerConfig configuration class: MaskFormerModel (MaskFormer model)
MaskFormerSwinConfig
configuration class:MaskFormerSwinModel
(MaskFormerSwin model)- MegaConfig configuration class: MegaModel (MEGA model)
- MegatronBertConfig configuration class: MegatronBertModel (Megatron-BERT model)
- MgpstrConfig configuration class: MgpstrForSceneTextRecognition (MGP-STR model)
- MimiConfig configuration class: MimiModel (Mimi model)
- MistralConfig configuration class: MistralModel (Mistral model)
- MixtralConfig configuration class: MixtralModel (Mixtral model)
- MobileBertConfig configuration class: MobileBertModel (MobileBERT model)
- MobileNetV1Config configuration class: MobileNetV1Model (MobileNetV1 model)
- MobileNetV2Config configuration class: MobileNetV2Model (MobileNetV2 model)
- MobileViTConfig configuration class: MobileViTModel (MobileViT model)
- MobileViTV2Config configuration class: MobileViTV2Model (MobileViTV2 model)
- ModernBertConfig configuration class: ModernBertModel (ModernBERT model)
- MoonshineConfig configuration class: MoonshineModel (Moonshine model)
- MoshiConfig configuration class: MoshiModel (Moshi model)
- MptConfig configuration class: MptModel (MPT model)
- MraConfig configuration class: MraModel (MRA model)
- MusicgenConfig configuration class: MusicgenModel (MusicGen model)
- MusicgenMelodyConfig configuration class: MusicgenMelodyModel (MusicGen Melody model)
- MvpConfig configuration class: MvpModel (MVP model)
- NatConfig configuration class: NatModel (NAT model)
- NemotronConfig configuration class: NemotronModel (Nemotron model)
- NezhaConfig configuration class: NezhaModel (Nezha model)
- NllbMoeConfig configuration class: NllbMoeModel (NLLB-MOE model)
- NystromformerConfig configuration class: NystromformerModel (Nyströmformer model)
- OPTConfig configuration class: OPTModel (OPT model)
- Olmo2Config configuration class: Olmo2Model (OLMo2 model)
- OlmoConfig configuration class: OlmoModel (OLMo model)
- OlmoeConfig configuration class: OlmoeModel (OLMoE model)
- OmDetTurboConfig configuration class: OmDetTurboForObjectDetection (OmDet-Turbo model)
- OneFormerConfig configuration class: OneFormerModel (OneFormer model)
- OpenAIGPTConfig configuration class: OpenAIGPTModel (OpenAI GPT model)
- OpenLlamaConfig configuration class: OpenLlamaModel (OpenLlama model)
- OwlViTConfig configuration class: OwlViTModel (OWL-ViT model)
- Owlv2Config configuration class: Owlv2Model (OWLv2 model)
- PLBartConfig configuration class: PLBartModel (PLBart model)
- PatchTSMixerConfig configuration class: PatchTSMixerModel (PatchTSMixer model)
- PatchTSTConfig configuration class: PatchTSTModel (PatchTST model)
- PegasusConfig configuration class: PegasusModel (Pegasus model)
- PegasusXConfig configuration class: PegasusXModel (PEGASUS-X model)
- PerceiverConfig configuration class: PerceiverModel (Perceiver model)
- PersimmonConfig configuration class: PersimmonModel (Persimmon model)
- Phi3Config configuration class: Phi3Model (Phi3 model)
- PhiConfig configuration class: PhiModel (Phi model)
- PhimoeConfig configuration class: PhimoeModel (Phimoe model)
- PixtralVisionConfig configuration class: PixtralVisionModel (Pixtral model)
- PoolFormerConfig configuration class: PoolFormerModel (PoolFormer model)
- ProphetNetConfig configuration class: ProphetNetModel (ProphetNet model)
- PvtConfig configuration class: PvtModel (PVT model)
- PvtV2Config configuration class: PvtV2Model (PVTv2 model)
- QDQBertConfig configuration class: QDQBertModel (QDQBert model)
- Qwen2AudioEncoderConfig configuration class: Qwen2AudioEncoder (Qwen2AudioEncoder model)
- Qwen2Config configuration class: Qwen2Model (Qwen2 model)
- Qwen2MoeConfig configuration class: Qwen2MoeModel (Qwen2MoE model)
- Qwen2VLConfig configuration class: Qwen2VLModel (Qwen2VL model)
- Qwen2_5_VLConfig configuration class: Qwen2_5_VLModel (Qwen2_5_VL model)
- RTDetrConfig configuration class: RTDetrModel (RT-DETR model)
- RTDetrV2Config configuration class: RTDetrV2Model (RT-DETRv2 model)
- RecurrentGemmaConfig configuration class: RecurrentGemmaModel (RecurrentGemma model)
- ReformerConfig configuration class: ReformerModel (Reformer model)
- RegNetConfig configuration class: RegNetModel (RegNet model)
- RemBertConfig configuration class: RemBertModel (RemBERT model)
- ResNetConfig configuration class: ResNetModel (ResNet model)
- RetriBertConfig configuration class: RetriBertModel (RetriBERT model)
- RoCBertConfig configuration class: RoCBertModel (RoCBert model)
- RoFormerConfig configuration class: RoFormerModel (RoFormer model)
- RobertaConfig configuration class: RobertaModel (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- RwkvConfig configuration class: RwkvModel (RWKV model)
- SEWConfig configuration class: SEWModel (SEW model)
- SEWDConfig configuration class: SEWDModel (SEW-D model)
- SamConfig configuration class: SamModel (SAM model)
- SeamlessM4TConfig configuration class: SeamlessM4TModel (SeamlessM4T model)
- SeamlessM4Tv2Config configuration class: SeamlessM4Tv2Model (SeamlessM4Tv2 model)
- SegGptConfig configuration class: SegGptModel (SegGPT model)
- SegformerConfig configuration class: SegformerModel (SegFormer model)
- Siglip2Config configuration class: Siglip2Model (SigLIP2 model)
- SiglipConfig configuration class: SiglipModel (SigLIP model)
- SiglipVisionConfig configuration class: SiglipVisionModel (SiglipVisionModel model)
- SmolVLMConfig configuration class: SmolVLMModel (SmolVLM model)
- SmolVLMVisionConfig configuration class: SmolVLMVisionTransformer (SmolVLMVisionTransformer model)
- Speech2TextConfig configuration class: Speech2TextModel (Speech2Text model)
- SpeechT5Config configuration class: SpeechT5Model (SpeechT5 model)
- SplinterConfig configuration class: SplinterModel (Splinter model)
- SqueezeBertConfig configuration class: SqueezeBertModel (SqueezeBERT model)
- StableLmConfig configuration class: StableLmModel (StableLm model)
- Starcoder2Config configuration class: Starcoder2Model (Starcoder2 model)
- SuperGlueConfig configuration class: SuperGlueForKeypointMatching (SuperGlue model)
- SwiftFormerConfig configuration class: SwiftFormerModel (SwiftFormer model)
- Swin2SRConfig configuration class: Swin2SRModel (Swin2SR model)
- SwinConfig configuration class: SwinModel (Swin Transformer model)
- Swinv2Config configuration class: Swinv2Model (Swin Transformer V2 model)
- SwitchTransformersConfig configuration class: SwitchTransformersModel (SwitchTransformers model)
- T5Config configuration class: T5Model (T5 model)
- TableTransformerConfig configuration class: TableTransformerModel (Table Transformer model)
- TapasConfig configuration class: TapasModel (TAPAS model)
- TextNetConfig configuration class: TextNetModel (TextNet model)
- TimeSeriesTransformerConfig configuration class: TimeSeriesTransformerModel (Time Series Transformer model)
- TimesformerConfig configuration class: TimesformerModel (TimeSformer model)
- TimmBackboneConfig configuration class: TimmBackbone (TimmBackbone model)
- TimmWrapperConfig configuration class: TimmWrapperModel (TimmWrapperModel model)
- TrajectoryTransformerConfig configuration class: TrajectoryTransformerModel (Trajectory Transformer model)
- TransfoXLConfig configuration class: TransfoXLModel (Transformer-XL model)
- TvltConfig configuration class: TvltModel (TVLT model)
- TvpConfig configuration class: TvpModel (TVP model)
- UMT5Config configuration class: UMT5Model (UMT5 model)
- UdopConfig configuration class: UdopModel (UDOP model)
- UniSpeechConfig configuration class: UniSpeechModel (UniSpeech model)
- UniSpeechSatConfig configuration class: UniSpeechSatModel (UniSpeechSat model)
- UnivNetConfig configuration class: UnivNetModel (UnivNet model)
- VanConfig configuration class: VanModel (VAN model)
- ViTConfig configuration class: ViTModel (ViT model)
- ViTHybridConfig configuration class: ViTHybridModel (ViT Hybrid model)
- ViTMAEConfig configuration class: ViTMAEModel (ViTMAE model)
- ViTMSNConfig configuration class: ViTMSNModel (ViTMSN model)
- VideoMAEConfig configuration class: VideoMAEModel (VideoMAE model)
- ViltConfig configuration class: ViltModel (ViLT model)
- VisionTextDualEncoderConfig configuration class: VisionTextDualEncoderModel (VisionTextDualEncoder model)
- VisualBertConfig configuration class: VisualBertModel (VisualBERT model)
- VitDetConfig configuration class: VitDetModel (VitDet model)
- VitsConfig configuration class: VitsModel (VITS model)
- VivitConfig configuration class: VivitModel (ViViT model)
- Wav2Vec2BertConfig configuration class: Wav2Vec2BertModel (Wav2Vec2-BERT model)
- Wav2Vec2Config configuration class: Wav2Vec2Model (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerModel (Wav2Vec2-Conformer model)
- WavLMConfig configuration class: WavLMModel (WavLM model)
- WhisperConfig configuration class: WhisperModel (Whisper model)
- XCLIPConfig configuration class: XCLIPModel (X-CLIP model)
- XGLMConfig configuration class: XGLMModel (XGLM model)
- XLMConfig configuration class: XLMModel (XLM model)
- XLMProphetNetConfig configuration class: XLMProphetNetModel (XLM-ProphetNet model)
- XLMRobertaConfig configuration class: XLMRobertaModel (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLModel (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetModel (XLNet model)
- XmodConfig configuration class: XmodModel (X-MOD model)
- YolosConfig configuration class: YolosModel (YOLOS model)
- YosoConfig configuration class: YosoModel (YOSO model)
- Zamba2Config configuration class: Zamba2Model (Zamba2 model)
- ZambaConfig configuration class: ZambaModel (Zamba model)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention )。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的基础模型类之一。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,应将from_tf
设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加的位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。在以下情况下可以自动加载配置:
- 该模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型使用 save_pretrained() 保存,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 要使用的状态字典,而不是从已保存的权重文件加载的状态字典。
如果要从预训练配置创建模型但加载自己的权重,可以使用此选项。不过,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 目录的路径,如果应不使用标准缓存,则应在该目录中缓存下载的预训练模型配置。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应针对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键(对应于配置属性)将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的基础模型类之一。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertModel (ALBERT 模型)
- align — AlignModel (ALIGN 模型)
- altclip — AltCLIPModel (AltCLIP 模型)
- aria — AriaForConditionalGeneration (Aria 模型)
- aria_text — AriaTextModel (AriaText 模型)
- audio-spectrogram-transformer — ASTModel (音频频谱图转换器模型)
- autoformer — AutoformerModel (Autoformer 模型)
- bamba — BambaModel (Bamba 模型)
- bark — BarkModel (Bark 模型)
- bart — BartModel (BART 模型)
- beit — BeitModel (BEiT 模型)
- bert — BertModel (BERT 模型)
- bert-generation — BertGenerationEncoder (Bert Generation 模型)
- big_bird — BigBirdModel (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusModel (BigBird-Pegasus 模型)
- biogpt — BioGptModel (BioGpt 模型)
- bit — BitModel (BiT 模型)
- blenderbot — BlenderbotModel (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallModel (BlenderbotSmall 模型)
- blip — BlipModel (BLIP 模型)
- blip-2 — Blip2Model (BLIP-2 模型)
- bloom — BloomModel (BLOOM 模型)
- bridgetower — BridgeTowerModel (BridgeTower 模型)
- bros — BrosModel (BROS 模型)
- camembert — CamembertModel (CamemBERT 模型)
- canine — CanineModel (CANINE 模型)
- chameleon — ChameleonModel (Chameleon 模型)
- chinese_clip — ChineseCLIPModel (Chinese-CLIP 模型)
- chinese_clip_vision_model — ChineseCLIPVisionModel (ChineseCLIPVisionModel 模型)
- clap — ClapModel (CLAP 模型)
- clip — CLIPModel (CLIP 模型)
- clip_text_model — CLIPTextModel (CLIPTextModel 模型)
- clip_vision_model — CLIPVisionModel (CLIPVisionModel 模型)
- clipseg — CLIPSegModel (CLIPSeg 模型)
- clvp — ClvpModelForConditionalGeneration (CLVP 模型)
- code_llama — LlamaModel (CodeLlama 模型)
- codegen — CodeGenModel (CodeGen 模型)
- cohere — CohereModel (Cohere 模型)
- cohere2 — Cohere2Model (Cohere2 模型)
- conditional_detr — ConditionalDetrModel (Conditional DETR 模型)
- convbert — ConvBertModel (ConvBERT 模型)
- convnext — ConvNextModel (ConvNeXT 模型)
- convnextv2 — ConvNextV2Model (ConvNeXTV2 模型)
- cpmant — CpmAntModel (CPM-Ant 模型)
- ctrl — CTRLModel (CTRL 模型)
- cvt — CvtModel (CvT 模型)
- dab-detr — DabDetrModel (DAB-DETR 模型)
- dac — DacModel (DAC 模型)
- data2vec-audio — Data2VecAudioModel (Data2VecAudio 模型)
- data2vec-text — Data2VecTextModel (Data2VecText 模型)
- data2vec-vision — Data2VecVisionModel (Data2VecVision 模型)
- dbrx — DbrxModel (DBRX 模型)
- deberta — DebertaModel (DeBERTa 模型)
- deberta-v2 — DebertaV2Model (DeBERTa-v2 模型)
- decision_transformer — DecisionTransformerModel (Decision Transformer 模型)
- deformable_detr — DeformableDetrModel (Deformable DETR 模型)
- deit — DeiTModel (DeiT 模型)
- depth_pro — DepthProModel (DepthPro 模型)
- deta — DetaModel (DETA 模型)
- detr — DetrModel (DETR 模型)
- diffllama — DiffLlamaModel (DiffLlama 模型)
- dinat — DinatModel (DiNAT 模型)
- dinov2 — Dinov2Model (DINOv2 模型)
- dinov2_with_registers — Dinov2WithRegistersModel (带寄存器的 DINOv2 模型)
- distilbert — DistilBertModel (DistilBERT 模型)
- donut-swin — DonutSwinModel (DonutSwin 模型)
- dpr — DPRQuestionEncoder (DPR 模型)
- dpt — DPTModel (DPT 模型)
- efficientformer — EfficientFormerModel (EfficientFormer 模型)
- efficientnet — EfficientNetModel (EfficientNet 模型)
- electra — ElectraModel (ELECTRA 模型)
- encodec — EncodecModel (EnCodec 模型)
- ernie — ErnieModel (ERNIE 模型)
- ernie_m — ErnieMModel (ErnieM 模型)
- esm — EsmModel (ESM 模型)
- falcon — FalconModel (Falcon 模型)
- falcon_mamba — FalconMambaModel (FalconMamba 模型)
- fastspeech2_conformer — FastSpeech2ConformerModel (FastSpeech2Conformer 模型)
- flaubert — FlaubertModel (FlauBERT 模型)
- flava — FlavaModel (FLAVA 模型)
- fnet — FNetModel (FNet 模型)
- focalnet — FocalNetModel (FocalNet 模型)
- fsmt — FSMTModel (FairSeq 机器翻译模型)
- funnel — FunnelModel 或 FunnelBaseModel (Funnel Transformer 模型)
- gemma — GemmaModel (Gemma 模型)
- gemma2 — Gemma2Model (Gemma2 模型)
- gemma3_text — Gemma3TextModel (Gemma3ForCausalLM 模型)
- git — GitModel (GIT 模型)
- glm — GlmModel (GLM 模型)
- glpn — GLPNModel (GLPN 模型)
- got_ocr2 — GotOcr2ForConditionalGeneration (GOT-OCR2 模型)
- gpt-sw3 — GPT2Model (GPT-Sw3 模型)
- gpt2 — GPT2Model (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeModel (GPTBigCode 模型)
- gpt_neo — GPTNeoModel (GPT Neo 模型)
- gpt_neox — GPTNeoXModel (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseModel (GPT NeoX Japanese 模型)
- gptj — GPTJModel (GPT-J 模型)
- gptsan-japanese — GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- granite — GraniteModel (Granite 模型)
- granitemoe — GraniteMoeModel (GraniteMoeMoe 模型)
- granitemoeshared — GraniteMoeSharedModel (GraniteMoeSharedMoe 模型)
- graphormer — GraphormerModel (Graphormer 模型)
- grounding-dino — GroundingDinoModel (Grounding DINO 模型)
- groupvit — GroupViTModel (GroupViT 模型)
- helium — HeliumModel (Helium 模型)
- hiera — HieraModel (Hiera 模型)
- hubert — HubertModel (Hubert 模型)
- ibert — IBertModel (I-BERT 模型)
- idefics — IdeficsModel (IDEFICS 模型)
- idefics2 — Idefics2Model (Idefics2 模型)
- idefics3 — Idefics3Model (Idefics3 模型)
- idefics3_vision — Idefics3VisionTransformer (Idefics3VisionTransformer 模型)
- ijepa — IJepaModel (I-JEPA 模型)
- imagegpt — ImageGPTModel (ImageGPT 模型)
- informer — InformerModel (Informer 模型)
- jamba — JambaModel (Jamba 模型)
- jetmoe — JetMoeModel (JetMoe 模型)
- jukebox — JukeboxModel (Jukebox 模型)
- kosmos-2 — Kosmos2Model (KOSMOS-2 模型)
- layoutlm — LayoutLMModel (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2Model (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3Model (LayoutLMv3 模型)
- led — LEDModel (LED 模型)
- levit — LevitModel (LeViT 模型)
- lilt — LiltModel (LiLT 模型)
- llama — LlamaModel (LLaMA 模型)
- longformer — LongformerModel (Longformer 模型)
- longt5 — LongT5Model (LongT5 模型)
- luke — LukeModel (LUKE 模型)
- lxmert — LxmertModel (LXMERT 模型)
- m2m_100 — M2M100Model (M2M100 模型)
- mamba — MambaModel (Mamba 模型)
- mamba2 — Mamba2Model (mamba2 模型)
- marian — MarianModel (Marian 模型)
- markuplm — MarkupLMModel (MarkupLM 模型)
- mask2former — Mask2FormerModel (Mask2Former 模型)
- maskformer — MaskFormerModel (MaskFormer 模型)
- maskformer-swin —
MaskFormerSwinModel
(MaskFormerSwin 模型) - mbart — MBartModel (mBART 模型)
- mctct — MCTCTModel (M-CTC-T 模型)
- mega — MegaModel (MEGA 模型)
- megatron-bert — MegatronBertModel (Megatron-BERT 模型)
- mgp-str — MgpstrForSceneTextRecognition (MGP-STR 模型)
- mimi — MimiModel (Mimi 模型)
- mistral — MistralModel (Mistral 模型)
- mixtral — MixtralModel (Mixtral 模型)
- mobilebert — MobileBertModel (MobileBERT 模型)
- mobilenet_v1 — MobileNetV1Model (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2Model (MobileNetV2 模型)
- mobilevit — MobileViTModel (MobileViT 模型)
- mobilevitv2 — MobileViTV2Model (MobileViTV2 模型)
- modernbert — ModernBertModel (ModernBERT 模型)
- moonshine — MoonshineModel (Moonshine 模型)
- moshi — MoshiModel (Moshi 模型)
- mpnet — MPNetModel (MPNet 模型)
- mpt — MptModel (MPT 模型)
- mra — MraModel (MRA 模型)
- mt5 — MT5Model (MT5 模型)
- musicgen — MusicgenModel (MusicGen 模型)
- musicgen_melody — MusicgenMelodyModel (MusicGen Melody 模型)
- mvp — MvpModel (MVP 模型)
- nat — NatModel (NAT 模型)
- nemotron — NemotronModel (Nemotron 模型)
- nezha — NezhaModel (Nezha 模型)
- nllb-moe — NllbMoeModel (NLLB-MOE 模型)
- nystromformer — NystromformerModel (Nyströmformer 模型)
- olmo — OlmoModel (OLMo 模型)
- olmo2 — Olmo2Model (OLMo2 模型)
- olmoe — OlmoeModel (OLMoE 模型)
- omdet-turbo — OmDetTurboForObjectDetection (OmDet-Turbo 模型)
- oneformer — OneFormerModel (OneFormer 模型)
- open-llama — OpenLlamaModel (OpenLlama 模型)
- openai-gpt — OpenAIGPTModel (OpenAI GPT 模型)
- opt — OPTModel (OPT 模型)
- owlv2 — Owlv2Model (OWLv2 模型)
- owlvit — OwlViTModel (OWL-ViT 模型)
- patchtsmixer — PatchTSMixerModel (PatchTSMixer 模型)
- patchtst — PatchTSTModel (PatchTST 模型)
- pegasus — PegasusModel (Pegasus 模型)
- pegasus_x — PegasusXModel (PEGASUS-X 模型)
- perceiver — PerceiverModel (Perceiver 模型)
- persimmon — PersimmonModel (Persimmon 模型)
- phi — PhiModel (Phi 模型)
- phi3 — Phi3Model (Phi3 模型)
- phimoe — PhimoeModel (Phimoe 模型)
- pixtral — PixtralVisionModel (Pixtral 模型)
- plbart — PLBartModel (PLBart 模型)
- poolformer — PoolFormerModel (PoolFormer 模型)
- prophetnet — ProphetNetModel (ProphetNet 模型)
- pvt — PvtModel (PVT 模型)
- pvt_v2 — PvtV2Model (PVTv2 模型)
- qdqbert — QDQBertModel (QDQBert 模型)
- qwen2 — Qwen2Model (Qwen2 模型)
- qwen2_5_vl — Qwen2_5_VLModel (Qwen2_5_VL 模型)
- qwen2_audio_encoder — Qwen2AudioEncoder (Qwen2AudioEncoder 模型)
- qwen2_moe — Qwen2MoeModel (Qwen2MoE 模型)
- qwen2_vl — Qwen2VLModel (Qwen2VL 模型)
- recurrent_gemma — RecurrentGemmaModel (RecurrentGemma 模型)
- reformer — ReformerModel (Reformer 模型)
- regnet — RegNetModel (RegNet 模型)
- rembert — RemBertModel (RemBERT 模型)
- resnet — ResNetModel (ResNet 模型)
- retribert — RetriBertModel (RetriBERT 模型)
- roberta — RobertaModel (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertModel (RoCBert 模型)
- roformer — RoFormerModel (RoFormer 模型)
- rt_detr — RTDetrModel (RT-DETR 模型)
- rt_detr_v2 — RTDetrV2Model (RT-DETRv2 模型)
- rwkv — RwkvModel (RWKV 模型)
- sam — SamModel (SAM 模型)
- seamless_m4t — SeamlessM4TModel (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2Model (SeamlessM4Tv2 模型)
- segformer — SegformerModel (SegFormer 模型)
- seggpt — SegGptModel (SegGPT 模型)
- sew — SEWModel (SEW 模型)
- sew-d — SEWDModel (SEW-D 模型)
- siglip — SiglipModel (SigLIP 模型)
- siglip2 — Siglip2Model (SigLIP2 模型)
- siglip_vision_model — SiglipVisionModel (SiglipVisionModel 模型)
- smolvlm — SmolVLMModel (SmolVLM 模型)
- smolvlm_vision — SmolVLMVisionTransformer (SmolVLMVisionTransformer 模型)
- speech_to_text — Speech2TextModel (Speech2Text 模型)
- speecht5 — SpeechT5Model (SpeechT5 模型)
- splinter — SplinterModel (Splinter 模型)
- squeezebert — SqueezeBertModel (SqueezeBERT 模型)
- stablelm — StableLmModel (StableLm 模型)
- starcoder2 — Starcoder2Model (Starcoder2 模型)
- superglue — SuperGlueForKeypointMatching (SuperGlue 模型)
- swiftformer — SwiftFormerModel (SwiftFormer 模型)
- swin — SwinModel (Swin Transformer 模型)
- swin2sr — Swin2SRModel (Swin2SR 模型)
- swinv2 — Swinv2Model (Swin Transformer V2 模型)
- switch_transformers — SwitchTransformersModel (SwitchTransformers 模型)
- t5 — T5Model (T5 模型)
- table-transformer — TableTransformerModel (Table Transformer 模型)
- tapas — TapasModel (TAPAS 模型)
- textnet — TextNetModel (TextNet 模型)
- time_series_transformer — TimeSeriesTransformerModel (Time Series Transformer 模型)
- timesformer — TimesformerModel (TimeSformer 模型)
- timm_backbone — TimmBackbone (TimmBackbone 模型)
- timm_wrapper — TimmWrapperModel (TimmWrapperModel 模型)
- trajectory_transformer — TrajectoryTransformerModel (Trajectory Transformer 模型)
- transfo-xl — TransfoXLModel (Transformer-XL 模型)
- tvlt — TvltModel (TVLT 模型)
- tvp — TvpModel (TVP 模型)
- udop — UdopModel (UDOP 模型)
- umt5 — UMT5Model (UMT5 模型)
- unispeech — UniSpeechModel (UniSpeech 模型)
- unispeech-sat — UniSpeechSatModel (UniSpeechSat 模型)
- univnet — UnivNetModel (UnivNet 模型)
- van — VanModel (VAN 模型)
- videomae — VideoMAEModel (VideoMAE 模型)
- vilt — ViltModel (ViLT 模型)
- vision-text-dual-encoder — VisionTextDualEncoderModel (VisionTextDualEncoder 模型)
- visual_bert — VisualBertModel (VisualBERT 模型)
- vit — ViTModel (ViT 模型)
- vit_hybrid — ViTHybridModel (ViT Hybrid 模型)
- vit_mae — ViTMAEModel (ViTMAE 模型)
- vit_msn — ViTMSNModel (ViTMSN 模型)
- vitdet — VitDetModel (VitDet 模型)
- vits — VitsModel (VITS 模型)
- vivit — VivitModel (ViViT 模型)
- wav2vec2 — Wav2Vec2Model (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertModel (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerModel (Wav2Vec2-Conformer 模型)
- wavlm — WavLMModel (WavLM 模型)
- whisper — WhisperModel (Whisper 模型)
- xclip — XCLIPModel (X-CLIP 模型)
- xglm — XGLMModel (XGLM 模型)
- xlm — XLMModel (XLM 模型)
- xlm-prophetnet — XLMProphetNetModel (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaModel (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLModel (XLM-RoBERTa-XL 模型)
- xlnet — XLNetModel (XLNet 模型)
- xmod — XmodModel (X-MOD 模型)
- yolos — YolosModel (YOLOS 模型)
- yoso — YosoModel (YOSO 模型)
- zamba — ZambaModel (Zamba 模型)
- zamba2 — Zamba2Model (Zamba2 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModel.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModel
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的基础模型类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< 源代码 >( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- AlbertConfig 配置类: TFAlbertModel (ALBERT 模型)
- BartConfig 配置类: TFBartModel (BART 模型)
- BertConfig 配置类: TFBertModel (BERT 模型)
- BlenderbotConfig 配置类: TFBlenderbotModel (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: TFBlenderbotSmallModel (BlenderbotSmall 模型)
- BlipConfig 配置类: TFBlipModel (BLIP 模型)
- CLIPConfig 配置类: TFCLIPModel (CLIP 模型)
- CTRLConfig 配置类: TFCTRLModel (CTRL 模型)
- CamembertConfig 配置类: TFCamembertModel (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertModel (ConvBERT 模型)
- ConvNextConfig 配置类: TFConvNextModel (ConvNeXT 模型)
- ConvNextV2Config 配置类: TFConvNextV2Model (ConvNeXTV2 模型)
- CvtConfig 配置类: TFCvtModel (CvT 模型)
- DPRConfig 配置类: TFDPRQuestionEncoder (DPR 模型)
- Data2VecVisionConfig 配置类: TFData2VecVisionModel (Data2VecVision 模型)
- DebertaConfig 配置类: TFDebertaModel (DeBERTa 模型)
- DebertaV2Config 配置类: TFDebertaV2Model (DeBERTa-v2 模型)
- DeiTConfig 配置类: TFDeiTModel (DeiT 模型)
- DistilBertConfig 配置类: TFDistilBertModel (DistilBERT 模型)
- EfficientFormerConfig 配置类: TFEfficientFormerModel (EfficientFormer 模型)
- ElectraConfig 配置类: TFElectraModel (ELECTRA 模型)
- EsmConfig 配置类: TFEsmModel (ESM 模型)
- FlaubertConfig 配置类: TFFlaubertModel (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelModel 或 TFFunnelBaseModel (Funnel Transformer 模型)
- GPT2Config 配置类: TFGPT2Model (OpenAI GPT-2 模型)
- GPTJConfig 配置类: TFGPTJModel (GPT-J 模型)
- GroupViTConfig 配置类: TFGroupViTModel (GroupViT 模型)
- HubertConfig 配置类: TFHubertModel (Hubert 模型)
- IdeficsConfig 配置类: TFIdeficsModel (IDEFICS 模型)
- LEDConfig 配置类: TFLEDModel (LED 模型)
- LayoutLMConfig 配置类: TFLayoutLMModel (LayoutLM 模型)
- LayoutLMv3Config 配置类: TFLayoutLMv3Model (LayoutLMv3 模型)
- LongformerConfig 配置类: TFLongformerModel (Longformer 模型)
- LxmertConfig 配置类: TFLxmertModel (LXMERT 模型)
- MBartConfig 配置类: TFMBartModel (mBART 模型)
- MPNetConfig 配置类: TFMPNetModel (MPNet 模型)
- MT5Config 配置类: TFMT5Model (MT5 模型)
- MarianConfig 配置类: TFMarianModel (Marian 模型)
- MistralConfig 配置类: TFMistralModel (Mistral 模型)
- MobileBertConfig 配置类: TFMobileBertModel (MobileBERT 模型)
- MobileViTConfig 配置类: TFMobileViTModel (MobileViT 模型)
- OPTConfig 配置类: TFOPTModel (OPT 模型)
- OpenAIGPTConfig 配置类: TFOpenAIGPTModel (OpenAI GPT 模型)
- PegasusConfig 配置类: TFPegasusModel (Pegasus 模型)
- RegNetConfig 配置类: TFRegNetModel (RegNet 模型)
- RemBertConfig 配置类: TFRemBertModel (RemBERT 模型)
- ResNetConfig 配置类: TFResNetModel (ResNet 模型)
- RoFormerConfig 配置类: TFRoFormerModel (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaModel (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- SamConfig 配置类: TFSamModel (SAM 模型)
- SegformerConfig 配置类: TFSegformerModel (SegFormer 模型)
- Speech2TextConfig 配置类: TFSpeech2TextModel (Speech2Text 模型)
- SwiftFormerConfig 配置类: TFSwiftFormerModel (SwiftFormer 模型)
- SwinConfig 配置类: TFSwinModel (Swin Transformer 模型)
- T5Config 配置类: TFT5Model (T5 模型)
- TapasConfig 配置类: TFTapasModel (TAPAS 模型)
- TransfoXLConfig 配置类: TFTransfoXLModel (Transformer-XL 模型)
- ViTConfig 配置类: TFViTModel (ViT 模型)
- ViTMAEConfig 配置类: TFViTMAEModel (ViTMAE 模型)
- VisionTextDualEncoderConfig 配置类: TFVisionTextDualEncoderModel (VisionTextDualEncoder 模型)
- Wav2Vec2Config 配置类: TFWav2Vec2Model (Wav2Vec2 模型)
- WhisperConfig 配置类: TFWhisperModel (Whisper 模型)
- XGLMConfig 配置类: TFXGLMModel (XGLM 模型)
- XLMConfig 配置类: TFXLMModel (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaModel (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetModel (XLNet 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现方式(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则,默认为手动"eager"
实现。
从配置实例化库的基础模型类之一。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的的位置参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。在以下情况下可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应针对您信任的存储库和您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性对应,将用于使用提供的kwargs
值覆盖所述属性。不与任何配置属性对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的基础模型类之一。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — TFAlbertModel (ALBERT 模型)
- bart — TFBartModel (BART 模型)
- bert — TFBertModel (BERT 模型)
- blenderbot — TFBlenderbotModel (Blenderbot 模型)
- blenderbot-small — TFBlenderbotSmallModel (BlenderbotSmall 模型)
- blip — TFBlipModel (BLIP 模型)
- camembert — TFCamembertModel (CamemBERT 模型)
- clip — TFCLIPModel (CLIP 模型)
- convbert — TFConvBertModel (ConvBERT 模型)
- convnext — TFConvNextModel (ConvNeXT 模型)
- convnextv2 — TFConvNextV2Model (ConvNeXTV2 模型)
- ctrl — TFCTRLModel (CTRL 模型)
- cvt — TFCvtModel (CvT 模型)
- data2vec-vision — TFData2VecVisionModel (Data2VecVision 模型)
- deberta — TFDebertaModel (DeBERTa 模型)
- deberta-v2 — TFDebertaV2Model (DeBERTa-v2 模型)
- deit — TFDeiTModel (DeiT 模型)
- distilbert — TFDistilBertModel (DistilBERT 模型)
- dpr — TFDPRQuestionEncoder (DPR 模型)
- efficientformer — TFEfficientFormerModel (EfficientFormer 模型)
- electra — TFElectraModel (ELECTRA 模型)
- esm — TFEsmModel (ESM 模型)
- flaubert — TFFlaubertModel (FlauBERT 模型)
- funnel — TFFunnelModel 或 TFFunnelBaseModel (Funnel Transformer 模型)
- gpt-sw3 — TFGPT2Model (GPT-Sw3 模型)
- gpt2 — TFGPT2Model (OpenAI GPT-2 模型)
- gptj — TFGPTJModel (GPT-J 模型)
- groupvit — TFGroupViTModel (GroupViT 模型)
- hubert — TFHubertModel (Hubert 模型)
- idefics — TFIdeficsModel (IDEFICS 模型)
- layoutlm — TFLayoutLMModel (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3Model (LayoutLMv3 模型)
- led — TFLEDModel (LED 模型)
- longformer — TFLongformerModel (Longformer 模型)
- lxmert — TFLxmertModel (LXMERT 模型)
- marian — TFMarianModel (Marian 模型)
- mbart — TFMBartModel (mBART 模型)
- mistral — TFMistralModel (Mistral 模型)
- mobilebert — TFMobileBertModel (MobileBERT 模型)
- mobilevit — TFMobileViTModel (MobileViT 模型)
- mpnet — TFMPNetModel (MPNet 模型)
- mt5 — TFMT5Model (MT5 模型)
- openai-gpt — TFOpenAIGPTModel (OpenAI GPT 模型)
- opt — TFOPTModel (OPT 模型)
- pegasus — TFPegasusModel (Pegasus 模型)
- regnet — TFRegNetModel (RegNet 模型)
- rembert — TFRemBertModel (RemBERT 模型)
- resnet — TFResNetModel (ResNet 模型)
- roberta — TFRobertaModel (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerModel (RoFormer 模型)
- sam — TFSamModel (SAM 模型)
- segformer — TFSegformerModel (SegFormer 模型)
- speech_to_text — TFSpeech2TextModel (Speech2Text 模型)
- swiftformer — TFSwiftFormerModel (SwiftFormer 模型)
- swin — TFSwinModel (Swin Transformer 模型)
- t5 — TFT5Model (T5 模型)
- tapas — TFTapasModel (TAPAS 模型)
- transfo-xl — TFTransfoXLModel (Transformer-XL 模型)
- vision-text-dual-encoder — TFVisionTextDualEncoderModel (VisionTextDualEncoder 模型)
- vit — TFViTModel (ViT 模型)
- vit_mae — TFViTMAEModel (ViTMAE 模型)
- wav2vec2 — TFWav2Vec2Model (Wav2Vec2 模型)
- whisper — TFWhisperModel (Whisper 模型)
- xglm — TFXGLMModel (XGLM 模型)
- xlm — TFXLMModel (XLM 模型)
- xlm-roberta — TFXLMRobertaModel (XLM-RoBERTa 模型)
- xlnet — TFXLNetModel (XLNet 模型)
示例
>>> from transformers import AutoConfig, TFAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModel
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的基础模型类之一。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 根据配置类选择要实例化的模型类:
- AlbertConfig 配置类: FlaxAlbertModel (ALBERT 模型)
- BartConfig 配置类: FlaxBartModel (BART 模型)
- BeitConfig 配置类: FlaxBeitModel (BEiT 模型)
- BertConfig 配置类: FlaxBertModel (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdModel (BigBird 模型)
- BlenderbotConfig 配置类: FlaxBlenderbotModel (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: FlaxBlenderbotSmallModel (BlenderbotSmall 模型)
- BloomConfig 配置类: FlaxBloomModel (BLOOM 模型)
- CLIPConfig 配置类: FlaxCLIPModel (CLIP 模型)
- Dinov2Config 配置类: FlaxDinov2Model (DINOv2 模型)
- DistilBertConfig 配置类: FlaxDistilBertModel (DistilBERT 模型)
- ElectraConfig 配置类: FlaxElectraModel (ELECTRA 模型)
- GPT2Config 配置类: FlaxGPT2Model (OpenAI GPT-2 模型)
- GPTJConfig 配置类: FlaxGPTJModel (GPT-J 模型)
- GPTNeoConfig 配置类: FlaxGPTNeoModel (GPT Neo 模型)
- GemmaConfig 配置类: FlaxGemmaModel (Gemma 模型)
- LlamaConfig 配置类: FlaxLlamaModel (LLaMA 模型)
- LongT5Config 配置类: FlaxLongT5Model (LongT5 模型)
- MBartConfig 配置类: FlaxMBartModel (mBART 模型)
- MT5Config 配置类: FlaxMT5Model (MT5 模型)
- MarianConfig 配置类: FlaxMarianModel (Marian 模型)
- MistralConfig 配置类: FlaxMistralModel (Mistral 模型)
- OPTConfig 配置类: FlaxOPTModel (OPT 模型)
- PegasusConfig 配置类: FlaxPegasusModel (Pegasus 模型)
- RegNetConfig 配置类: FlaxRegNetModel (RegNet 模型)
- ResNetConfig 配置类: FlaxResNetModel (ResNet 模型)
- RoFormerConfig 配置类: FlaxRoFormerModel (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaModel (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- T5Config 配置类: FlaxT5Model (T5 模型)
- ViTConfig 配置类: FlaxViTModel (ViT 模型)
- VisionTextDualEncoderConfig 配置类: FlaxVisionTextDualEncoderModel (VisionTextDualEncoder 模型)
- Wav2Vec2Config 配置类: FlaxWav2Vec2Model (Wav2Vec2 模型)
- WhisperConfig 配置类: FlaxWhisperModel (Whisper 模型)
- XGLMConfig 配置类: FlaxXGLMModel (XGLM 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaModel (XLM-RoBERTa 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的基础模型类之一。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件 的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不想使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许使用 Hub 上自定义模型文件中定义的自定义模型。 此选项仅应设置为True
,用于您信任的存储库以及您已阅读代码的存储库,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
,行为有所不同:- 如果提供了带
config
的配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性相对应,都将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了带
从预训练模型实例化库的基础模型类之一。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — FlaxAlbertModel (ALBERT 模型)
- bart — FlaxBartModel (BART 模型)
- beit — FlaxBeitModel (BEiT 模型)
- bert — FlaxBertModel (BERT 模型)
- big_bird — FlaxBigBirdModel (BigBird 模型)
- blenderbot — FlaxBlenderbotModel (Blenderbot 模型)
- blenderbot-small — FlaxBlenderbotSmallModel (BlenderbotSmall 模型)
- bloom — FlaxBloomModel (BLOOM 模型)
- clip — FlaxCLIPModel (CLIP 模型)
- dinov2 — FlaxDinov2Model (DINOv2 模型)
- distilbert — FlaxDistilBertModel (DistilBERT 模型)
- electra — FlaxElectraModel (ELECTRA 模型)
- gemma — FlaxGemmaModel (Gemma 模型)
- gpt-sw3 — FlaxGPT2Model (GPT-Sw3 模型)
- gpt2 — FlaxGPT2Model (OpenAI GPT-2 模型)
- gpt_neo — FlaxGPTNeoModel (GPT Neo 模型)
- gptj — FlaxGPTJModel (GPT-J 模型)
- llama — FlaxLlamaModel (LLaMA 模型)
- longt5 — FlaxLongT5Model (LongT5 模型)
- marian — FlaxMarianModel (Marian 模型)
- mbart — FlaxMBartModel (mBART 模型)
- mistral — FlaxMistralModel (Mistral 模型)
- mt5 — FlaxMT5Model (MT5 模型)
- opt — FlaxOPTModel (OPT 模型)
- pegasus — FlaxPegasusModel (Pegasus 模型)
- regnet — FlaxRegNetModel (RegNet 模型)
- resnet — FlaxResNetModel (ResNet 模型)
- roberta — FlaxRobertaModel (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerModel (RoFormer 模型)
- t5 — FlaxT5Model (T5 模型)
- vision-text-dual-encoder — FlaxVisionTextDualEncoderModel (VisionTextDualEncoder 模型)
- vit — FlaxViTModel (ViT 模型)
- wav2vec2 — FlaxWav2Vec2Model (Wav2Vec2 模型)
- whisper — FlaxWhisperModel (Whisper 模型)
- xglm — FlaxXGLMModel (XGLM 模型)
- xlm-roberta — FlaxXLMRobertaModel (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModel.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModel.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
通用预训练类
以下自动类可用于实例化带有预训练头的模型。
AutoModelForPreTraining
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有预训练头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: AlbertForPreTraining (ALBERT 模型)
- BartConfig 配置类: BartForConditionalGeneration (BART 模型)
- BertConfig 配置类: BertForPreTraining (BERT 模型)
- BigBirdConfig 配置类: BigBirdForPreTraining (BigBird 模型)
- BloomConfig 配置类: BloomForCausalLM (BLOOM 模型)
- CTRLConfig 配置类: CTRLLMHeadModel (CTRL 模型)
- CamembertConfig 配置类: CamembertForMaskedLM (CamemBERT 模型)
- ColPaliConfig 配置类: ColPaliForRetrieval (ColPali 模型)
- Data2VecTextConfig 配置类: Data2VecTextForMaskedLM (Data2VecText 模型)
- DebertaConfig 配置类: DebertaForMaskedLM (DeBERTa 模型)
- DebertaV2Config 配置类: DebertaV2ForMaskedLM (DeBERTa-v2 模型)
- DistilBertConfig 配置类: DistilBertForMaskedLM (DistilBERT 模型)
- ElectraConfig 配置类: ElectraForPreTraining (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForPreTraining (ERNIE 模型)
- FNetConfig 配置类: FNetForPreTraining (FNet 模型)
- FSMTConfig 配置类: FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- FalconMambaConfig 配置类: FalconMambaForCausalLM (FalconMamba 模型)
- FlaubertConfig 配置类: FlaubertWithLMHeadModel (FlauBERT 模型)
- FlavaConfig 配置类: FlavaForPreTraining (FLAVA 模型)
- FunnelConfig 配置类: FunnelForPreTraining (Funnel Transformer 模型)
- GPT2Config 配置类: GPT2LMHeadModel (OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类: GPTBigCodeForCausalLM (GPTBigCode 模型)
- GPTSanJapaneseConfig 配置类: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- Gemma3Config 配置类: Gemma3ForConditionalGeneration (Gemma3ForConditionalGeneration 模型)
- HieraConfig 配置类: HieraForPreTraining (Hiera 模型)
- IBertConfig 配置类: IBertForMaskedLM (I-BERT 模型)
- Idefics2Config 配置类: Idefics2ForConditionalGeneration (Idefics2 模型)
- Idefics3Config 配置类: Idefics3ForConditionalGeneration (Idefics3 模型)
- IdeficsConfig 配置类: IdeficsForVisionText2Text (IDEFICS 模型)
- LayoutLMConfig 配置类: LayoutLMForMaskedLM (LayoutLM 模型)
- LlavaConfig 配置类: LlavaForConditionalGeneration (LLaVa 模型)
- LlavaNextConfig 配置类: LlavaNextForConditionalGeneration (LLaVA-NeXT 模型)
- LlavaNextVideoConfig 配置类: LlavaNextVideoForConditionalGeneration (LLaVa-NeXT-Video 模型)
- LlavaOnevisionConfig 配置类: LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- LongformerConfig 配置类: LongformerForMaskedLM (Longformer 模型)
- LukeConfig 配置类: LukeForMaskedLM (LUKE 模型)
- LxmertConfig 配置类: LxmertForPreTraining (LXMERT 模型)
- MPNetConfig 配置类: MPNetForMaskedLM (MPNet 模型)
- Mamba2Config 配置类: Mamba2ForCausalLM (mamba2 模型)
- MambaConfig 配置类: MambaForCausalLM (Mamba 模型)
- MegaConfig 配置类: MegaForMaskedLM (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForPreTraining (Megatron-BERT 模型)
- Mistral3Config 配置类: Mistral3ForConditionalGeneration (Mistral3 模型)
- MllamaConfig 配置类: MllamaForConditionalGeneration (Mllama 模型)
- MobileBertConfig 配置类: MobileBertForPreTraining (MobileBERT 模型)
- MptConfig 配置类: MptForCausalLM (MPT 模型)
- MraConfig 配置类: MraForMaskedLM (MRA 模型)
- MvpConfig 配置类: MvpForConditionalGeneration (MVP 模型)
- NezhaConfig 配置类: NezhaForPreTraining (Nezha 模型)
- NllbMoeConfig 配置类: NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- PaliGemmaConfig 配置类: PaliGemmaForConditionalGeneration (PaliGemma 模型)
- Qwen2AudioConfig 配置类: Qwen2AudioForConditionalGeneration (Qwen2Audio 模型)
- RetriBertConfig 配置类: RetriBertModel (RetriBERT 模型)
- RoCBertConfig 配置类: RoCBertForPreTraining (RoCBert 模型)
- RobertaConfig 配置类: RobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- RwkvConfig 配置类: RwkvForCausalLM (RWKV 模型)
- SplinterConfig 配置类: SplinterForPreTraining (Splinter 模型)
- SqueezeBertConfig 配置类: SqueezeBertForMaskedLM (SqueezeBERT 模型)
- SwitchTransformersConfig 配置类: SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- T5Config 配置类: T5ForConditionalGeneration (T5 模型)
- TapasConfig 配置类: TapasForMaskedLM (TAPAS 模型)
- TransfoXLConfig 配置类: TransfoXLLMHeadModel (Transformer-XL 模型)
- TvltConfig 配置类: TvltForPreTraining (TVLT 模型)
- UniSpeechConfig 配置类: UniSpeechForPreTraining (UniSpeech 模型)
- UniSpeechSatConfig 配置类: UniSpeechSatForPreTraining (UniSpeechSat 模型)
- ViTMAEConfig 配置类: ViTMAEForPreTraining (ViTMAE 模型)
- VideoLlavaConfig 配置类: VideoLlavaForConditionalGeneration (VideoLlava 模型)
- VideoMAEConfig 配置类: VideoMAEForPreTraining (VideoMAE 模型)
- VipLlavaConfig 配置类: VipLlavaForConditionalGeneration (VipLlava 模型)
- VisualBertConfig 配置类: VisualBertForPreTraining (VisualBERT 模型)
- Wav2Vec2Config 配置类: Wav2Vec2ForPreTraining (Wav2Vec2 模型)
- Wav2Vec2ConformerConfig 配置类: Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer 模型)
- XLMConfig 配置类: XLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetLMHeadModel (XLNet 模型)
- XmodConfig 配置类: XmodForMaskedLM (X-MOD 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(手动实现的注意力)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一种。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有预训练头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,huggingface.co 上模型存储库中托管的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。 在这种情况下,from_tf
应设置为True
,并且应将配置对象作为config
参数提供。 此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当满足以下条件时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 用于替代从已保存权重文件加载的状态字典的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 下载的预训练模型配置应缓存到的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上自定义模型中定义的模型文件。此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供了config
或自动加载配置,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖该属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有预训练头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertForPreTraining (ALBERT 模型)
- bart — BartForConditionalGeneration (BART 模型)
- bert — BertForPreTraining (BERT 模型)
- big_bird — BigBirdForPreTraining (BigBird 模型)
- bloom — BloomForCausalLM (BLOOM 模型)
- camembert — CamembertForMaskedLM (CamemBERT 模型)
- colpali — ColPaliForRetrieval (ColPali 模型)
- ctrl — CTRLLMHeadModel (CTRL 模型)
- data2vec-text — Data2VecTextForMaskedLM (Data2VecText 模型)
- deberta — DebertaForMaskedLM (DeBERTa 模型)
- deberta-v2 — DebertaV2ForMaskedLM (DeBERTa-v2 模型)
- distilbert — DistilBertForMaskedLM (DistilBERT 模型)
- electra — ElectraForPreTraining (ELECTRA 模型)
- ernie — ErnieForPreTraining (ERNIE 模型)
- falcon_mamba — FalconMambaForCausalLM (FalconMamba 模型)
- flaubert — FlaubertWithLMHeadModel (FlauBERT 模型)
- flava — FlavaForPreTraining (FLAVA 模型)
- fnet — FNetForPreTraining (FNet 模型)
- fsmt — FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- funnel — FunnelForPreTraining (Funnel Transformer 模型)
- gemma3 — Gemma3ForConditionalGeneration (Gemma3ForConditionalGeneration 模型)
- gpt-sw3 — GPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — GPT2LMHeadModel (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForCausalLM (GPTBigCode 模型)
- gptsan-japanese — GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- hiera — HieraForPreTraining (Hiera 模型)
- ibert — IBertForMaskedLM (I-BERT 模型)
- idefics — IdeficsForVisionText2Text (IDEFICS 模型)
- idefics2 — Idefics2ForConditionalGeneration (Idefics2 模型)
- idefics3 — Idefics3ForConditionalGeneration (Idefics3 模型)
- layoutlm — LayoutLMForMaskedLM (LayoutLM 模型)
- llava — LlavaForConditionalGeneration (LLaVa 模型)
- llava_next — LlavaNextForConditionalGeneration (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoForConditionalGeneration (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- longformer — LongformerForMaskedLM (Longformer 模型)
- luke — LukeForMaskedLM (LUKE 模型)
- lxmert — LxmertForPreTraining (LXMERT 模型)
- mamba — MambaForCausalLM (Mamba 模型)
- mamba2 — Mamba2ForCausalLM (mamba2 模型)
- mega — MegaForMaskedLM (MEGA 模型)
- megatron-bert — MegatronBertForPreTraining (Megatron-BERT 模型)
- mistral3 — Mistral3ForConditionalGeneration (Mistral3 模型)
- mllama — MllamaForConditionalGeneration (Mllama 模型)
- mobilebert — MobileBertForPreTraining (MobileBERT 模型)
- mpnet — MPNetForMaskedLM (MPNet 模型)
- mpt — MptForCausalLM (MPT 模型)
- mra — MraForMaskedLM (MRA 模型)
- mvp — MvpForConditionalGeneration (MVP 模型)
- nezha — NezhaForPreTraining (Nezha 模型)
- nllb-moe — NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- openai-gpt — OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- paligemma — PaliGemmaForConditionalGeneration (PaliGemma 模型)
- qwen2_audio — Qwen2AudioForConditionalGeneration (Qwen2Audio 模型)
- retribert — RetriBertModel (RetriBERT 模型)
- roberta — RobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForPreTraining (RoCBert 模型)
- rwkv — RwkvForCausalLM (RWKV 模型)
- splinter — SplinterForPreTraining (Splinter 模型)
- squeezebert — SqueezeBertForMaskedLM (SqueezeBERT 模型)
- switch_transformers — SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- t5 — T5ForConditionalGeneration (T5 模型)
- tapas — TapasForMaskedLM (TAPAS 模型)
- transfo-xl — TransfoXLLMHeadModel (Transformer-XL 模型)
- tvlt — TvltForPreTraining (TVLT 模型)
- unispeech — UniSpeechForPreTraining (UniSpeech 模型)
- unispeech-sat — UniSpeechSatForPreTraining (UniSpeechSat 模型)
- video_llava — VideoLlavaForConditionalGeneration (VideoLlava 模型)
- videomae — VideoMAEForPreTraining (VideoMAE 模型)
- vipllava — VipLlavaForConditionalGeneration (VipLlava 模型)
- visual_bert — VisualBertForPreTraining (VisualBERT 模型)
- vit_mae — ViTMAEForPreTraining (ViTMAE 模型)
- wav2vec2 — Wav2Vec2ForPreTraining (Wav2Vec2 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer 模型)
- xlm — XLMWithLMHeadModel (XLM 模型)
- xlm-roberta — XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- xlnet — XLNetLMHeadModel (XLNet 模型)
- xmod — XmodForMaskedLM (X-MOD 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForPreTraining.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForPreTraining
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有预训练头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- AlbertConfig 配置类: TFAlbertForPreTraining (ALBERT 模型)
- BartConfig 配置类: TFBartForConditionalGeneration (BART 模型)
- BertConfig 配置类: TFBertForPreTraining (BERT 模型)
- CTRLConfig 配置类: TFCTRLLMHeadModel (CTRL 模型)
- CamembertConfig 配置类: TFCamembertForMaskedLM (CamemBERT 模型)
- DistilBertConfig 配置类: TFDistilBertForMaskedLM (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForPreTraining (ELECTRA 模型)
- FlaubertConfig 配置类: TFFlaubertWithLMHeadModel (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForPreTraining (Funnel Transformer 模型)
- GPT2Config 配置类: TFGPT2LMHeadModel (OpenAI GPT-2 模型)
- IdeficsConfig 配置类: TFIdeficsForVisionText2Text (IDEFICS 模型)
- LayoutLMConfig 配置类: TFLayoutLMForMaskedLM (LayoutLM 模型)
- LxmertConfig 配置类: TFLxmertForPreTraining (LXMERT 模型)
- MPNetConfig 配置类: TFMPNetForMaskedLM (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForPreTraining (MobileBERT 模型)
- OpenAIGPTConfig 配置类: TFOpenAIGPTLMHeadModel (OpenAI GPT 模型)
- RobertaConfig 配置类: TFRobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- T5Config 配置类: TFT5ForConditionalGeneration (T5 模型)
- TapasConfig 配置类: TFTapasForMaskedLM (TAPAS 模型)
- TransfoXLConfig 配置类: TFTransfoXLLMHeadModel (Transformer-XL 模型)
- ViTMAEConfig 配置类: TFViTMAEForPreTraining (ViTMAE 模型)
- XLMConfig 配置类: TFXLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetLMHeadModel (XLNet 模型)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现方式(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
), 或者"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有预训练头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中托管的预训练模型的 模型 ID 。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的 positional 参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的 模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword 参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载配置,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性相对应,将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有预训练头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — TFAlbertForPreTraining (ALBERT 模型)
- bart — TFBartForConditionalGeneration (BART 模型)
- bert — TFBertForPreTraining (BERT 模型)
- camembert — TFCamembertForMaskedLM (CamemBERT 模型)
- ctrl — TFCTRLLMHeadModel (CTRL 模型)
- distilbert — TFDistilBertForMaskedLM (DistilBERT 模型)
- electra — TFElectraForPreTraining (ELECTRA 模型)
- flaubert — TFFlaubertWithLMHeadModel (FlauBERT 模型)
- funnel — TFFunnelForPreTraining (Funnel Transformer 模型)
- gpt-sw3 — TFGPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — TFGPT2LMHeadModel (OpenAI GPT-2 模型)
- idefics — TFIdeficsForVisionText2Text (IDEFICS 模型)
- layoutlm — TFLayoutLMForMaskedLM (LayoutLM 模型)
- lxmert — TFLxmertForPreTraining (LXMERT 模型)
- mobilebert — TFMobileBertForPreTraining (MobileBERT 模型)
- mpnet — TFMPNetForMaskedLM (MPNet 模型)
- openai-gpt — TFOpenAIGPTLMHeadModel (OpenAI GPT 模型)
- roberta — TFRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- t5 — TFT5ForConditionalGeneration (T5 模型)
- tapas — TFTapasForMaskedLM (TAPAS 模型)
- transfo-xl — TFTransfoXLLMHeadModel (Transformer-XL 模型)
- vit_mae — TFViTMAEForPreTraining (ViTMAE 模型)
- xlm — TFXLMWithLMHeadModel (XLM 模型)
- xlm-roberta — TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- xlnet — TFXLNetLMHeadModel (XLNet 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForPreTraining
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有预训练头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< 源代码 >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- AlbertConfig 配置类: FlaxAlbertForPreTraining (ALBERT 模型)
- BartConfig 配置类: FlaxBartForConditionalGeneration (BART 模型)
- BertConfig 配置类: FlaxBertForPreTraining (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForPreTraining (BigBird 模型)
- ElectraConfig 配置类: FlaxElectraForPreTraining (ELECTRA 模型)
- LongT5Config 配置类: FlaxLongT5ForConditionalGeneration (LongT5 模型)
- MBartConfig 配置类: FlaxMBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: FlaxMT5ForConditionalGeneration (MT5 模型)
- RoFormerConfig 配置类: FlaxRoFormerForMaskedLM (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- T5Config 配置类: FlaxT5ForConditionalGeneration (T5 模型)
- Wav2Vec2Config 配置类: FlaxWav2Vec2ForPreTraining (Wav2Vec2 模型)
- WhisperConfig 配置类: FlaxWhisperForConditionalGeneration (Whisper 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现方式 (如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
), 或者"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认是手动"eager"
实现。
从配置实例化库的模型类之一(带有预训练头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< 源代码 >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 的模型仓库中。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。 当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果应该不使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 通过协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上自定义模型及其自己的建模文件中。此选项仅应为信任的存储库设置为True
,并且您已阅读其中的代码,因为它将在本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载配置,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有预训练头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — FlaxAlbertForPreTraining (ALBERT 模型)
- bart — FlaxBartForConditionalGeneration (BART 模型)
- bert — FlaxBertForPreTraining (BERT 模型)
- big_bird — FlaxBigBirdForPreTraining (BigBird 模型)
- electra — FlaxElectraForPreTraining (ELECTRA 模型)
- longt5 — FlaxLongT5ForConditionalGeneration (LongT5 模型)
- mbart — FlaxMBartForConditionalGeneration (mBART 模型)
- mt5 — FlaxMT5ForConditionalGeneration (MT5 模型)
- roberta — FlaxRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForMaskedLM (RoFormer 模型)
- t5 — FlaxT5ForConditionalGeneration (T5 模型)
- wav2vec2 — FlaxWav2Vec2ForPreTraining (Wav2Vec2 模型)
- whisper — FlaxWhisperForConditionalGeneration (Whisper 模型)
- xlm-roberta — FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForPreTraining.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
自然语言处理
以下自动类适用于以下自然语言处理任务。
AutoModelForCausalLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有因果语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< 源代码 >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- AriaTextConfig 配置类: AriaTextForCausalLM (AriaText 模型)
- BambaConfig 配置类: BambaForCausalLM (Bamba 模型)
- BartConfig 配置类: BartForCausalLM (BART 模型)
- BertConfig 配置类: BertLMHeadModel (BERT 模型)
- BertGenerationConfig 配置类: BertGenerationDecoder (Bert Generation 模型)
- BigBirdConfig 配置类: BigBirdForCausalLM (BigBird 模型)
- BigBirdPegasusConfig 配置类: BigBirdPegasusForCausalLM (BigBird-Pegasus 模型)
- BioGptConfig 配置类: BioGptForCausalLM (BioGpt 模型)
- BlenderbotConfig 配置类: BlenderbotForCausalLM (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: BlenderbotSmallForCausalLM (BlenderbotSmall 模型)
- BloomConfig 配置类: BloomForCausalLM (BLOOM 模型)
- CTRLConfig 配置类: CTRLLMHeadModel (CTRL 模型)
- CamembertConfig 配置类: CamembertForCausalLM (CamemBERT 模型)
- CodeGenConfig 配置类: CodeGenForCausalLM (CodeGen 模型)
- Cohere2Config 配置类: Cohere2ForCausalLM (Cohere2 模型)
- CohereConfig 配置类: CohereForCausalLM (Cohere 模型)
- CpmAntConfig 配置类: CpmAntForCausalLM (CPM-Ant 模型)
- Data2VecTextConfig 配置类: Data2VecTextForCausalLM (Data2VecText 模型)
- DbrxConfig 配置类: DbrxForCausalLM (DBRX 模型)
- DiffLlamaConfig 配置类: DiffLlamaForCausalLM (DiffLlama 模型)
- ElectraConfig 配置类: ElectraForCausalLM (ELECTRA 模型)
- Emu3Config 配置类: Emu3ForCausalLM (Emu3 模型)
- ErnieConfig 配置类: ErnieForCausalLM (ERNIE 模型)
- FalconConfig 配置类: FalconForCausalLM (Falcon 模型)
- FalconMambaConfig 配置类: FalconMambaForCausalLM (FalconMamba 模型)
- FuyuConfig 配置类: FuyuForCausalLM (Fuyu 模型)
- GPT2Config 配置类: GPT2LMHeadModel (OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类: GPTBigCodeForCausalLM (GPTBigCode 模型)
- GPTJConfig 配置类: GPTJForCausalLM (GPT-J 模型)
- GPTNeoConfig 配置类: GPTNeoForCausalLM (GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXForCausalLM (GPT NeoX 模型)
- GPTNeoXJapaneseConfig 配置类: GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese 模型)
- Gemma2Config 配置类: Gemma2ForCausalLM (Gemma2 模型)
- Gemma3Config 配置类: Gemma3ForCausalLM (Gemma3ForConditionalGeneration 模型)
- Gemma3TextConfig 配置类: Gemma3ForCausalLM (Gemma3ForCausalLM 模型)
- GemmaConfig 配置类: GemmaForCausalLM (Gemma 模型)
- GitConfig 配置类: GitForCausalLM (GIT 模型)
- GlmConfig 配置类: GlmForCausalLM (GLM 模型)
- GotOcr2Config 配置类: GotOcr2ForConditionalGeneration (GOT-OCR2 模型)
- GraniteConfig 配置类: GraniteForCausalLM (Granite 模型)
- GraniteMoeConfig 配置类: GraniteMoeForCausalLM (GraniteMoeMoe 模型)
- GraniteMoeSharedConfig 配置类: GraniteMoeSharedForCausalLM (GraniteMoeSharedMoe 模型)
- HeliumConfig 配置类: HeliumForCausalLM (Helium 模型)
- JambaConfig 配置类: JambaForCausalLM (Jamba 模型)
- JetMoeConfig 配置类: JetMoeForCausalLM (JetMoe 模型)
- LlamaConfig 配置类: LlamaForCausalLM (LLaMA 模型)
- MBartConfig 配置类: MBartForCausalLM (mBART 模型)
- Mamba2Config 配置类: Mamba2ForCausalLM (mamba2 模型)
- MambaConfig 配置类: MambaForCausalLM (Mamba 模型)
- MarianConfig 配置类: MarianForCausalLM (Marian 模型)
- MegaConfig 配置类: MegaForCausalLM (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForCausalLM (Megatron-BERT 模型)
- MistralConfig 配置类: MistralForCausalLM (Mistral 模型)
- MixtralConfig 配置类: MixtralForCausalLM (Mixtral 模型)
- MllamaConfig 配置类: MllamaForCausalLM (Mllama 模型)
- MoshiConfig 配置类: MoshiForCausalLM (Moshi 模型)
- MptConfig 配置类: MptForCausalLM (MPT 模型)
- MusicgenConfig 配置类: MusicgenForCausalLM (MusicGen 模型)
- MusicgenMelodyConfig 配置类: MusicgenMelodyForCausalLM (MusicGen Melody 模型)
- MvpConfig 配置类: MvpForCausalLM (MVP 模型)
- NemotronConfig 配置类: NemotronForCausalLM (Nemotron 模型)
- OPTConfig 配置类: OPTForCausalLM (OPT 模型)
- Olmo2Config 配置类: Olmo2ForCausalLM (OLMo2 模型)
- OlmoConfig 配置类: OlmoForCausalLM (OLMo 模型)
- OlmoeConfig 配置类: OlmoeForCausalLM (OLMoE 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- OpenLlamaConfig 配置类: OpenLlamaForCausalLM (OpenLlama 模型)
- PLBartConfig 配置类: PLBartForCausalLM (PLBart 模型)
- PegasusConfig 配置类: PegasusForCausalLM (Pegasus 模型)
- PersimmonConfig 配置类: PersimmonForCausalLM (Persimmon 模型)
- Phi3Config 配置类: Phi3ForCausalLM (Phi3 模型)
- PhiConfig 配置类: PhiForCausalLM (Phi 模型)
- PhimoeConfig 配置类: PhimoeForCausalLM (Phimoe 模型)
- ProphetNetConfig 配置类: ProphetNetForCausalLM (ProphetNet 模型)
- QDQBertConfig 配置类: QDQBertLMHeadModel (QDQBert 模型)
- Qwen2Config 配置类: Qwen2ForCausalLM (Qwen2 模型)
- Qwen2MoeConfig 配置类: Qwen2MoeForCausalLM (Qwen2MoE 模型)
- RecurrentGemmaConfig 配置类: RecurrentGemmaForCausalLM (RecurrentGemma 模型)
- ReformerConfig 配置类: ReformerModelWithLMHead (Reformer 模型)
- RemBertConfig 配置类: RemBertForCausalLM (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForCausalLM (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForCausalLM (RoFormer 模型)
- RobertaConfig 配置类: RobertaForCausalLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- RwkvConfig 配置类: RwkvForCausalLM (RWKV 模型)
- Speech2Text2Config 配置类: Speech2Text2ForCausalLM (Speech2Text2 模型)
- StableLmConfig 配置类: StableLmForCausalLM (StableLm 模型)
- Starcoder2Config 配置类: Starcoder2ForCausalLM (Starcoder2 模型)
- TrOCRConfig 配置类: TrOCRForCausalLM (TrOCR 模型)
- TransfoXLConfig 配置类: TransfoXLLMHeadModel (Transformer-XL 模型)
- WhisperConfig 配置类: WhisperForCausalLM (Whisper 模型)
- XGLMConfig 配置类: XGLMForCausalLM (XGLM 模型)
- XLMConfig 配置类: XLMWithLMHeadModel (XLM 模型)
- XLMProphetNetConfig 配置类: XLMProphetNetForCausalLM (XLM-ProphetNet 模型)
- XLMRobertaConfig 配置类: XLMRobertaForCausalLM (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForCausalLM (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetLMHeadModel (XLNet 模型)
- XmodConfig 配置类: XmodForCausalLM (X-MOD 模型)
- Zamba2Config 配置类: Zamba2ForCausalLM (Zamba2 模型)
- ZambaConfig 配置类: ZambaForCausalLM (Zamba 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库中的一个模型类(带有因果语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - 一个 tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的positional arguments, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 要使用的模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 该模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 要使用的状态字典,而不是从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,可以使用此选项。但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, defaults toFalse
) — 从 TensorFlow 检查点保存文件加载模型权重 (参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个代理服务器字典,用于指定每个协议或端点的代理,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上定义的自定义模型在它们自己的建模文件中。 此选项仅应为设置为True
,用于您信任的存储库,并且您已阅读过其中的代码,因为它将在您的本地计算机上执行 Hub 上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个对应于配置属性的键将用于使用提供的kwargs
值覆盖所述属性。 与任何配置属性不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有因果语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- aria_text — AriaTextForCausalLM (AriaText 模型)
- bamba — BambaForCausalLM (Bamba 模型)
- bart — BartForCausalLM (BART 模型)
- bert — BertLMHeadModel (BERT 模型)
- bert-generation — BertGenerationDecoder (Bert Generation 模型)
- big_bird — BigBirdForCausalLM (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusForCausalLM (BigBird-Pegasus 模型)
- biogpt — BioGptForCausalLM (BioGpt 模型)
- blenderbot — BlenderbotForCausalLM (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallForCausalLM (BlenderbotSmall 模型)
- bloom — BloomForCausalLM (BLOOM 模型)
- camembert — CamembertForCausalLM (CamemBERT 模型)
- code_llama — LlamaForCausalLM (CodeLlama 模型)
- codegen — CodeGenForCausalLM (CodeGen 模型)
- cohere — CohereForCausalLM (Cohere 模型)
- cohere2 — Cohere2ForCausalLM (Cohere2 模型)
- cpmant — CpmAntForCausalLM (CPM-Ant 模型)
- ctrl — CTRLLMHeadModel (CTRL 模型)
- data2vec-text — Data2VecTextForCausalLM (Data2VecText 模型)
- dbrx — DbrxForCausalLM (DBRX 模型)
- diffllama — DiffLlamaForCausalLM (DiffLlama 模型)
- electra — ElectraForCausalLM (ELECTRA 模型)
- emu3 — Emu3ForCausalLM (Emu3 模型)
- ernie — ErnieForCausalLM (ERNIE 模型)
- falcon — FalconForCausalLM (Falcon 模型)
- falcon_mamba — FalconMambaForCausalLM (FalconMamba 模型)
- fuyu — FuyuForCausalLM (Fuyu 模型)
- gemma — GemmaForCausalLM (Gemma 模型)
- gemma2 — Gemma2ForCausalLM (Gemma2 模型)
- gemma3 — Gemma3ForCausalLM (Gemma3ForConditionalGeneration 模型)
- gemma3_text — Gemma3ForCausalLM (Gemma3ForCausalLM 模型)
- git — GitForCausalLM (GIT 模型)
- glm — GlmForCausalLM (GLM 模型)
- got_ocr2 — GotOcr2ForConditionalGeneration (GOT-OCR2 模型)
- gpt-sw3 — GPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — GPT2LMHeadModel (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForCausalLM (GPTBigCode 模型)
- gpt_neo — GPTNeoForCausalLM (GPT Neo 模型)
- gpt_neox — GPTNeoXForCausalLM (GPT NeoX 模型)
- gpt_neox_japanese — GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese 模型)
- gptj — GPTJForCausalLM (GPT-J 模型)
- granite — GraniteForCausalLM (Granite 模型)
- granitemoe — GraniteMoeForCausalLM (GraniteMoeMoe 模型)
- granitemoeshared — GraniteMoeSharedForCausalLM (GraniteMoeSharedMoe 模型)
- helium — HeliumForCausalLM (Helium 模型)
- jamba — JambaForCausalLM (Jamba 模型)
- jetmoe — JetMoeForCausalLM (JetMoe 模型)
- llama — LlamaForCausalLM (LLaMA 模型)
- mamba — MambaForCausalLM (Mamba 模型)
- mamba2 — Mamba2ForCausalLM (mamba2 模型)
- marian — MarianForCausalLM (Marian 模型)
- mbart — MBartForCausalLM (mBART 模型)
- mega — MegaForCausalLM (MEGA 模型)
- megatron-bert — MegatronBertForCausalLM (Megatron-BERT 模型)
- mistral — MistralForCausalLM (Mistral 模型)
- mixtral — MixtralForCausalLM (Mixtral 模型)
- mllama — MllamaForCausalLM (Mllama 模型)
- moshi — MoshiForCausalLM (Moshi 模型)
- mpt — MptForCausalLM (MPT 模型)
- musicgen — MusicgenForCausalLM (MusicGen 模型)
- musicgen_melody — MusicgenMelodyForCausalLM (MusicGen Melody 模型)
- mvp — MvpForCausalLM (MVP 模型)
- nemotron — NemotronForCausalLM (Nemotron 模型)
- olmo — OlmoForCausalLM (OLMo 模型)
- olmo2 — Olmo2ForCausalLM (OLMo2 模型)
- olmoe — OlmoeForCausalLM (OLMoE 模型)
- open-llama — OpenLlamaForCausalLM (OpenLlama 模型)
- openai-gpt — OpenAIGPTLMHeadModel (OpenAI GPT 模型)
- opt — OPTForCausalLM (OPT 模型)
- pegasus — PegasusForCausalLM (Pegasus 模型)
- persimmon — PersimmonForCausalLM (Persimmon 模型)
- phi — PhiForCausalLM (Phi 模型)
- phi3 — Phi3ForCausalLM (Phi3 模型)
- phimoe — PhimoeForCausalLM (Phimoe 模型)
- plbart — PLBartForCausalLM (PLBart 模型)
- prophetnet — ProphetNetForCausalLM (ProphetNet 模型)
- qdqbert — QDQBertLMHeadModel (QDQBert 模型)
- qwen2 — Qwen2ForCausalLM (Qwen2 模型)
- qwen2_moe — Qwen2MoeForCausalLM (Qwen2MoE 模型)
- recurrent_gemma — RecurrentGemmaForCausalLM (RecurrentGemma 模型)
- reformer — ReformerModelWithLMHead (Reformer 模型)
- rembert — RemBertForCausalLM (RemBERT 模型)
- roberta — RobertaForCausalLM (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForCausalLM (RoCBert 模型)
- roformer — RoFormerForCausalLM (RoFormer 模型)
- rwkv — RwkvForCausalLM (RWKV 模型)
- speech_to_text_2 — Speech2Text2ForCausalLM (Speech2Text2 模型)
- stablelm — StableLmForCausalLM (StableLm 模型)
- starcoder2 — Starcoder2ForCausalLM (Starcoder2 模型)
- transfo-xl — TransfoXLLMHeadModel (Transformer-XL 模型)
- trocr — TrOCRForCausalLM (TrOCR 模型)
- whisper — WhisperForCausalLM (Whisper 模型)
- xglm — XGLMForCausalLM (XGLM 模型)
- xlm — XLMWithLMHeadModel (XLM 模型)
- xlm-prophetnet — XLMProphetNetForCausalLM (XLM-ProphetNet 模型)
- xlm-roberta — XLMRobertaForCausalLM (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForCausalLM (XLM-RoBERTa-XL 模型)
- xlnet — XLNetLMHeadModel (XLNet 模型)
- xmod — XmodForCausalLM (X-MOD 模型)
- zamba — ZambaForCausalLM (Zamba 模型)
- zamba2 — Zamba2ForCausalLM (Zamba2 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCausalLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForCausalLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有因果语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- BertConfig 配置类: TFBertLMHeadModel (BERT 模型)
- CTRLConfig 配置类: TFCTRLLMHeadModel (CTRL 模型)
- CamembertConfig 配置类: TFCamembertForCausalLM (CamemBERT 模型)
- GPT2Config 配置类: TFGPT2LMHeadModel (OpenAI GPT-2 模型)
- GPTJConfig 配置类: TFGPTJForCausalLM (GPT-J 模型)
- MistralConfig 配置类: TFMistralForCausalLM (Mistral 模型)
- OPTConfig 配置类: TFOPTForCausalLM (OPT 模型)
- OpenAIGPTConfig 配置类: TFOpenAIGPTLMHeadModel (OpenAI GPT 模型)
- RemBertConfig 配置类: TFRemBertForCausalLM (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForCausalLM (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForCausalLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- TransfoXLConfig 配置类: TFTransfoXLLMHeadModel (Transformer-XL 模型)
- XGLMConfig 配置类: TFXGLMForCausalLM (XGLM 模型)
- XLMConfig 配置类: TFXLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForCausalLM (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetLMHeadModel (XLNet 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现方式(如果相关)。 可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一种。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认为手动"eager"
实现。
从配置实例化库中的一个模型类(带有因果语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。 在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型,然后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。 在以下情况下可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 用于缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件中加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。所有下载现在默认情况下尽可能恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应为信任的仓库且您已阅读代码的情况下设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的仓库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为会有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有因果语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bert — TFBertLMHeadModel (BERT 模型)
- camembert — TFCamembertForCausalLM (CamemBERT 模型)
- ctrl — TFCTRLLMHeadModel (CTRL 模型)
- gpt-sw3 — TFGPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — TFGPT2LMHeadModel (OpenAI GPT-2 模型)
- gptj — TFGPTJForCausalLM (GPT-J 模型)
- mistral — TFMistralForCausalLM (Mistral 模型)
- openai-gpt — TFOpenAIGPTLMHeadModel (OpenAI GPT 模型)
- opt — TFOPTForCausalLM (OPT 模型)
- rembert — TFRemBertForCausalLM (RemBERT 模型)
- roberta — TFRobertaForCausalLM (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForCausalLM (RoFormer 模型)
- transfo-xl — TFTransfoXLLMHeadModel (Transformer-XL 模型)
- xglm — TFXGLMForCausalLM (XGLM 模型)
- xlm — TFXLMWithLMHeadModel (XLM 模型)
- xlm-roberta — TFXLMRobertaForCausalLM (XLM-RoBERTa 模型)
- xlnet — TFXLNetLMHeadModel (XLNet 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForCausalLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有因果语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- BartConfig 配置类: FlaxBartForCausalLM (BART 模型)
- BertConfig 配置类: FlaxBertForCausalLM (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForCausalLM (BigBird 模型)
- BloomConfig 配置类: FlaxBloomForCausalLM (BLOOM 模型)
- ElectraConfig 配置类: FlaxElectraForCausalLM (ELECTRA 模型)
- GPT2Config 配置类: FlaxGPT2LMHeadModel (OpenAI GPT-2 模型)
- GPTJConfig 配置类: FlaxGPTJForCausalLM (GPT-J 模型)
- GPTNeoConfig 配置类: FlaxGPTNeoForCausalLM (GPT Neo 模型)
- GemmaConfig 配置类: FlaxGemmaForCausalLM (Gemma 模型)
- LlamaConfig 配置类: FlaxLlamaForCausalLM (LLaMA 模型)
- MistralConfig 配置类: FlaxMistralForCausalLM (Mistral 模型)
- OPTConfig 配置类: FlaxOPTForCausalLM (OPT 模型)
- RobertaConfig 配置类: FlaxRobertaForCausalLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- XGLMConfig 配置类: FlaxXGLMForCausalLM (XGLM 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForCausalLM (XLM-RoBERTa 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认是手动"eager"
实现。
从配置实例化库中的一个模型类(带有因果语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。在以下情况下,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型是通过提供本地目录作为
pretrained_model_name_or_path
加载的,并且在目录中找到了名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, 可选) — 用于缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件中加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在它们自己的建模文件中定义。此选项应仅对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分在不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性相对应,将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有因果语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bart — FlaxBartForCausalLM (BART 模型)
- bert — FlaxBertForCausalLM (BERT 模型)
- big_bird — FlaxBigBirdForCausalLM (BigBird 模型)
- bloom — FlaxBloomForCausalLM (BLOOM 模型)
- electra — FlaxElectraForCausalLM (ELECTRA 模型)
- gemma — FlaxGemmaForCausalLM (Gemma 模型)
- gpt-sw3 — FlaxGPT2LMHeadModel (GPT-Sw3 模型)
- gpt2 — FlaxGPT2LMHeadModel (OpenAI GPT-2 模型)
- gpt_neo — FlaxGPTNeoForCausalLM (GPT Neo 模型)
- gptj — FlaxGPTJForCausalLM (GPT-J 模型)
- llama — FlaxLlamaForCausalLM (LLaMA 模型)
- mistral — FlaxMistralForCausalLM (Mistral 模型)
- opt — FlaxOPTForCausalLM (OPT 模型)
- roberta — FlaxRobertaForCausalLM (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm 模型)
- xglm — FlaxXGLMForCausalLM (XGLM 模型)
- xlm-roberta — FlaxXLMRobertaForCausalLM (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForCausalLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMaskedLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有掩码语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: AlbertForMaskedLM (ALBERT 模型)
- BartConfig 配置类: BartForConditionalGeneration (BART 模型)
- BertConfig 配置类: BertForMaskedLM (BERT 模型)
- BigBirdConfig 配置类: BigBirdForMaskedLM (BigBird 模型)
- CamembertConfig 配置类: CamembertForMaskedLM (CamemBERT 模型)
- ConvBertConfig 配置类: ConvBertForMaskedLM (ConvBERT 模型)
- Data2VecTextConfig 配置类: Data2VecTextForMaskedLM (Data2VecText 模型)
- DebertaConfig 配置类: DebertaForMaskedLM (DeBERTa 模型)
- DebertaV2Config 配置类: DebertaV2ForMaskedLM (DeBERTa-v2 模型)
- DistilBertConfig 配置类: DistilBertForMaskedLM (DistilBERT 模型)
- ElectraConfig 配置类: ElectraForMaskedLM (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForMaskedLM (ERNIE 模型)
- EsmConfig 配置类: EsmForMaskedLM (ESM 模型)
- FNetConfig 配置类: FNetForMaskedLM (FNet 模型)
- FlaubertConfig 配置类: FlaubertWithLMHeadModel (FlauBERT 模型)
- FunnelConfig 配置类: FunnelForMaskedLM (Funnel Transformer 模型)
- IBertConfig 配置类: IBertForMaskedLM (I-BERT 模型)
- LayoutLMConfig 配置类: LayoutLMForMaskedLM (LayoutLM 模型)
- LongformerConfig 配置类: LongformerForMaskedLM (Longformer 模型)
- LukeConfig 配置类: LukeForMaskedLM (LUKE 模型)
- MBartConfig 配置类: MBartForConditionalGeneration (mBART 模型)
- MPNetConfig 配置类: MPNetForMaskedLM (MPNet 模型)
- MegaConfig 配置类: MegaForMaskedLM (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForMaskedLM (Megatron-BERT 模型)
- MobileBertConfig 配置类: MobileBertForMaskedLM (MobileBERT 模型)
- ModernBertConfig 配置类: ModernBertForMaskedLM (ModernBERT 模型)
- MraConfig 配置类: MraForMaskedLM (MRA 模型)
- MvpConfig 配置类: MvpForConditionalGeneration (MVP 模型)
- NezhaConfig 配置类: NezhaForMaskedLM (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForMaskedLM (Nyströmformer 模型)
- PerceiverConfig 配置类: PerceiverForMaskedLM (Perceiver 模型)
- QDQBertConfig 配置类: QDQBertForMaskedLM (QDQBert 模型)
- ReformerConfig 配置类: ReformerForMaskedLM (Reformer 模型)
- RemBertConfig 配置类: RemBertForMaskedLM (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForMaskedLM (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForMaskedLM (RoFormer 模型)
- RobertaConfig 配置类: RobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类: SqueezeBertForMaskedLM (SqueezeBERT 模型)
- TapasConfig 配置类: TapasForMaskedLM (TAPAS 模型)
- Wav2Vec2Config 配置类:
Wav2Vec2ForMaskedLM
(Wav2Vec2 模型) - XLMConfig 配置类: XLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- XmodConfig 配置类: XmodForMaskedLM (X-MOD 模型)
- YosoConfig 配置类: YosoForMaskedLM (YOSO 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一种。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有掩码语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,应将from_tf
设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加的位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,以替代自动加载的配置。当满足以下条件时,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], optional) — 一个状态字典,用于替代从已保存权重文件中加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, optional) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, optional, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且忽略。 现在,所有下载在可能的情况下都默认恢复。 将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, optional) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, optional, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, 默认为"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。 此选项仅应针对您信任并且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键(对应于配置属性)将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有掩码语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertForMaskedLM (ALBERT 模型)
- bart — BartForConditionalGeneration (BART 模型)
- bert — BertForMaskedLM (BERT 模型)
- big_bird — BigBirdForMaskedLM (BigBird 模型)
- camembert — CamembertForMaskedLM (CamemBERT 模型)
- convbert — ConvBertForMaskedLM (ConvBERT 模型)
- data2vec-text — Data2VecTextForMaskedLM (Data2VecText 模型)
- deberta — DebertaForMaskedLM (DeBERTa 模型)
- deberta-v2 — DebertaV2ForMaskedLM (DeBERTa-v2 模型)
- distilbert — DistilBertForMaskedLM (DistilBERT 模型)
- electra — ElectraForMaskedLM (ELECTRA 模型)
- ernie — ErnieForMaskedLM (ERNIE 模型)
- esm — EsmForMaskedLM (ESM 模型)
- flaubert — FlaubertWithLMHeadModel (FlauBERT 模型)
- fnet — FNetForMaskedLM (FNet 模型)
- funnel — FunnelForMaskedLM (Funnel Transformer 模型)
- ibert — IBertForMaskedLM (I-BERT 模型)
- layoutlm — LayoutLMForMaskedLM (LayoutLM 模型)
- longformer — LongformerForMaskedLM (Longformer 模型)
- luke — LukeForMaskedLM (LUKE 模型)
- mbart — MBartForConditionalGeneration (mBART 模型)
- mega — MegaForMaskedLM (MEGA 模型)
- megatron-bert — MegatronBertForMaskedLM (Megatron-BERT 模型)
- mobilebert — MobileBertForMaskedLM (MobileBERT 模型)
- modernbert — ModernBertForMaskedLM (ModernBERT 模型)
- mpnet — MPNetForMaskedLM (MPNet 模型)
- mra — MraForMaskedLM (MRA 模型)
- mvp — MvpForConditionalGeneration (MVP 模型)
- nezha — NezhaForMaskedLM (Nezha 模型)
- nystromformer — NystromformerForMaskedLM (Nyströmformer 模型)
- perceiver — PerceiverForMaskedLM (Perceiver 模型)
- qdqbert — QDQBertForMaskedLM (QDQBert 模型)
- reformer — ReformerForMaskedLM (Reformer 模型)
- rembert — RemBertForMaskedLM (RemBERT 模型)
- roberta — RobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForMaskedLM (RoCBert 模型)
- roformer — RoFormerForMaskedLM (RoFormer 模型)
- squeezebert — SqueezeBertForMaskedLM (SqueezeBERT 模型)
- tapas — TapasForMaskedLM (TAPAS 模型)
- wav2vec2 —
Wav2Vec2ForMaskedLM
(Wav2Vec2 模型) - xlm — XLMWithLMHeadModel (XLM 模型)
- xlm-roberta — XLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL 模型)
- xmod — XmodForMaskedLM (X-MOD 模型)
- yoso — YosoForMaskedLM (YOSO 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMaskedLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMaskedLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有掩码语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: TFAlbertForMaskedLM (ALBERT 模型)
- BertConfig 配置类: TFBertForMaskedLM (BERT 模型)
- CamembertConfig 配置类: TFCamembertForMaskedLM (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertForMaskedLM (ConvBERT 模型)
- DebertaConfig 配置类: TFDebertaForMaskedLM (DeBERTa 模型)
- DebertaV2Config 配置类: TFDebertaV2ForMaskedLM (DeBERTa-v2 模型)
- DistilBertConfig 配置类: TFDistilBertForMaskedLM (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForMaskedLM (ELECTRA 模型)
- EsmConfig 配置类: TFEsmForMaskedLM (ESM 模型)
- FlaubertConfig 配置类: TFFlaubertWithLMHeadModel (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForMaskedLM (Funnel Transformer 模型)
- LayoutLMConfig 配置类: TFLayoutLMForMaskedLM (LayoutLM 模型)
- LongformerConfig 配置类: TFLongformerForMaskedLM (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForMaskedLM (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForMaskedLM (MobileBERT 模型)
- RemBertConfig 配置类: TFRemBertForMaskedLM (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForMaskedLM (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- TapasConfig 配置类: TFTapasForMaskedLM (TAPAS 模型)
- XLMConfig 配置类: TFXLMWithLMHeadModel (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有掩码语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。 在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加的位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。 当满足以下条件时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, 可选) — 下载的预训练模型配置应缓存到其中的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。 现在,所有下载在可能的情况下默认恢复。 将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上自定义模型文件中定义的自定义模型。 此选项仅应为信任的存储库和您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载配置,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性相对应,都将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的其余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有掩码语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — TFAlbertForMaskedLM (ALBERT 模型)
- bert — TFBertForMaskedLM (BERT 模型)
- camembert — TFCamembertForMaskedLM (CamemBERT 模型)
- convbert — TFConvBertForMaskedLM (ConvBERT 模型)
- deberta — TFDebertaForMaskedLM (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForMaskedLM (DeBERTa-v2 模型)
- distilbert — TFDistilBertForMaskedLM (DistilBERT 模型)
- electra — TFElectraForMaskedLM (ELECTRA 模型)
- esm — TFEsmForMaskedLM (ESM 模型)
- flaubert — TFFlaubertWithLMHeadModel (FlauBERT 模型)
- funnel — TFFunnelForMaskedLM (Funnel Transformer 模型)
- layoutlm — TFLayoutLMForMaskedLM (LayoutLM 模型)
- longformer — TFLongformerForMaskedLM (Longformer 模型)
- mobilebert — TFMobileBertForMaskedLM (MobileBERT 模型)
- mpnet — TFMPNetForMaskedLM (MPNet 模型)
- rembert — TFRemBertForMaskedLM (RemBERT 模型)
- roberta — TFRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForMaskedLM (RoFormer 模型)
- tapas — TFTapasForMaskedLM (TAPAS 模型)
- xlm — TFXLMWithLMHeadModel (XLM 模型)
- xlm-roberta — TFXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForMaskedLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有掩码语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: FlaxAlbertForMaskedLM (ALBERT 模型)
- BartConfig 配置类: FlaxBartForConditionalGeneration (BART 模型)
- BertConfig 配置类: FlaxBertForMaskedLM (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForMaskedLM (BigBird 模型)
- DistilBertConfig 配置类: FlaxDistilBertForMaskedLM (DistilBERT 模型)
- ElectraConfig 配置类: FlaxElectraForMaskedLM (ELECTRA 模型)
- MBartConfig 配置类: FlaxMBartForConditionalGeneration (mBART 模型)
- RoFormerConfig 配置类: FlaxRoFormerForMaskedLM (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaForMaskedLM (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一种。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有掩码语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID 。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。 在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。 此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加的位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。 当满足以下条件时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, 可选) — 下载的预训练模型配置应缓存到其中的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理服务器用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许使用 Hub 上定义的自定义模型及其自身的建模文件。此选项应仅对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
,行为会有所不同:- 如果使用
config
提供了配置,**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键(对应于配置属性)将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有掩码语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — FlaxAlbertForMaskedLM (ALBERT 模型)
- bart — FlaxBartForConditionalGeneration (BART 模型)
- bert — FlaxBertForMaskedLM (BERT 模型)
- big_bird — FlaxBigBirdForMaskedLM (BigBird 模型)
- distilbert — FlaxDistilBertForMaskedLM (DistilBERT 模型)
- electra — FlaxElectraForMaskedLM (ELECTRA 模型)
- mbart — FlaxMBartForConditionalGeneration (mBART 模型)
- roberta — FlaxRobertaForMaskedLM (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForMaskedLM (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForMaskedLM (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForMaskedLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForMaskedLM.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForMaskedLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMaskGeneration
TFAutoModelForMaskGeneration
AutoModelForSeq2SeqLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有序列到序列的语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- BartConfig 配置类: BartForConditionalGeneration (BART 模型)
- BigBirdPegasusConfig 配置类: BigBirdPegasusForConditionalGeneration (BigBird-Pegasus 模型)
- BlenderbotConfig 配置类: BlenderbotForConditionalGeneration (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: BlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类: EncoderDecoderModel (Encoder decoder 模型)
- FSMTConfig 配置类: FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- GPTSanJapaneseConfig 配置类: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- LEDConfig 配置类: LEDForConditionalGeneration (LED 模型)
- LongT5Config 配置类: LongT5ForConditionalGeneration (LongT5 模型)
- M2M100Config 配置类: M2M100ForConditionalGeneration (M2M100 模型)
- MBartConfig 配置类: MBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: MT5ForConditionalGeneration (MT5 模型)
- MarianConfig 配置类: MarianMTModel (Marian 模型)
- MvpConfig 配置类: MvpForConditionalGeneration (MVP 模型)
- NllbMoeConfig 配置类: NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- PLBartConfig 配置类: PLBartForConditionalGeneration (PLBart 模型)
- PegasusConfig 配置类: PegasusForConditionalGeneration (Pegasus 模型)
- PegasusXConfig 配置类: PegasusXForConditionalGeneration (PEGASUS-X 模型)
- ProphetNetConfig 配置类: ProphetNetForConditionalGeneration (ProphetNet 模型)
- Qwen2AudioConfig 配置类: Qwen2AudioForConditionalGeneration (Qwen2Audio 模型)
- SeamlessM4TConfig 配置类: SeamlessM4TForTextToText (SeamlessM4T 模型)
- SeamlessM4Tv2Config 配置类: SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 模型)
- SwitchTransformersConfig 配置类: SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- T5Config 配置类: T5ForConditionalGeneration (T5 模型)
- UMT5Config 配置类: UMT5ForConditionalGeneration (UMT5 模型)
- XLMProphetNetConfig 配置类: XLMProphetNetForConditionalGeneration (XLM-ProphetNet 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有序列到序列的语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的位置参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 用于替代从已保存权重文件加载的 state dictionary。
如果您想从预训练配置创建模型,但加载您自己的权重,则可以使用此选项。但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 下载的预训练模型配置应缓存到的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用并忽略。现在所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应为信任的存储库设置为True
,并且您已阅读过代码,因为它将在您的本地机器上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分在不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数,可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,
output_attentions=True
)。行为方式取决于是否提供了config
或自动加载:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bart — BartForConditionalGeneration (BART 模型)
- bigbird_pegasus — BigBirdPegasusForConditionalGeneration (BigBird-Pegasus 模型)
- blenderbot — BlenderbotForConditionalGeneration (Blenderbot 模型)
- blenderbot-small — BlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- encoder-decoder — EncoderDecoderModel (Encoder decoder 模型)
- fsmt — FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- gptsan-japanese — GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- led — LEDForConditionalGeneration (LED 模型)
- longt5 — LongT5ForConditionalGeneration (LongT5 模型)
- m2m_100 — M2M100ForConditionalGeneration (M2M100 模型)
- marian — MarianMTModel (Marian 模型)
- mbart — MBartForConditionalGeneration (mBART 模型)
- mt5 — MT5ForConditionalGeneration (MT5 模型)
- mvp — MvpForConditionalGeneration (MVP 模型)
- nllb-moe — NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- pegasus — PegasusForConditionalGeneration (Pegasus 模型)
- pegasus_x — PegasusXForConditionalGeneration (PEGASUS-X 模型)
- plbart — PLBartForConditionalGeneration (PLBart 模型)
- prophetnet — ProphetNetForConditionalGeneration (ProphetNet 模型)
- qwen2_audio — Qwen2AudioForConditionalGeneration (Qwen2Audio 模型)
- seamless_m4t — SeamlessM4TForTextToText (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 模型)
- switch_transformers — SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- t5 — T5ForConditionalGeneration (T5 模型)
- umt5 — UMT5ForConditionalGeneration (UMT5 模型)
- xlm-prophetnet — XLMProphetNetForConditionalGeneration (XLM-ProphetNet 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/t5_tf_model_config.json")
>>> model = AutoModelForSeq2SeqLM.from_pretrained(
... "./tf_model/t5_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSeq2SeqLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有序列到序列的语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- BartConfig 配置类: TFBartForConditionalGeneration (BART 模型)
- BlenderbotConfig 配置类: TFBlenderbotForConditionalGeneration (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类: TFEncoderDecoderModel (Encoder decoder 模型)
- LEDConfig 配置类: TFLEDForConditionalGeneration (LED 模型)
- MBartConfig 配置类: TFMBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: TFMT5ForConditionalGeneration (MT5 模型)
- MarianConfig 配置类: TFMarianMTModel (Marian 模型)
- PegasusConfig 配置类: TFPegasusForConditionalGeneration (Pegasus 模型)
- T5Config 配置类: TFT5ForConditionalGeneration (T5 模型)
- attn_implementation (
str
, 可选) — 要在模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有序列到序列的语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的位置参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,以替代自动加载的配置。在以下情况下,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, optional) — 如果不想使用标准缓存,则应将下载的预训练模型配置缓存到此目录的路径。 - from_pt (
bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, optional) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项应仅对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 如果代码位于与模型其余部分不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,都将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bart — TFBartForConditionalGeneration (BART 模型)
- blenderbot — TFBlenderbotForConditionalGeneration (Blenderbot 模型)
- blenderbot-small — TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- encoder-decoder — TFEncoderDecoderModel (Encoder decoder 模型)
- led — TFLEDForConditionalGeneration (LED 模型)
- marian — TFMarianMTModel (Marian 模型)
- mbart — TFMBartForConditionalGeneration (mBART 模型)
- mt5 — TFMT5ForConditionalGeneration (MT5 模型)
- pegasus — TFPegasusForConditionalGeneration (Pegasus 模型)
- t5 — TFT5ForConditionalGeneration (T5 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSeq2SeqLM
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有序列到序列的语言建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- BartConfig 配置类: FlaxBartForConditionalGeneration (BART 模型)
- BlenderbotConfig 配置类: FlaxBlenderbotForConditionalGeneration (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: FlaxBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类: FlaxEncoderDecoderModel (Encoder decoder 模型)
- LongT5Config 配置类: FlaxLongT5ForConditionalGeneration (LongT5 模型)
- MBartConfig 配置类: FlaxMBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: FlaxMT5ForConditionalGeneration (MT5 模型)
- MarianConfig 配置类: FlaxMarianMTModel (Marian 模型)
- PegasusConfig 配置类: FlaxPegasusForConditionalGeneration (Pegasus 模型)
- T5Config 配置类: FlaxT5ForConditionalGeneration (T5 模型)
- attn_implementation (
str
, optional) — 要在模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有序列到序列的语言建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的 模型 ID 。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, optional) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,以替代自动加载的配置。在以下情况下,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, optional) — 如果不想使用标准缓存,则应将下载的预训练模型配置缓存到此目录的路径。 - from_pt (
bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, optional) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). - revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. - code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. - kwargs (additional keyword arguments, optional) — Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bart — FlaxBartForConditionalGeneration (BART 模型)
- blenderbot — FlaxBlenderbotForConditionalGeneration (Blenderbot 模型)
- blenderbot-small — FlaxBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- encoder-decoder — FlaxEncoderDecoderModel (Encoder decoder 模型)
- longt5 — FlaxLongT5ForConditionalGeneration (LongT5 模型)
- marian — FlaxMarianMTModel (Marian 模型)
- mbart — FlaxMBartForConditionalGeneration (mBART 模型)
- mt5 — FlaxMT5ForConditionalGeneration (MT5 模型)
- pegasus — FlaxPegasusForConditionalGeneration (Pegasus 模型)
- t5 — FlaxT5ForConditionalGeneration (T5 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json")
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForSequenceClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有序列分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- AlbertConfig 配置类: AlbertForSequenceClassification (ALBERT 模型)
- BartConfig 配置类: BartForSequenceClassification (BART 模型)
- BertConfig 配置类: BertForSequenceClassification (BERT 模型)
- BigBirdConfig 配置类: BigBirdForSequenceClassification (BigBird 模型)
- BigBirdPegasusConfig 配置类: BigBirdPegasusForSequenceClassification (BigBird-Pegasus 模型)
- BioGptConfig 配置类: BioGptForSequenceClassification (BioGpt 模型)
- BloomConfig 配置类: BloomForSequenceClassification (BLOOM 模型)
- CTRLConfig 配置类: CTRLForSequenceClassification (CTRL 模型)
- CamembertConfig 配置类: CamembertForSequenceClassification (CamemBERT 模型)
- CanineConfig 配置类: CanineForSequenceClassification (CANINE 模型)
- ConvBertConfig 配置类: ConvBertForSequenceClassification (ConvBERT 模型)
- Data2VecTextConfig 配置类: Data2VecTextForSequenceClassification (Data2VecText 模型)
- DebertaConfig 配置类: DebertaForSequenceClassification (DeBERTa 模型)
- DebertaV2Config 配置类: DebertaV2ForSequenceClassification (DeBERTa-v2 模型)
- DiffLlamaConfig 配置类: DiffLlamaForSequenceClassification (DiffLlama 模型)
- DistilBertConfig 配置类: DistilBertForSequenceClassification (DistilBERT 模型)
- ElectraConfig 配置类: ElectraForSequenceClassification (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForSequenceClassification (ERNIE 模型)
- ErnieMConfig 配置类: ErnieMForSequenceClassification (ErnieM 模型)
- EsmConfig 配置类: EsmForSequenceClassification (ESM 模型)
- FNetConfig 配置类: FNetForSequenceClassification (FNet 模型)
- FalconConfig 配置类: FalconForSequenceClassification (Falcon 模型)
- FlaubertConfig 配置类: FlaubertForSequenceClassification (FlauBERT 模型)
- FunnelConfig 配置类: FunnelForSequenceClassification (Funnel Transformer 模型)
- GPT2Config 配置类: GPT2ForSequenceClassification (OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类: GPTBigCodeForSequenceClassification (GPTBigCode 模型)
- GPTJConfig 配置类: GPTJForSequenceClassification (GPT-J 模型)
- GPTNeoConfig 配置类: GPTNeoForSequenceClassification (GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXForSequenceClassification (GPT NeoX 模型)
- Gemma2Config 配置类: Gemma2ForSequenceClassification (Gemma2 模型)
- GemmaConfig 配置类: GemmaForSequenceClassification (Gemma 模型)
- GlmConfig 配置类: GlmForSequenceClassification (GLM 模型)
- HeliumConfig 配置类: HeliumForSequenceClassification (Helium 模型)
- IBertConfig 配置类: IBertForSequenceClassification (I-BERT 模型)
- JambaConfig 配置类: JambaForSequenceClassification (Jamba 模型)
- JetMoeConfig 配置类: JetMoeForSequenceClassification (JetMoe 模型)
- LEDConfig 配置类: LEDForSequenceClassification (LED 模型)
- LayoutLMConfig 配置类: LayoutLMForSequenceClassification (LayoutLM 模型)
- LayoutLMv2Config 配置类: LayoutLMv2ForSequenceClassification (LayoutLMv2 模型)
- LayoutLMv3Config 配置类: LayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- LiltConfig 配置类: LiltForSequenceClassification (LiLT 模型)
- LlamaConfig 配置类: LlamaForSequenceClassification (LLaMA 模型)
- LongformerConfig 配置类: LongformerForSequenceClassification (Longformer 模型)
- LukeConfig 配置类: LukeForSequenceClassification (LUKE 模型)
- MBartConfig 配置类: MBartForSequenceClassification (mBART 模型)
- MPNetConfig 配置类: MPNetForSequenceClassification (MPNet 模型)
- MT5Config 配置类: MT5ForSequenceClassification (MT5 模型)
- MarkupLMConfig 配置类: MarkupLMForSequenceClassification (MarkupLM 模型)
- MegaConfig 配置类: MegaForSequenceClassification (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForSequenceClassification (Megatron-BERT 模型)
- MistralConfig 配置类: MistralForSequenceClassification (Mistral 模型)
- MixtralConfig 配置类: MixtralForSequenceClassification (Mixtral 模型)
- MobileBertConfig 配置类: MobileBertForSequenceClassification (MobileBERT 模型)
- ModernBertConfig 配置类: ModernBertForSequenceClassification (ModernBERT 模型)
- MptConfig 配置类: MptForSequenceClassification (MPT 模型)
- MraConfig 配置类: MraForSequenceClassification (MRA 模型)
- MvpConfig 配置类: MvpForSequenceClassification (MVP 模型)
- NemotronConfig 配置类: NemotronForSequenceClassification (Nemotron 模型)
- NezhaConfig 配置类: NezhaForSequenceClassification (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForSequenceClassification (Nyströmformer 模型)
- OPTConfig 配置类: OPTForSequenceClassification (OPT 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- OpenLlamaConfig 配置类: OpenLlamaForSequenceClassification (OpenLlama 模型)
- PLBartConfig 配置类: PLBartForSequenceClassification (PLBart 模型)
- PerceiverConfig 配置类: PerceiverForSequenceClassification (Perceiver 模型)
- PersimmonConfig 配置类: PersimmonForSequenceClassification (Persimmon 模型)
- Phi3Config 配置类: Phi3ForSequenceClassification (Phi3 模型)
- PhiConfig 配置类: PhiForSequenceClassification (Phi 模型)
- PhimoeConfig 配置类: PhimoeForSequenceClassification (Phimoe 模型)
- QDQBertConfig 配置类: QDQBertForSequenceClassification (QDQBert 模型)
- Qwen2Config 配置类: Qwen2ForSequenceClassification (Qwen2 模型)
- Qwen2MoeConfig 配置类: Qwen2MoeForSequenceClassification (Qwen2MoE 模型)
- ReformerConfig 配置类: ReformerForSequenceClassification (Reformer 模型)
- RemBertConfig 配置类: RemBertForSequenceClassification (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForSequenceClassification (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForSequenceClassification (RoFormer 模型)
- RobertaConfig 配置类: RobertaForSequenceClassification (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类: SqueezeBertForSequenceClassification (SqueezeBERT 模型)
- StableLmConfig 配置类: StableLmForSequenceClassification (StableLm 模型)
- Starcoder2Config 配置类: Starcoder2ForSequenceClassification (Starcoder2 模型)
- T5Config 配置类: T5ForSequenceClassification (T5 模型)
- TapasConfig 配置类: TapasForSequenceClassification (TAPAS 模型)
- TransfoXLConfig 配置类: TransfoXLForSequenceClassification (Transformer-XL 模型)
- UMT5Config 配置类: UMT5ForSequenceClassification (UMT5 模型)
- XLMConfig 配置类: XLMForSequenceClassification (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForSequenceClassification (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetForSequenceClassification (XLNet 模型)
- XmodConfig 配置类: XmodForSequenceClassification (X-MOD 模型)
- YosoConfig 配置类: YosoForSequenceClassification (YOSO 模型)
- Zamba2Config 配置类: Zamba2ForSequenceClassification (Zamba2 模型)
- ZambaConfig 配置类: ZambaForSequenceClassification (Zamba 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,对于 torch>=2.1.1 将使用 SDPA。否则默认为手动"eager"
实现。
从配置实例化库中的模型类之一(带有序列分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个 tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型,然后加载 PyTorch 模型要慢。
- model_args (附加的位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型使用 save_pretrained() 保存,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 用于替代从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,可以使用此选项。 在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应针对您信任并已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数, 可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载配置,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性对应,将用于使用提供的kwargs
值覆盖所述属性。不与任何配置属性对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有序列分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertForSequenceClassification (ALBERT 模型)
- bart — BartForSequenceClassification (BART 模型)
- bert — BertForSequenceClassification (BERT 模型)
- big_bird — BigBirdForSequenceClassification (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusForSequenceClassification (BigBird-Pegasus 模型)
- biogpt — BioGptForSequenceClassification (BioGpt 模型)
- bloom — BloomForSequenceClassification (BLOOM 模型)
- camembert — CamembertForSequenceClassification (CamemBERT 模型)
- canine — CanineForSequenceClassification (CANINE 模型)
- code_llama — LlamaForSequenceClassification (CodeLlama 模型)
- convbert — ConvBertForSequenceClassification (ConvBERT 模型)
- ctrl — CTRLForSequenceClassification (CTRL 模型)
- data2vec-text — Data2VecTextForSequenceClassification (Data2VecText 模型)
- deberta — DebertaForSequenceClassification (DeBERTa 模型)
- deberta-v2 — DebertaV2ForSequenceClassification (DeBERTa-v2 模型)
- diffllama — DiffLlamaForSequenceClassification (DiffLlama 模型)
- distilbert — DistilBertForSequenceClassification (DistilBERT 模型)
- electra — ElectraForSequenceClassification (ELECTRA 模型)
- ernie — ErnieForSequenceClassification (ERNIE 模型)
- ernie_m — ErnieMForSequenceClassification (ErnieM 模型)
- esm — EsmForSequenceClassification (ESM 模型)
- falcon — FalconForSequenceClassification (Falcon 模型)
- flaubert — FlaubertForSequenceClassification (FlauBERT 模型)
- fnet — FNetForSequenceClassification (FNet 模型)
- funnel — FunnelForSequenceClassification (Funnel Transformer 模型)
- gemma — GemmaForSequenceClassification (Gemma 模型)
- gemma2 — Gemma2ForSequenceClassification (Gemma2 模型)
- glm — GlmForSequenceClassification (GLM 模型)
- gpt-sw3 — GPT2ForSequenceClassification (GPT-Sw3 模型)
- gpt2 — GPT2ForSequenceClassification (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForSequenceClassification (GPTBigCode 模型)
- gpt_neo — GPTNeoForSequenceClassification (GPT Neo 模型)
- gpt_neox — GPTNeoXForSequenceClassification (GPT NeoX 模型)
- gptj — GPTJForSequenceClassification (GPT-J 模型)
- helium — HeliumForSequenceClassification (Helium 模型)
- ibert — IBertForSequenceClassification (I-BERT 模型)
- jamba — JambaForSequenceClassification (Jamba 模型)
- jetmoe — JetMoeForSequenceClassification (JetMoe 模型)
- layoutlm — LayoutLMForSequenceClassification (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2ForSequenceClassification (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- led — LEDForSequenceClassification (LED 模型)
- lilt — LiltForSequenceClassification (LiLT 模型)
- llama — LlamaForSequenceClassification (LLaMA 模型)
- longformer — LongformerForSequenceClassification (Longformer 模型)
- luke — LukeForSequenceClassification (LUKE 模型)
- markuplm — MarkupLMForSequenceClassification (MarkupLM 模型)
- mbart — MBartForSequenceClassification (mBART 模型)
- mega — MegaForSequenceClassification (MEGA 模型)
- megatron-bert — MegatronBertForSequenceClassification (Megatron-BERT 模型)
- mistral — MistralForSequenceClassification (Mistral 模型)
- mixtral — MixtralForSequenceClassification (Mixtral 模型)
- mobilebert — MobileBertForSequenceClassification (MobileBERT 模型)
- modernbert — ModernBertForSequenceClassification (ModernBERT 模型)
- mpnet — MPNetForSequenceClassification (MPNet 模型)
- mpt — MptForSequenceClassification (MPT 模型)
- mra — MraForSequenceClassification (MRA 模型)
- mt5 — MT5ForSequenceClassification (MT5 模型)
- mvp — MvpForSequenceClassification (MVP 模型)
- nemotron — NemotronForSequenceClassification (Nemotron 模型)
- nezha — NezhaForSequenceClassification (Nezha 模型)
- nystromformer — NystromformerForSequenceClassification (Nyströmformer 模型)
- open-llama — OpenLlamaForSequenceClassification (OpenLlama 模型)
- openai-gpt — OpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- opt — OPTForSequenceClassification (OPT 模型)
- perceiver — PerceiverForSequenceClassification (Perceiver 模型)
- persimmon — PersimmonForSequenceClassification (Persimmon 模型)
- phi — PhiForSequenceClassification (Phi 模型)
- phi3 — Phi3ForSequenceClassification (Phi3 模型)
- phimoe — PhimoeForSequenceClassification (Phimoe 模型)
- plbart — PLBartForSequenceClassification (PLBart 模型)
- qdqbert — QDQBertForSequenceClassification (QDQBert 模型)
- qwen2 — Qwen2ForSequenceClassification (Qwen2 模型)
- qwen2_moe — Qwen2MoeForSequenceClassification (Qwen2MoE 模型)
- reformer — ReformerForSequenceClassification (Reformer 模型)
- rembert — RemBertForSequenceClassification (RemBERT 模型)
- roberta — RobertaForSequenceClassification (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForSequenceClassification (RoCBert 模型)
- roformer — RoFormerForSequenceClassification (RoFormer 模型)
- squeezebert — SqueezeBertForSequenceClassification (SqueezeBERT 模型)
- stablelm — StableLmForSequenceClassification (StableLm 模型)
- starcoder2 — Starcoder2ForSequenceClassification (Starcoder2 模型)
- t5 — T5ForSequenceClassification (T5 模型)
- tapas — TapasForSequenceClassification (TAPAS 模型)
- transfo-xl — TransfoXLForSequenceClassification (Transformer-XL 模型)
- umt5 — UMT5ForSequenceClassification (UMT5 模型)
- xlm — XLMForSequenceClassification (XLM 模型)
- xlm-roberta — XLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForSequenceClassification (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForSequenceClassification (XLNet 模型)
- xmod — XmodForSequenceClassification (X-MOD 模型)
- yoso — YosoForSequenceClassification (YOSO 模型)
- zamba — ZambaForSequenceClassification (Zamba 模型)
- zamba2 — Zamba2ForSequenceClassification (Zamba2 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSequenceClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSequenceClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有序列分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- AlbertConfig 配置类: TFAlbertForSequenceClassification (ALBERT 模型)
- BartConfig 配置类: TFBartForSequenceClassification (BART 模型)
- BertConfig 配置类: TFBertForSequenceClassification (BERT 模型)
- CTRLConfig 配置类: TFCTRLForSequenceClassification (CTRL 模型)
- CamembertConfig 配置类: TFCamembertForSequenceClassification (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertForSequenceClassification (ConvBERT 模型)
- DebertaConfig 配置类: TFDebertaForSequenceClassification (DeBERTa 模型)
- DebertaV2Config 配置类: TFDebertaV2ForSequenceClassification (DeBERTa-v2 模型)
- DistilBertConfig 配置类: TFDistilBertForSequenceClassification (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForSequenceClassification (ELECTRA 模型)
- EsmConfig 配置类: TFEsmForSequenceClassification (ESM 模型)
- FlaubertConfig 配置类: TFFlaubertForSequenceClassification (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForSequenceClassification (Funnel Transformer 模型)
- GPT2Config 配置类: TFGPT2ForSequenceClassification (OpenAI GPT-2 模型)
- GPTJConfig 配置类: TFGPTJForSequenceClassification (GPT-J 模型)
- LayoutLMConfig 配置类: TFLayoutLMForSequenceClassification (LayoutLM 模型)
- LayoutLMv3Config 配置类: TFLayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- LongformerConfig 配置类: TFLongformerForSequenceClassification (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForSequenceClassification (MPNet 模型)
- MistralConfig 配置类: TFMistralForSequenceClassification (Mistral 模型)
- MobileBertConfig 配置类: TFMobileBertForSequenceClassification (MobileBERT 模型)
- OpenAIGPTConfig 配置类: TFOpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- RemBertConfig 配置类: TFRemBertForSequenceClassification (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForSequenceClassification (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForSequenceClassification (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- TapasConfig 配置类: TFTapasForSequenceClassification (TAPAS 模型)
- TransfoXLConfig 配置类: TFTransfoXLForSequenceClassification (Transformer-XL 模型)
- XLMConfig 配置类: TFXLMForSequenceClassification (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetForSequenceClassification (XLNet 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
) 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一种。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则,默认为手动"eager"
实现。
从配置实例化库中的模型类之一(带有序列分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,pretrained model 的 model id,托管在 huggingface.co 的 model repo 中。
- 目录 的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件 的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的定位参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用 pretrained model 的 model id 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 下载的 pretrained model 配置应缓存到的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应为信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则要用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为会有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递到底层模型的__init__
方法(我们假设配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递到底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有序列分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — TFAlbertForSequenceClassification (ALBERT 模型)
- bart — TFBartForSequenceClassification (BART 模型)
- bert — TFBertForSequenceClassification (BERT 模型)
- camembert — TFCamembertForSequenceClassification (CamemBERT 模型)
- convbert — TFConvBertForSequenceClassification (ConvBERT 模型)
- ctrl — TFCTRLForSequenceClassification (CTRL 模型)
- deberta — TFDebertaForSequenceClassification (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForSequenceClassification (DeBERTa-v2 模型)
- distilbert — TFDistilBertForSequenceClassification (DistilBERT 模型)
- electra — TFElectraForSequenceClassification (ELECTRA 模型)
- esm — TFEsmForSequenceClassification (ESM 模型)
- flaubert — TFFlaubertForSequenceClassification (FlauBERT 模型)
- funnel — TFFunnelForSequenceClassification (Funnel Transformer 模型)
- gpt-sw3 — TFGPT2ForSequenceClassification (GPT-Sw3 模型)
- gpt2 — TFGPT2ForSequenceClassification (OpenAI GPT-2 模型)
- gptj — TFGPTJForSequenceClassification (GPT-J 模型)
- layoutlm — TFLayoutLMForSequenceClassification (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- longformer — TFLongformerForSequenceClassification (Longformer 模型)
- mistral — TFMistralForSequenceClassification (Mistral 模型)
- mobilebert — TFMobileBertForSequenceClassification (MobileBERT 模型)
- mpnet — TFMPNetForSequenceClassification (MPNet 模型)
- openai-gpt — TFOpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- rembert — TFRemBertForSequenceClassification (RemBERT 模型)
- roberta — TFRobertaForSequenceClassification (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForSequenceClassification (RoFormer 模型)
- tapas — TFTapasForSequenceClassification (TAPAS 模型)
- transfo-xl — TFTransfoXLForSequenceClassification (Transformer-XL 模型)
- xlm — TFXLMForSequenceClassification (XLM 模型)
- xlm-roberta — TFXLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- xlnet — TFXLNetForSequenceClassification (XLNet 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSequenceClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有序列分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- AlbertConfig 配置类: FlaxAlbertForSequenceClassification (ALBERT 模型)
- BartConfig 配置类: FlaxBartForSequenceClassification (BART 模型)
- BertConfig 配置类: FlaxBertForSequenceClassification (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForSequenceClassification (BigBird 模型)
- DistilBertConfig 配置类: FlaxDistilBertForSequenceClassification (DistilBERT 模型)
- ElectraConfig 配置类: FlaxElectraForSequenceClassification (ELECTRA 模型)
- MBartConfig 配置类: FlaxMBartForSequenceClassification (mBART 模型)
- RoFormerConfig 配置类: FlaxRoFormerForSequenceClassification (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaForSequenceClassification (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则,默认使用手动"eager"
实现。
从配置实例化库中的模型类之一(带有序列分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,指 huggingface.co 模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 PyTorch state_dict 保存文件 的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。 当以下情况时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不想使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在它们自己的建模文件中定义。此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递到底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性对应,将用于使用提供的kwargs
值覆盖该属性。不与任何配置属性对应的剩余键将传递到底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有序列分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — FlaxAlbertForSequenceClassification (ALBERT 模型)
- bart — FlaxBartForSequenceClassification (BART 模型)
- bert — FlaxBertForSequenceClassification (BERT 模型)
- big_bird — FlaxBigBirdForSequenceClassification (BigBird 模型)
- distilbert — FlaxDistilBertForSequenceClassification (DistilBERT 模型)
- electra — FlaxElectraForSequenceClassification (ELECTRA 模型)
- mbart — FlaxMBartForSequenceClassification (mBART 模型)
- roberta — FlaxRobertaForSequenceClassification (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForSequenceClassification (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForSequenceClassification (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForSequenceClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForSequenceClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForMultipleChoice
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有 multiple choice 头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 模型类实例化的选择基于配置类:
- AlbertConfig 配置类: AlbertForMultipleChoice (ALBERT 模型)
- BertConfig 配置类: BertForMultipleChoice (BERT 模型)
- BigBirdConfig 配置类: BigBirdForMultipleChoice (BigBird 模型)
- CamembertConfig 配置类: CamembertForMultipleChoice (CamemBERT 模型)
- CanineConfig 配置类: CanineForMultipleChoice (CANINE 模型)
- ConvBertConfig 配置类: ConvBertForMultipleChoice (ConvBERT 模型)
- Data2VecTextConfig 配置类: Data2VecTextForMultipleChoice (Data2VecText 模型)
- DebertaV2Config 配置类: DebertaV2ForMultipleChoice (DeBERTa-v2 模型)
- DistilBertConfig 配置类: DistilBertForMultipleChoice (DistilBERT 模型)
- ElectraConfig 配置类: ElectraForMultipleChoice (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForMultipleChoice (ERNIE 模型)
- ErnieMConfig 配置类: ErnieMForMultipleChoice (ErnieM 模型)
- FNetConfig 配置类: FNetForMultipleChoice (FNet 模型)
- FlaubertConfig 配置类: FlaubertForMultipleChoice (FlauBERT 模型)
- FunnelConfig 配置类: FunnelForMultipleChoice (Funnel Transformer 模型)
- IBertConfig 配置类: IBertForMultipleChoice (I-BERT 模型)
- LongformerConfig 配置类: LongformerForMultipleChoice (Longformer 模型)
- LukeConfig 配置类: LukeForMultipleChoice (LUKE 模型)
- MPNetConfig 配置类: MPNetForMultipleChoice (MPNet 模型)
- MegaConfig 配置类: MegaForMultipleChoice (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForMultipleChoice (Megatron-BERT 模型)
- MobileBertConfig 配置类: MobileBertForMultipleChoice (MobileBERT 模型)
- MraConfig 配置类: MraForMultipleChoice (MRA 模型)
- NezhaConfig 配置类: NezhaForMultipleChoice (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForMultipleChoice (Nyströmformer 模型)
- QDQBertConfig 配置类: QDQBertForMultipleChoice (QDQBert 模型)
- RemBertConfig 配置类: RemBertForMultipleChoice (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForMultipleChoice (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForMultipleChoice (RoFormer 模型)
- RobertaConfig 配置类: RobertaForMultipleChoice (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类: SqueezeBertForMultipleChoice (SqueezeBERT 模型)
- XLMConfig 配置类: XLMForMultipleChoice (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForMultipleChoice (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetForMultipleChoice (XLNet 模型)
- XmodConfig 配置类: XmodForMultipleChoice (X-MOD 模型)
- YosoConfig 配置类: YosoForMultipleChoice (YOSO 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现方式 (如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则,默认使用手动"eager"
实现。
从配置实例化库中的一个模型类 (带有多项选择头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,表示托管在 huggingface.co 模型仓库中的预训练模型的 模型 ID 。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 TensorFlow 索引检查点文件 的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加的位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是由库提供的模型 (使用预训练模型的 模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于替代从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 下载的预训练模型配置应缓存到的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 代理服务器字典,用于按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件 (例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义的模型在其自己的建模文件中定义。此选项仅应设置为True
,用于您信任的仓库,并且您已阅读其中的代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的仓库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数, 可选) — 可用于更新配置对象 (在加载后) 并初始化模型 (例如,
output_attentions=True
)。根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,**kwargs
将直接传递到底层模型的__init__
方法 (我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递到底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类 (带有多项选择头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertForMultipleChoice (ALBERT 模型)
- bert — BertForMultipleChoice (BERT 模型)
- big_bird — BigBirdForMultipleChoice (BigBird 模型)
- camembert — CamembertForMultipleChoice (CamemBERT 模型)
- canine — CanineForMultipleChoice (CANINE 模型)
- convbert — ConvBertForMultipleChoice (ConvBERT 模型)
- data2vec-text — Data2VecTextForMultipleChoice (Data2VecText 模型)
- deberta-v2 — DebertaV2ForMultipleChoice (DeBERTa-v2 模型)
- distilbert — DistilBertForMultipleChoice (DistilBERT 模型)
- electra — ElectraForMultipleChoice (ELECTRA 模型)
- ernie — ErnieForMultipleChoice (ERNIE 模型)
- ernie_m — ErnieMForMultipleChoice (ErnieM 模型)
- flaubert — FlaubertForMultipleChoice (FlauBERT 模型)
- fnet — FNetForMultipleChoice (FNet 模型)
- funnel — FunnelForMultipleChoice (Funnel Transformer 模型)
- ibert — IBertForMultipleChoice (I-BERT 模型)
- longformer — LongformerForMultipleChoice (Longformer 模型)
- luke — LukeForMultipleChoice (LUKE 模型)
- mega — MegaForMultipleChoice (MEGA 模型)
- megatron-bert — MegatronBertForMultipleChoice (Megatron-BERT 模型)
- mobilebert — MobileBertForMultipleChoice (MobileBERT 模型)
- mpnet — MPNetForMultipleChoice (MPNet 模型)
- mra — MraForMultipleChoice (MRA 模型)
- nezha — NezhaForMultipleChoice (Nezha 模型)
- nystromformer — NystromformerForMultipleChoice (Nyströmformer 模型)
- qdqbert — QDQBertForMultipleChoice (QDQBert 模型)
- rembert — RemBertForMultipleChoice (RemBERT 模型)
- roberta — RobertaForMultipleChoice (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForMultipleChoice (RoCBert 模型)
- roformer — RoFormerForMultipleChoice (RoFormer 模型)
- squeezebert — SqueezeBertForMultipleChoice (SqueezeBERT 模型)
- xlm — XLMForMultipleChoice (XLM 模型)
- xlm-roberta — XLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForMultipleChoice (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForMultipleChoice (XLNet 模型)
- xmod — XmodForMultipleChoice (X-MOD 模型)
- yoso — YosoForMultipleChoice (YOSO 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMultipleChoice.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMultipleChoice
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有 multiple choice 头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 实例化的模型类是基于配置类选择的:
- AlbertConfig 配置类: TFAlbertForMultipleChoice (ALBERT 模型)
- BertConfig 配置类: TFBertForMultipleChoice (BERT 模型)
- CamembertConfig 配置类: TFCamembertForMultipleChoice (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertForMultipleChoice (ConvBERT 模型)
- DebertaV2Config 配置类: TFDebertaV2ForMultipleChoice (DeBERTa-v2 模型)
- DistilBertConfig 配置类: TFDistilBertForMultipleChoice (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForMultipleChoice (ELECTRA 模型)
- FlaubertConfig 配置类: TFFlaubertForMultipleChoice (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForMultipleChoice (Funnel Transformer 模型)
- LongformerConfig 配置类: TFLongformerForMultipleChoice (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForMultipleChoice (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForMultipleChoice (MobileBERT 模型)
- RemBertConfig 配置类: TFRemBertForMultipleChoice (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForMultipleChoice (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForMultipleChoice (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- XLMConfig 配置类: TFXLMForMultipleChoice (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetForMultipleChoice (XLNet 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现。
从配置实例化库中的一个模型类 (带有多项选择头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - 一个 PyTorch state_dict 保存文件 的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的 positional arguments, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。 当以下情况时,可以自动加载配置:
- 该模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path
加载模型,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则用于缓存下载的预训练模型配置的目录的路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重 (参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义的模型在其自身的建模文件中定义。此选项应仅对您信任的仓库以及您已阅读代码的仓库设置为True
,因为它将在您的本地机器上执行 Hub 上的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的仓库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类 (带有多项选择头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — TFAlbertForMultipleChoice (ALBERT 模型)
- bert — TFBertForMultipleChoice (BERT 模型)
- camembert — TFCamembertForMultipleChoice (CamemBERT 模型)
- convbert — TFConvBertForMultipleChoice (ConvBERT 模型)
- deberta-v2 — TFDebertaV2ForMultipleChoice (DeBERTa-v2 模型)
- distilbert — TFDistilBertForMultipleChoice (DistilBERT 模型)
- electra — TFElectraForMultipleChoice (ELECTRA 模型)
- flaubert — TFFlaubertForMultipleChoice (FlauBERT 模型)
- funnel — TFFunnelForMultipleChoice (Funnel Transformer 模型)
- longformer — TFLongformerForMultipleChoice (Longformer 模型)
- mobilebert — TFMobileBertForMultipleChoice (MobileBERT 模型)
- mpnet — TFMPNetForMultipleChoice (MPNet 模型)
- rembert — TFRemBertForMultipleChoice (RemBERT 模型)
- roberta — TFRobertaForMultipleChoice (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForMultipleChoice (RoFormer 模型)
- xlm — TFXLMForMultipleChoice (XLM 模型)
- xlm-roberta — TFXLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- xlnet — TFXLNetForMultipleChoice (XLNet 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMultipleChoice.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForMultipleChoice
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有 multiple choice 头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: FlaxAlbertForMultipleChoice (ALBERT 模型)
- BertConfig 配置类: FlaxBertForMultipleChoice (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForMultipleChoice (BigBird 模型)
- DistilBertConfig 配置类: FlaxDistilBertForMultipleChoice (DistilBERT 模型)
- ElectraConfig 配置类: FlaxElectraForMultipleChoice (ELECTRA 模型)
- RoFormerConfig 配置类: FlaxRoFormerForMultipleChoice (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaForMultipleChoice (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库中的一个模型类 (带有多项选择头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, optional) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,以替代自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, optional) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers 的 v5 版本中删除。
- proxies (
Dict[str, str]
, optional) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义的模型在其自身的建模文件中定义。此选项应仅对您信任的仓库以及您已阅读代码的仓库设置为True
,因为它将在您的本地机器上执行 Hub 上的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的仓库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类 (带有多项选择头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — FlaxAlbertForMultipleChoice (ALBERT 模型)
- bert — FlaxBertForMultipleChoice (BERT 模型)
- big_bird — FlaxBigBirdForMultipleChoice (BigBird 模型)
- distilbert — FlaxDistilBertForMultipleChoice (DistilBERT 模型)
- electra — FlaxElectraForMultipleChoice (ELECTRA 模型)
- roberta — FlaxRobertaForMultipleChoice (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForMultipleChoice (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForMultipleChoice (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForMultipleChoice (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForMultipleChoice
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForMultipleChoice.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForNextSentencePrediction
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有下一句预测头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- BertConfig 配置类: BertForNextSentencePrediction (BERT 模型)
- ErnieConfig 配置类: ErnieForNextSentencePrediction (ERNIE 模型)
- FNetConfig 配置类: FNetForNextSentencePrediction (FNet 模型)
- MegatronBertConfig 配置类: MegatronBertForNextSentencePrediction (Megatron-BERT 模型)
- MobileBertConfig 配置类: MobileBertForNextSentencePrediction (MobileBERT 模型)
- NezhaConfig 配置类: NezhaForNextSentencePrediction (Nezha 模型)
- QDQBertConfig 配置类: QDQBertForNextSentencePrediction (QDQBert 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则,默认值为手动"eager"
实现。
从配置实例化库中的模型类之一(带有下一句预测头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 模型仓库中的预训练模型的 模型 ID 。
- 一个指向 目录 的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 tensorflow 索引检查点文件 的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,应将from_tf
设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加的位置参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 要使用的模型配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 该模型是由库提供的模型(使用预训练模型的 模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 用于代替从保存的权重文件加载的状态字典的状态字典。
如果您想从预训练配置创建模型但加载自己的权重,则可以使用此选项。但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下都默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理服务器用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应为信任的存储库且您已阅读代码的情况下设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数,可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果通过
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性都不对应的剩余键将传递给底层模型的__init__
函数。
- 如果通过
从预训练模型实例化库中的模型类之一(带有下一句预测头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bert — BertForNextSentencePrediction (BERT 模型)
- ernie — ErnieForNextSentencePrediction (ERNIE 模型)
- fnet — FNetForNextSentencePrediction (FNet 模型)
- megatron-bert — MegatronBertForNextSentencePrediction (Megatron-BERT 模型)
- mobilebert — MobileBertForNextSentencePrediction (MobileBERT 模型)
- nezha — NezhaForNextSentencePrediction (Nezha 模型)
- qdqbert — QDQBertForNextSentencePrediction (QDQBert 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForNextSentencePrediction.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForNextSentencePrediction
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有下一句预测头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- BertConfig 配置类: TFBertForNextSentencePrediction (BERT 模型)
- MobileBertConfig 配置类: TFMobileBertForNextSentencePrediction (MobileBERT 模型)
- attn_implementation (
str
, optional) — 模型中使用的注意力机制实现方式(如果相关)。 可以是"eager"
(手动实现的注意力机制),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。 默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。 否则,默认使用手动"eager"
实现。
从配置实例化库中的模型类之一(带有下一句预测头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。 在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型,然后加载 TensorFlow 模型要慢。
- model_args (额外的**位置参数**, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。 在以下情况下,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。 现在,所有下载在可能的情况下默认恢复。 将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。 此选项仅应为信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的**关键字参数**, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。 与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。 与任何配置属性都不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的模型类之一(带有下一句预测头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bert — TFBertForNextSentencePrediction (BERT 模型)
- mobilebert — TFMobileBertForNextSentencePrediction (MobileBERT 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForNextSentencePrediction.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForNextSentencePrediction
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有下一句预测头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- BertConfig 配置类: FlaxBertForNextSentencePrediction (BERT 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力机制实现方式(如果相关)。 可以是"eager"
(手动实现的注意力机制),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。 默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。 否则,默认使用手动"eager"
实现。
从配置实例化库中的模型类之一(带有下一句预测头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,托管在 huggingface.co 模型仓库中的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是库提供的模型 (使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 目录的路径,如果不想使用标准缓存,则应在其中缓存下载的预训练模型配置。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否还返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应针对您信任的存储库设置为True
,并且您已阅读过代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的模型类之一(带有下一句预测头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- bert — FlaxBertForNextSentencePrediction (BERT 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForNextSentencePrediction
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForNextSentencePrediction.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForTokenClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将实例化为库的模型类之一(带有 token classification 头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- AlbertConfig 配置类: AlbertForTokenClassification (ALBERT 模型)
- BertConfig 配置类: BertForTokenClassification (BERT 模型)
- BigBirdConfig 配置类: BigBirdForTokenClassification (BigBird 模型)
- BioGptConfig 配置类: BioGptForTokenClassification (BioGpt 模型)
- BloomConfig 配置类: BloomForTokenClassification (BLOOM 模型)
- BrosConfig 配置类: BrosForTokenClassification (BROS 模型)
- CamembertConfig 配置类: CamembertForTokenClassification (CamemBERT 模型)
- CanineConfig 配置类: CanineForTokenClassification (CANINE 模型)
- ConvBertConfig 配置类: ConvBertForTokenClassification (ConvBERT 模型)
- Data2VecTextConfig 配置类: Data2VecTextForTokenClassification (Data2VecText 模型)
- DebertaConfig 配置类: DebertaForTokenClassification (DeBERTa 模型)
- DebertaV2Config 配置类: DebertaV2ForTokenClassification (DeBERTa-v2 模型)
- DiffLlamaConfig 配置类: DiffLlamaForTokenClassification (DiffLlama 模型)
- DistilBertConfig 配置类: DistilBertForTokenClassification (DistilBERT 模型)
- ElectraConfig 配置类: ElectraForTokenClassification (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForTokenClassification (ERNIE 模型)
- ErnieMConfig 配置类: ErnieMForTokenClassification (ErnieM 模型)
- EsmConfig 配置类: EsmForTokenClassification (ESM 模型)
- FNetConfig 配置类: FNetForTokenClassification (FNet 模型)
- FalconConfig 配置类: FalconForTokenClassification (Falcon 模型)
- FlaubertConfig 配置类: FlaubertForTokenClassification (FlauBERT 模型)
- FunnelConfig 配置类: FunnelForTokenClassification (Funnel Transformer 模型)
- GPT2Config 配置类: GPT2ForTokenClassification (OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类: GPTBigCodeForTokenClassification (GPTBigCode 模型)
- GPTNeoConfig 配置类: GPTNeoForTokenClassification (GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXForTokenClassification (GPT NeoX 模型)
- Gemma2Config 配置类: Gemma2ForTokenClassification (Gemma2 模型)
- GemmaConfig 配置类: GemmaForTokenClassification (Gemma 模型)
- GlmConfig 配置类: GlmForTokenClassification (GLM 模型)
- HeliumConfig 配置类: HeliumForTokenClassification (Helium 模型)
- IBertConfig 配置类: IBertForTokenClassification (I-BERT 模型)
- LayoutLMConfig 配置类: LayoutLMForTokenClassification (LayoutLM 模型)
- LayoutLMv2Config 配置类: LayoutLMv2ForTokenClassification (LayoutLMv2 模型)
- LayoutLMv3Config 配置类: LayoutLMv3ForTokenClassification (LayoutLMv3 模型)
- LiltConfig 配置类: LiltForTokenClassification (LiLT 模型)
- LlamaConfig 配置类: LlamaForTokenClassification (LLaMA 模型)
- LongformerConfig 配置类: LongformerForTokenClassification (Longformer 模型)
- LukeConfig 配置类: LukeForTokenClassification (LUKE 模型)
- MPNetConfig 配置类: MPNetForTokenClassification (MPNet 模型)
- MT5Config 配置类: MT5ForTokenClassification (MT5 模型)
- MarkupLMConfig 配置类: MarkupLMForTokenClassification (MarkupLM 模型)
- MegaConfig 配置类: MegaForTokenClassification (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForTokenClassification (Megatron-BERT 模型)
- MistralConfig 配置类: MistralForTokenClassification (Mistral 模型)
- MixtralConfig 配置类: MixtralForTokenClassification (Mixtral 模型)
- MobileBertConfig 配置类: MobileBertForTokenClassification (MobileBERT 模型)
- ModernBertConfig 配置类: ModernBertForTokenClassification (ModernBERT 模型)
- MptConfig 配置类: MptForTokenClassification (MPT 模型)
- MraConfig 配置类: MraForTokenClassification (MRA 模型)
- NemotronConfig 配置类: NemotronForTokenClassification (Nemotron 模型)
- NezhaConfig 配置类: NezhaForTokenClassification (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForTokenClassification (Nyströmformer 模型)
- PersimmonConfig 配置类: PersimmonForTokenClassification (Persimmon 模型)
- Phi3Config 配置类: Phi3ForTokenClassification (Phi3 模型)
- PhiConfig 配置类: PhiForTokenClassification (Phi 模型)
- QDQBertConfig 配置类: QDQBertForTokenClassification (QDQBert 模型)
- Qwen2Config 配置类: Qwen2ForTokenClassification (Qwen2 模型)
- Qwen2MoeConfig 配置类: Qwen2MoeForTokenClassification (Qwen2MoE 模型)
- RemBertConfig 配置类: RemBertForTokenClassification (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForTokenClassification (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForTokenClassification (RoFormer 模型)
- RobertaConfig 配置类: RobertaForTokenClassification (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类: SqueezeBertForTokenClassification (SqueezeBERT 模型)
- StableLmConfig 配置类: StableLmForTokenClassification (StableLm 模型)
- Starcoder2Config 配置类: Starcoder2ForTokenClassification (Starcoder2 模型)
- T5Config 配置类: T5ForTokenClassification (T5 模型)
- UMT5Config 配置类: UMT5ForTokenClassification (UMT5 模型)
- XLMConfig 配置类: XLMForTokenClassification (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForTokenClassification (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForTokenClassification (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetForTokenClassification (XLNet 模型)
- XmodConfig 配置类: XmodForTokenClassification (X-MOD 模型)
- YosoConfig 配置类: YosoForTokenClassification (YOSO 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现方式(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则默认为手动"eager"
实现。
从配置实例化库中的一个模型类(带有一个 token 分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应该设置为True
,并且应该将配置对象作为config
参数提供。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加的位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当满足以下条件时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是一个更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不想使用标准缓存,则应将下载的预训练模型配置缓存到此目录的路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖现有缓存版本。 - resume_download — 已弃用且忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许加载 Hub 上自定义模型文件定义的模型。 此选项应仅对您信任的代码仓库设置为True
,并且您已阅读过其中的代码,因为它会在您的本地机器上执行 Hub 上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分存放在不同的仓库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供了config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性相对应,将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带 token 分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertForTokenClassification (ALBERT 模型)
- bert — BertForTokenClassification (BERT 模型)
- big_bird — BigBirdForTokenClassification (BigBird 模型)
- biogpt — BioGptForTokenClassification (BioGpt 模型)
- bloom — BloomForTokenClassification (BLOOM 模型)
- bros — BrosForTokenClassification (BROS 模型)
- camembert — CamembertForTokenClassification (CamemBERT 模型)
- canine — CanineForTokenClassification (CANINE 模型)
- convbert — ConvBertForTokenClassification (ConvBERT 模型)
- data2vec-text — Data2VecTextForTokenClassification (Data2VecText 模型)
- deberta — DebertaForTokenClassification (DeBERTa 模型)
- deberta-v2 — DebertaV2ForTokenClassification (DeBERTa-v2 模型)
- diffllama — DiffLlamaForTokenClassification (DiffLlama 模型)
- distilbert — DistilBertForTokenClassification (DistilBERT 模型)
- electra — ElectraForTokenClassification (ELECTRA 模型)
- ernie — ErnieForTokenClassification (ERNIE 模型)
- ernie_m — ErnieMForTokenClassification (ErnieM 模型)
- esm — EsmForTokenClassification (ESM 模型)
- falcon — FalconForTokenClassification (Falcon 模型)
- flaubert — FlaubertForTokenClassification (FlauBERT 模型)
- fnet — FNetForTokenClassification (FNet 模型)
- funnel — FunnelForTokenClassification (Funnel Transformer 模型)
- gemma — GemmaForTokenClassification (Gemma 模型)
- gemma2 — Gemma2ForTokenClassification (Gemma2 模型)
- glm — GlmForTokenClassification (GLM 模型)
- gpt-sw3 — GPT2ForTokenClassification (GPT-Sw3 模型)
- gpt2 — GPT2ForTokenClassification (OpenAI GPT-2 模型)
- gpt_bigcode — GPTBigCodeForTokenClassification (GPTBigCode 模型)
- gpt_neo — GPTNeoForTokenClassification (GPT Neo 模型)
- gpt_neox — GPTNeoXForTokenClassification (GPT NeoX 模型)
- helium — HeliumForTokenClassification (Helium 模型)
- ibert — IBertForTokenClassification (I-BERT 模型)
- layoutlm — LayoutLMForTokenClassification (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2ForTokenClassification (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForTokenClassification (LayoutLMv3 模型)
- lilt — LiltForTokenClassification (LiLT 模型)
- llama — LlamaForTokenClassification (LLaMA 模型)
- longformer — LongformerForTokenClassification (Longformer 模型)
- luke — LukeForTokenClassification (LUKE 模型)
- markuplm — MarkupLMForTokenClassification (MarkupLM 模型)
- mega — MegaForTokenClassification (MEGA 模型)
- megatron-bert — MegatronBertForTokenClassification (Megatron-BERT 模型)
- mistral — MistralForTokenClassification (Mistral 模型)
- mixtral — MixtralForTokenClassification (Mixtral 模型)
- mobilebert — MobileBertForTokenClassification (MobileBERT 模型)
- modernbert — ModernBertForTokenClassification (ModernBERT 模型)
- mpnet — MPNetForTokenClassification (MPNet 模型)
- mpt — MptForTokenClassification (MPT 模型)
- mra — MraForTokenClassification (MRA 模型)
- mt5 — MT5ForTokenClassification (MT5 模型)
- nemotron — NemotronForTokenClassification (Nemotron 模型)
- nezha — NezhaForTokenClassification (Nezha 模型)
- nystromformer — NystromformerForTokenClassification (Nyströmformer 模型)
- persimmon — PersimmonForTokenClassification (Persimmon 模型)
- phi — PhiForTokenClassification (Phi 模型)
- phi3 — Phi3ForTokenClassification (Phi3 模型)
- qdqbert — QDQBertForTokenClassification (QDQBert 模型)
- qwen2 — Qwen2ForTokenClassification (Qwen2 模型)
- qwen2_moe — Qwen2MoeForTokenClassification (Qwen2MoE 模型)
- rembert — RemBertForTokenClassification (RemBERT 模型)
- roberta — RobertaForTokenClassification (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForTokenClassification (RoCBert 模型)
- roformer — RoFormerForTokenClassification (RoFormer 模型)
- squeezebert — SqueezeBertForTokenClassification (SqueezeBERT 模型)
- stablelm — StableLmForTokenClassification (StableLm 模型)
- starcoder2 — Starcoder2ForTokenClassification (Starcoder2 模型)
- t5 — T5ForTokenClassification (T5 模型)
- umt5 — UMT5ForTokenClassification (UMT5 模型)
- xlm — XLMForTokenClassification (XLM 模型)
- xlm-roberta — XLMRobertaForTokenClassification (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForTokenClassification (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForTokenClassification (XLNet 模型)
- xmod — XmodForTokenClassification (X-MOD 模型)
- yoso — YosoForTokenClassification (YOSO 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForTokenClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForTokenClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将实例化为库的模型类之一(带有 token classification 头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: TFAlbertForTokenClassification (ALBERT 模型)
- BertConfig 配置类: TFBertForTokenClassification (BERT 模型)
- CamembertConfig 配置类: TFCamembertForTokenClassification (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertForTokenClassification (ConvBERT 模型)
- DebertaConfig 配置类: TFDebertaForTokenClassification (DeBERTa 模型)
- DebertaV2Config 配置类: TFDebertaV2ForTokenClassification (DeBERTa-v2 模型)
- DistilBertConfig 配置类: TFDistilBertForTokenClassification (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForTokenClassification (ELECTRA 模型)
- EsmConfig 配置类: TFEsmForTokenClassification (ESM 模型)
- FlaubertConfig 配置类: TFFlaubertForTokenClassification (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForTokenClassification (Funnel Transformer 模型)
- LayoutLMConfig 配置类: TFLayoutLMForTokenClassification (LayoutLM 模型)
- LayoutLMv3Config 配置类: TFLayoutLMv3ForTokenClassification (LayoutLMv3 模型)
- LongformerConfig 配置类: TFLongformerForTokenClassification (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForTokenClassification (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForTokenClassification (MobileBERT 模型)
- RemBertConfig 配置类: TFRemBertForTokenClassification (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForTokenClassification (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForTokenClassification (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- XLMConfig 配置类: TFXLMForTokenClassification (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForTokenClassification (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetForTokenClassification (XLNet 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认值为手动"eager"
实现。
从配置实例化库中的一个模型类(带有一个 token 分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source >( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - 一个 PyTorch state_dict 保存文件 的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。 在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。 在以下情况下,可以自动加载配置:
- 该模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, optional, 默认为False
) — 从 PyTorch 检查点保存文件中加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, optional) — 一个代理服务器字典,用于指定每个协议或端点的代理,例如:{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, optional, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, 默认为False
) — 是否允许 Hub 上自定义的模型在其自己的建模文件中定义。此选项应仅针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分在不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带 token 分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — TFAlbertForTokenClassification (ALBERT 模型)
- bert — TFBertForTokenClassification (BERT 模型)
- camembert — TFCamembertForTokenClassification (CamemBERT 模型)
- convbert — TFConvBertForTokenClassification (ConvBERT 模型)
- deberta — TFDebertaForTokenClassification (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForTokenClassification (DeBERTa-v2 模型)
- distilbert — TFDistilBertForTokenClassification (DistilBERT 模型)
- electra — TFElectraForTokenClassification (ELECTRA 模型)
- esm — TFEsmForTokenClassification (ESM 模型)
- flaubert — TFFlaubertForTokenClassification (FlauBERT 模型)
- funnel — TFFunnelForTokenClassification (Funnel Transformer 模型)
- layoutlm — TFLayoutLMForTokenClassification (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3ForTokenClassification (LayoutLMv3 模型)
- longformer — TFLongformerForTokenClassification (Longformer 模型)
- mobilebert — TFMobileBertForTokenClassification (MobileBERT 模型)
- mpnet — TFMPNetForTokenClassification (MPNet 模型)
- rembert — TFRemBertForTokenClassification (RemBERT 模型)
- roberta — TFRobertaForTokenClassification (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForTokenClassification (RoFormer 模型)
- xlm — TFXLMForTokenClassification (XLM 模型)
- xlm-roberta — TFXLMRobertaForTokenClassification (XLM-RoBERTa 模型)
- xlnet — TFXLNetForTokenClassification (XLNet 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForTokenClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForTokenClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将实例化为库的模型类之一(带有 token classification 头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source >( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: FlaxAlbertForTokenClassification (ALBERT 模型)
- BertConfig 配置类: FlaxBertForTokenClassification (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForTokenClassification (BigBird 模型)
- DistilBertConfig 配置类: FlaxDistilBertForTokenClassification (DistilBERT 模型)
- ElectraConfig 配置类: FlaxElectraForTokenClassification (ELECTRA 模型)
- RoFormerConfig 配置类: FlaxRoFormerForTokenClassification (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaForTokenClassification (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForTokenClassification (XLM-RoBERTa 模型)
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a directory containing model weights saved using save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, optional) — Will be passed along to the underlying model
__init__()
method. - config (PretrainedConfig, optional) — Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
- cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). - force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - resume_download — Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (
Dict[str, str]
, optional) — 用于按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理服务器用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许 Hub 上自定义模型在它们自己的建模文件中定义。 此选项仅应针对您信任并且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (additional keyword arguments, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键(对应于配置属性)将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带 token 分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — FlaxAlbertForTokenClassification (ALBERT 模型)
- bert — FlaxBertForTokenClassification (BERT 模型)
- big_bird — FlaxBigBirdForTokenClassification (BigBird 模型)
- distilbert — FlaxDistilBertForTokenClassification (DistilBERT 模型)
- electra — FlaxElectraForTokenClassification (ELECTRA 模型)
- roberta — FlaxRobertaForTokenClassification (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForTokenClassification (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForTokenClassification (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForTokenClassification (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForTokenClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForTokenClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForTokenClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有问答头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlbertConfig 配置类: AlbertForQuestionAnswering (ALBERT 模型)
- BartConfig 配置类: BartForQuestionAnswering (BART 模型)
- BertConfig 配置类: BertForQuestionAnswering (BERT 模型)
- BigBirdConfig 配置类: BigBirdForQuestionAnswering (BigBird 模型)
- BigBirdPegasusConfig 配置类: BigBirdPegasusForQuestionAnswering (BigBird-Pegasus 模型)
- BloomConfig 配置类: BloomForQuestionAnswering (BLOOM 模型)
- CamembertConfig 配置类: CamembertForQuestionAnswering (CamemBERT 模型)
- CanineConfig 配置类: CanineForQuestionAnswering (CANINE 模型)
- ConvBertConfig 配置类: ConvBertForQuestionAnswering (ConvBERT 模型)
- Data2VecTextConfig 配置类: Data2VecTextForQuestionAnswering (Data2VecText 模型)
- DebertaConfig 配置类: DebertaForQuestionAnswering (DeBERTa 模型)
- DebertaV2Config 配置类: DebertaV2ForQuestionAnswering (DeBERTa-v2 模型)
- DiffLlamaConfig 配置类: DiffLlamaForQuestionAnswering (DiffLlama 模型)
- DistilBertConfig 配置类: DistilBertForQuestionAnswering (DistilBERT 模型)
- ElectraConfig 配置类: ElectraForQuestionAnswering (ELECTRA 模型)
- ErnieConfig 配置类: ErnieForQuestionAnswering (ERNIE 模型)
- ErnieMConfig 配置类: ErnieMForQuestionAnswering (ErnieM 模型)
- FNetConfig 配置类: FNetForQuestionAnswering (FNet 模型)
- FalconConfig 配置类: FalconForQuestionAnswering (Falcon 模型)
- FlaubertConfig 配置类: FlaubertForQuestionAnsweringSimple (FlauBERT 模型)
- FunnelConfig 配置类: FunnelForQuestionAnswering (Funnel Transformer 模型)
- GPT2Config 配置类: GPT2ForQuestionAnswering (OpenAI GPT-2 模型)
- GPTJConfig 配置类: GPTJForQuestionAnswering (GPT-J 模型)
- GPTNeoConfig 配置类: GPTNeoForQuestionAnswering (GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXForQuestionAnswering (GPT NeoX 模型)
- IBertConfig 配置类: IBertForQuestionAnswering (I-BERT 模型)
- LEDConfig 配置类: LEDForQuestionAnswering (LED 模型)
- LayoutLMv2Config 配置类: LayoutLMv2ForQuestionAnswering (LayoutLMv2 模型)
- LayoutLMv3Config 配置类: LayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- LiltConfig 配置类: LiltForQuestionAnswering (LiLT 模型)
- LlamaConfig 配置类: LlamaForQuestionAnswering (LLaMA 模型)
- LongformerConfig 配置类: LongformerForQuestionAnswering (Longformer 模型)
- LukeConfig 配置类: LukeForQuestionAnswering (LUKE 模型)
- LxmertConfig 配置类: LxmertForQuestionAnswering (LXMERT 模型)
- MBartConfig 配置类: MBartForQuestionAnswering (mBART 模型)
- MPNetConfig 配置类: MPNetForQuestionAnswering (MPNet 模型)
- MT5Config 配置类: MT5ForQuestionAnswering (MT5 模型)
- MarkupLMConfig 配置类: MarkupLMForQuestionAnswering (MarkupLM 模型)
- MegaConfig 配置类: MegaForQuestionAnswering (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForQuestionAnswering (Megatron-BERT 模型)
- MistralConfig 配置类: MistralForQuestionAnswering (Mistral 模型)
- MixtralConfig 配置类: MixtralForQuestionAnswering (Mixtral 模型)
- MobileBertConfig 配置类: MobileBertForQuestionAnswering (MobileBERT 模型)
- MptConfig 配置类: MptForQuestionAnswering (MPT 模型)
- MraConfig 配置类: MraForQuestionAnswering (MRA 模型)
- MvpConfig 配置类: MvpForQuestionAnswering (MVP 模型)
- NemotronConfig 配置类: NemotronForQuestionAnswering (Nemotron 模型)
- NezhaConfig 配置类: NezhaForQuestionAnswering (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForQuestionAnswering (Nyströmformer 模型)
- OPTConfig 配置类: OPTForQuestionAnswering (OPT 模型)
- QDQBertConfig 配置类: QDQBertForQuestionAnswering (QDQBert 模型)
- Qwen2Config 配置类: Qwen2ForQuestionAnswering (Qwen2 模型)
- Qwen2MoeConfig 配置类: Qwen2MoeForQuestionAnswering (Qwen2MoE 模型)
- ReformerConfig 配置类: ReformerForQuestionAnswering (Reformer 模型)
- RemBertConfig 配置类: RemBertForQuestionAnswering (RemBERT 模型)
- RoCBertConfig 配置类: RoCBertForQuestionAnswering (RoCBert 模型)
- RoFormerConfig 配置类: RoFormerForQuestionAnswering (RoFormer 模型)
- RobertaConfig 配置类: RobertaForQuestionAnswering (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: RobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- SplinterConfig 配置类: SplinterForQuestionAnswering (Splinter 模型)
- SqueezeBertConfig 配置类: SqueezeBertForQuestionAnswering (SqueezeBERT 模型)
- T5Config 配置类: T5ForQuestionAnswering (T5 模型)
- UMT5Config 配置类: UMT5ForQuestionAnswering (UMT5 模型)
- XLMConfig 配置类: XLMForQuestionAnsweringSimple (XLM 模型)
- XLMRobertaConfig 配置类: XLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类: XLMRobertaXLForQuestionAnswering (XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类: XLNetForQuestionAnsweringSimple (XLNet 模型)
- XmodConfig 配置类: XmodForQuestionAnswering (X-MOD 模型)
- YosoConfig 配置类: YosoForQuestionAnswering (YOSO 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现方式(如果相关)。 可以是"eager"
(手动实现注意力)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一种。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有问答头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< 源码 > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 的模型仓库中。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 tensorflow 索引检查点文件 的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应提供一个配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型,然后再加载 PyTorch 模型要慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。在以下情况下可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在该目录中找到了名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载自己的权重,则可以使用此选项。 不过,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖可能存在的缓存版本。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理服务器用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义的模型在其自己的建模文件中定义。 此选项仅应为信任的存储库设置为True
,并且您已阅读其中的代码,因为它将在您的本地计算机上执行 Hub 上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个对应于配置属性的键将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — AlbertForQuestionAnswering (ALBERT 模型)
- bart — BartForQuestionAnswering (BART 模型)
- bert — BertForQuestionAnswering (BERT 模型)
- big_bird — BigBirdForQuestionAnswering (BigBird 模型)
- bigbird_pegasus — BigBirdPegasusForQuestionAnswering (BigBird-Pegasus 模型)
- bloom — BloomForQuestionAnswering (BLOOM 模型)
- camembert — CamembertForQuestionAnswering (CamemBERT 模型)
- canine — CanineForQuestionAnswering (CANINE 模型)
- convbert — ConvBertForQuestionAnswering (ConvBERT 模型)
- data2vec-text — Data2VecTextForQuestionAnswering (Data2VecText 模型)
- deberta — DebertaForQuestionAnswering (DeBERTa 模型)
- deberta-v2 — DebertaV2ForQuestionAnswering (DeBERTa-v2 模型)
- diffllama — DiffLlamaForQuestionAnswering (DiffLlama 模型)
- distilbert — DistilBertForQuestionAnswering (DistilBERT 模型)
- electra — ElectraForQuestionAnswering (ELECTRA 模型)
- ernie — ErnieForQuestionAnswering (ERNIE 模型)
- ernie_m — ErnieMForQuestionAnswering (ErnieM 模型)
- falcon — FalconForQuestionAnswering (Falcon 模型)
- flaubert — FlaubertForQuestionAnsweringSimple (FlauBERT 模型)
- fnet — FNetForQuestionAnswering (FNet 模型)
- funnel — FunnelForQuestionAnswering (Funnel Transformer 模型)
- gpt2 — GPT2ForQuestionAnswering (OpenAI GPT-2 模型)
- gpt_neo — GPTNeoForQuestionAnswering (GPT Neo 模型)
- gpt_neox — GPTNeoXForQuestionAnswering (GPT NeoX 模型)
- gptj — GPTJForQuestionAnswering (GPT-J 模型)
- ibert — IBertForQuestionAnswering (I-BERT 模型)
- layoutlmv2 — LayoutLMv2ForQuestionAnswering (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- led — LEDForQuestionAnswering (LED 模型)
- lilt — LiltForQuestionAnswering (LiLT 模型)
- llama — LlamaForQuestionAnswering (LLaMA 模型)
- longformer — LongformerForQuestionAnswering (Longformer 模型)
- luke — LukeForQuestionAnswering (LUKE 模型)
- lxmert — LxmertForQuestionAnswering (LXMERT 模型)
- markuplm — MarkupLMForQuestionAnswering (MarkupLM 模型)
- mbart — MBartForQuestionAnswering (mBART 模型)
- mega — MegaForQuestionAnswering (MEGA 模型)
- megatron-bert — MegatronBertForQuestionAnswering (Megatron-BERT 模型)
- mistral — MistralForQuestionAnswering (Mistral 模型)
- mixtral — MixtralForQuestionAnswering (Mixtral 模型)
- mobilebert — MobileBertForQuestionAnswering (MobileBERT 模型)
- mpnet — MPNetForQuestionAnswering (MPNet 模型)
- mpt — MptForQuestionAnswering (MPT 模型)
- mra — MraForQuestionAnswering (MRA 模型)
- mt5 — MT5ForQuestionAnswering (MT5 模型)
- mvp — MvpForQuestionAnswering (MVP 模型)
- nemotron — NemotronForQuestionAnswering (Nemotron 模型)
- nezha — NezhaForQuestionAnswering (Nezha 模型)
- nystromformer — NystromformerForQuestionAnswering (Nyströmformer 模型)
- opt — OPTForQuestionAnswering (OPT 模型)
- qdqbert — QDQBertForQuestionAnswering (QDQBert 模型)
- qwen2 — Qwen2ForQuestionAnswering (Qwen2 模型)
- qwen2_moe — Qwen2MoeForQuestionAnswering (Qwen2MoE 模型)
- reformer — ReformerForQuestionAnswering (Reformer 模型)
- rembert — RemBertForQuestionAnswering (RemBERT 模型)
- roberta — RobertaForQuestionAnswering (RoBERTa 模型)
- roberta-prelayernorm — RobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- roc_bert — RoCBertForQuestionAnswering (RoCBert 模型)
- roformer — RoFormerForQuestionAnswering (RoFormer 模型)
- splinter — SplinterForQuestionAnswering (Splinter 模型)
- squeezebert — SqueezeBertForQuestionAnswering (SqueezeBERT 模型)
- t5 — T5ForQuestionAnswering (T5 模型)
- umt5 — UMT5ForQuestionAnswering (UMT5 模型)
- xlm — XLMForQuestionAnsweringSimple (XLM 模型)
- xlm-roberta — XLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
- xlm-roberta-xl — XLMRobertaXLForQuestionAnswering (XLM-RoBERTa-XL 模型)
- xlnet — XLNetForQuestionAnsweringSimple (XLNet 模型)
- xmod — XmodForQuestionAnswering (X-MOD 模型)
- yoso — YosoForQuestionAnswering (YOSO 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForQuestionAnswering.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有问答头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- AlbertConfig 配置类: TFAlbertForQuestionAnswering (ALBERT 模型)
- BertConfig 配置类: TFBertForQuestionAnswering (BERT 模型)
- CamembertConfig 配置类: TFCamembertForQuestionAnswering (CamemBERT 模型)
- ConvBertConfig 配置类: TFConvBertForQuestionAnswering (ConvBERT 模型)
- DebertaConfig 配置类: TFDebertaForQuestionAnswering (DeBERTa 模型)
- DebertaV2Config 配置类: TFDebertaV2ForQuestionAnswering (DeBERTa-v2 模型)
- DistilBertConfig 配置类: TFDistilBertForQuestionAnswering (DistilBERT 模型)
- ElectraConfig 配置类: TFElectraForQuestionAnswering (ELECTRA 模型)
- FlaubertConfig 配置类: TFFlaubertForQuestionAnsweringSimple (FlauBERT 模型)
- FunnelConfig 配置类: TFFunnelForQuestionAnswering (Funnel Transformer 模型)
- GPTJConfig 配置类: TFGPTJForQuestionAnswering (GPT-J 模型)
- LayoutLMv3Config 配置类: TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- LongformerConfig 配置类: TFLongformerForQuestionAnswering (Longformer 模型)
- MPNetConfig 配置类: TFMPNetForQuestionAnswering (MPNet 模型)
- MobileBertConfig 配置类: TFMobileBertForQuestionAnswering (MobileBERT 模型)
- RemBertConfig 配置类: TFRemBertForQuestionAnswering (RemBERT 模型)
- RoFormerConfig 配置类: TFRoFormerForQuestionAnswering (RoFormer 模型)
- RobertaConfig 配置类: TFRobertaForQuestionAnswering (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: TFRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- XLMConfig 配置类: TFXLMForQuestionAnsweringSimple (XLM 模型)
- XLMRobertaConfig 配置类: TFXLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
- XLNetConfig 配置类: TFXLNetForQuestionAnsweringSimple (XLNet 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是以下任何一种:"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,对于 torch>=2.1.1 将使用 SDPA。否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有问答头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 PyTorch state_dict 保存文件 的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。当满足以下条件时,可以自动加载配置:
- 该模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不想使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上定义的自定义模型在其自己的建模文件中。 此选项仅应针对您信任的仓库以及您已阅读代码的仓库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的仓库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性相对应,将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性都不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — TFAlbertForQuestionAnswering (ALBERT 模型)
- bert — TFBertForQuestionAnswering (BERT 模型)
- camembert — TFCamembertForQuestionAnswering (CamemBERT 模型)
- convbert — TFConvBertForQuestionAnswering (ConvBERT 模型)
- deberta — TFDebertaForQuestionAnswering (DeBERTa 模型)
- deberta-v2 — TFDebertaV2ForQuestionAnswering (DeBERTa-v2 模型)
- distilbert — TFDistilBertForQuestionAnswering (DistilBERT 模型)
- electra — TFElectraForQuestionAnswering (ELECTRA 模型)
- flaubert — TFFlaubertForQuestionAnsweringSimple (FlauBERT 模型)
- funnel — TFFunnelForQuestionAnswering (Funnel Transformer 模型)
- gptj — TFGPTJForQuestionAnswering (GPT-J 模型)
- layoutlmv3 — TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- longformer — TFLongformerForQuestionAnswering (Longformer 模型)
- mobilebert — TFMobileBertForQuestionAnswering (MobileBERT 模型)
- mpnet — TFMPNetForQuestionAnswering (MPNet 模型)
- rembert — TFRemBertForQuestionAnswering (RemBERT 模型)
- roberta — TFRobertaForQuestionAnswering (RoBERTa 模型)
- roberta-prelayernorm — TFRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- roformer — TFRoFormerForQuestionAnswering (RoFormer 模型)
- xlm — TFXLMForQuestionAnsweringSimple (XLM 模型)
- xlm-roberta — TFXLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
- xlnet — TFXLNetForQuestionAnsweringSimple (XLNet 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForQuestionAnswering.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有问答头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — config (PretrainedConfig) — 实例化模型所用的模型类会根据配置类进行选择:
- AlbertConfig 配置类: FlaxAlbertForQuestionAnswering (ALBERT 模型)
- BartConfig 配置类: FlaxBartForQuestionAnswering (BART 模型)
- BertConfig 配置类: FlaxBertForQuestionAnswering (BERT 模型)
- BigBirdConfig 配置类: FlaxBigBirdForQuestionAnswering (BigBird 模型)
- DistilBertConfig 配置类: FlaxDistilBertForQuestionAnswering (DistilBERT 模型)
- ElectraConfig 配置类: FlaxElectraForQuestionAnswering (ELECTRA 模型)
- MBartConfig 配置类: FlaxMBartForQuestionAnswering (mBART 模型)
- RoFormerConfig 配置类: FlaxRoFormerForQuestionAnswering (RoFormer 模型)
- RobertaConfig 配置类: FlaxRobertaForQuestionAnswering (RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类: FlaxRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- XLMRobertaConfig 配置类: FlaxXLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
- attn_implementation (
str
, 可选) — attn_implementation (str
, optional) — 模型中使用的 attention 实现方式(如果相关)。可以是"eager"
(attention 的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一种。默认情况下,如果可用,对于 torch>=2.1.1 将使用 SDPA。否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有问答头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — pretrained_model_name_or_path (str
oros.PathLike
) — 可以是:- 一个字符串,表示 huggingface.co 上模型仓库中托管的预训练模型的 模型 ID 。
- 一个指向 目录 的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 指向 PyTorch state_dict 保存文件 的路径或 URL(例如
./pt_model/pytorch_model.bin
)。在这种情况下,应将from_pt
设置为True
,并且应将配置对象作为config
参数提供。与使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型,然后加载 TensorFlow 模型相比,此加载路径速度较慢。
- model_args (其他位置参数, 可选) — model_args (additional positional arguments, optional) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — config (PretrainedConfig, optional) — 用于模型的配置,以替代自动加载的配置。在以下情况下可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的 模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — cache_dir (str
oros.PathLike
, optional) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_pt (
bool
, 可选, defaults toFalse
) — from_pt (bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, defaults toFalse
) — force_download (bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — proxies (Dict[str, str]
, optional) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 - output_loading_info(
bool
, 可选, defaults toFalse
) — output_loading_info(bool
, optional, defaults toFalse
) — 是否还返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, defaults toFalse
) — local_files_only(bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, defaults to"main"
) — revision (str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标记名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, defaults toFalse
) — trust_remote_code (bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义模型文件中定义的自定义模型。此选项仅应为设置为True
,用于您信任的存储库,并且您已阅读过代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, defaults to"main"
) — code_revision (str
, optional, defaults to"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标记名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (其他关键字参数, 可选) — kwargs (additional keyword arguments, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递到底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained()) 。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性都不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- albert — FlaxAlbertForQuestionAnswering (ALBERT 模型)
- bart — FlaxBartForQuestionAnswering (BART 模型)
- bert — FlaxBertForQuestionAnswering (BERT 模型)
- big_bird — FlaxBigBirdForQuestionAnswering (BigBird 模型)
- distilbert — FlaxDistilBertForQuestionAnswering (DistilBERT 模型)
- electra — FlaxElectraForQuestionAnswering (ELECTRA 模型)
- mbart — FlaxMBartForQuestionAnswering (mBART 模型)
- roberta — FlaxRobertaForQuestionAnswering (RoBERTa 模型)
- roberta-prelayernorm — FlaxRobertaPreLayerNormForQuestionAnswering (RoBERTa-PreLayerNorm 模型)
- roformer — FlaxRoFormerForQuestionAnswering (RoFormer 模型)
- xlm-roberta — FlaxXLMRobertaForQuestionAnswering (XLM-RoBERTa 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForQuestionAnswering.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForTextEncoding
TFAutoModelForTextEncoding
计算机视觉
以下自动类可用于以下计算机视觉任务。
AutoModelForDepthEstimation
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有深度估计头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< 源代码 > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- DPTConfig 配置类: DPTForDepthEstimation (DPT 模型)
- DepthAnythingConfig 配置类: DepthAnythingForDepthEstimation (Depth Anything 模型)
- DepthProConfig 配置类: DepthProForDepthEstimation (DepthPro 模型)
- GLPNConfig 配置类: GLPNForDepthEstimation (GLPN 模型)
- PromptDepthAnythingConfig 配置类: PromptDepthAnythingForDepthEstimation (PromptDepthAnything 模型)
- ZoeDepthConfig 配置类: ZoeDepthForDepthEstimation (ZoeDepth 模型)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有深度估计头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< 源代码 > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的 positional 参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,以代替自动加载的配置。在以下情况下,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- state_dict (Dict[str, torch.Tensor], optional) — 用于代替从保存的权重文件加载的状态字典的状态字典。
如果要从预训练配置创建模型但加载自己的权重,则可以使用此选项。但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, optional) — 目录的路径,如果不想使用标准缓存,则应在其中缓存下载的预训练模型配置。 - from_tf (
bool
, optional, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, optional) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, 默认为False
) — 是否也返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, optional, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, 默认为False
) — 是否允许在 Hub 上自定义模型在其自己的建模文件中定义。此选项应仅对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数,optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载而行为不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有深度估计头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- depth_anything — DepthAnythingForDepthEstimation (Depth Anything 模型)
- depth_pro — DepthProForDepthEstimation (DepthPro 模型)
- dpt — DPTForDepthEstimation (DPT 模型)
- glpn — GLPNForDepthEstimation (GLPN 模型)
- prompt_depth_anything — PromptDepthAnythingForDepthEstimation (PromptDepthAnything 模型)
- zoedepth — ZoeDepthForDepthEstimation (ZoeDepth 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForDepthEstimation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForDepthEstimation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForDepthEstimation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForDepthEstimation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有图像分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
从配置实例化库中的一个模型类 (带图像分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
从预训练模型实例化库中的一个模型类 (带图像分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- beit — BeitForImageClassification (BEiT 模型)
- bit — BitForImageClassification (BiT 模型)
- clip — CLIPForImageClassification (CLIP 模型)
- convnext — ConvNextForImageClassification (ConvNeXT 模型)
- convnextv2 — ConvNextV2ForImageClassification (ConvNeXTV2 模型)
- cvt — CvtForImageClassification (CvT 模型)
- data2vec-vision — Data2VecVisionForImageClassification (Data2VecVision 模型)
- deit — DeiTForImageClassification 或 DeiTForImageClassificationWithTeacher (DeiT 模型)
- dinat — DinatForImageClassification (DiNAT 模型)
- dinov2 — Dinov2ForImageClassification (DINOv2 模型)
- dinov2_with_registers — Dinov2WithRegistersForImageClassification (带有 Registers 的 DINOv2 模型)
- efficientformer — EfficientFormerForImageClassification 或 EfficientFormerForImageClassificationWithTeacher (EfficientFormer 模型)
- efficientnet — EfficientNetForImageClassification (EfficientNet 模型)
- focalnet — FocalNetForImageClassification (FocalNet 模型)
- hiera — HieraForImageClassification (Hiera 模型)
- ijepa — IJepaForImageClassification (I-JEPA 模型)
- imagegpt — ImageGPTForImageClassification (ImageGPT 模型)
- levit — LevitForImageClassification 或 LevitForImageClassificationWithTeacher (LeViT 模型)
- mobilenet_v1 — MobileNetV1ForImageClassification (MobileNetV1 模型)
- mobilenet_v2 — MobileNetV2ForImageClassification (MobileNetV2 模型)
- mobilevit — MobileViTForImageClassification (MobileViT 模型)
- mobilevitv2 — MobileViTV2ForImageClassification (MobileViTV2 模型)
- nat — NatForImageClassification (NAT 模型)
- perceiver — PerceiverForImageClassificationLearned 或 PerceiverForImageClassificationFourier 或 PerceiverForImageClassificationConvProcessing (Perceiver 模型)
- poolformer — PoolFormerForImageClassification (PoolFormer 模型)
- pvt — PvtForImageClassification (PVT 模型)
- pvt_v2 — PvtV2ForImageClassification (PVTv2 模型)
- regnet — RegNetForImageClassification (RegNet 模型)
- resnet — ResNetForImageClassification (ResNet 模型)
- segformer — SegformerForImageClassification (SegFormer 模型)
- shieldgemma2 — ShieldGemma2ForImageClassification (Shieldgemma2 模型)
- siglip — SiglipForImageClassification (SigLIP 模型)
- siglip2 — Siglip2ForImageClassification (SigLIP2 模型)
- swiftformer — SwiftFormerForImageClassification (SwiftFormer 模型)
- swin — SwinForImageClassification (Swin Transformer 模型)
- swinv2 — Swinv2ForImageClassification (Swin Transformer V2 模型)
- textnet — TextNetForImageClassification (TextNet 模型)
- timm_wrapper — TimmWrapperForImageClassification (TimmWrapperModel 模型)
- van — VanForImageClassification (VAN 模型)
- vit — ViTForImageClassification (ViT 模型)
- vit_hybrid — ViTHybridForImageClassification (ViT Hybrid 模型)
- vit_msn — ViTMSNForImageClassification (ViTMSN 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForImageClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有图像分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- ConvNextConfig 配置类: TFConvNextForImageClassification (ConvNeXT 模型)
- ConvNextV2Config 配置类: TFConvNextV2ForImageClassification (ConvNeXTV2 模型)
- CvtConfig 配置类: TFCvtForImageClassification (CvT 模型)
- Data2VecVisionConfig 配置类: TFData2VecVisionForImageClassification (Data2VecVision 模型)
- DeiTConfig 配置类: TFDeiTForImageClassification 或 TFDeiTForImageClassificationWithTeacher (DeiT 模型)
- EfficientFormerConfig 配置类: TFEfficientFormerForImageClassification 或 TFEfficientFormerForImageClassificationWithTeacher (EfficientFormer 模型)
- MobileViTConfig 配置类: TFMobileViTForImageClassification (MobileViT 模型)
- RegNetConfig 配置类: TFRegNetForImageClassification (RegNet 模型)
- ResNetConfig 配置类: TFResNetForImageClassification (ResNet 模型)
- SegformerConfig 配置类: TFSegformerForImageClassification (SegFormer 模型)
- SwiftFormerConfig 配置类: TFSwiftFormerForImageClassification (SwiftFormer 模型)
- SwinConfig 配置类: TFSwinForImageClassification (Swin Transformer 模型)
- ViTConfig 配置类: TFViTForImageClassification (ViT 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现 (如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库中的一个模型类 (带图像分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 模型仓库中托管的预训练模型的 模型 ID 。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的positional参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是库提供的模型 (使用预训练模型的 模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重 (参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在其自己的建模文件中定义 Hub 上的自定义模型。此选项仅应针对您信任的仓库以及您已阅读代码的仓库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码位于与模型其余部分不同的仓库中,则要用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数, 可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类 (带图像分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- convnext — TFConvNextForImageClassification (ConvNeXT 模型)
- convnextv2 — TFConvNextV2ForImageClassification (ConvNeXTV2 模型)
- cvt — TFCvtForImageClassification (CvT 模型)
- data2vec-vision — TFData2VecVisionForImageClassification (Data2VecVision 模型)
- deit — TFDeiTForImageClassification 或 TFDeiTForImageClassificationWithTeacher (DeiT 模型)
- efficientformer — TFEfficientFormerForImageClassification 或 TFEfficientFormerForImageClassificationWithTeacher (EfficientFormer 模型)
- mobilevit — TFMobileViTForImageClassification (MobileViT 模型)
- regnet — TFRegNetForImageClassification (RegNet 模型)
- resnet — TFResNetForImageClassification (ResNet 模型)
- segformer — TFSegformerForImageClassification (SegFormer 模型)
- swiftformer — TFSwiftFormerForImageClassification (SwiftFormer 模型)
- swin — TFSwinForImageClassification (Swin Transformer 模型)
- vit — TFViTForImageClassification (ViT 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForImageClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有图像分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 模型类,用于实例化是基于配置类选择的:
- BeitConfig 配置类: FlaxBeitForImageClassification (BEiT 模型)
- Dinov2Config 配置类: FlaxDinov2ForImageClassification (DINOv2 模型)
- RegNetConfig 配置类: FlaxRegNetForImageClassification (RegNet 模型)
- ResNetConfig 配置类: FlaxResNetForImageClassification (ResNet 模型)
- ViTConfig 配置类: FlaxViTForImageClassification (ViT 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库中的一个模型类 (带图像分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型仓库中。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 PyTorch state_dict 保存文件 的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型,然后加载 TensorFlow 模型要慢。
- model_args (额外的 positional arguments, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。在以下情况下,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重 (参见pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖可能存在的缓存版本。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件 (例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上定义的自定义模型在其自己的建模文件中。此选项仅应设置为True
,用于您信任的存储库,并且您已阅读过代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分存放在不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword arguments, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类 (带图像分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- beit — FlaxBeitForImageClassification (BEiT 模型)
- dinov2 — Dinov2Config (DINOv2 模型)
- regnet — FlaxRegNetForImageClassification (RegNet 模型)
- resnet — FlaxResNetForImageClassification (ResNet 模型)
- vit — FlaxViTForImageClassification (ViT 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForVideoClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有视频分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 模型类,用于实例化是基于配置类选择的:
- TimesformerConfig 配置类: TimesformerForVideoClassification (TimeSformer 模型)
- VideoMAEConfig 配置类: VideoMAEForVideoClassification (VideoMAE 模型)
- VivitConfig 配置类: VivitForVideoClassification (ViViT 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有视频分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型仓库中。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 TensorFlow 索引检查点文件 的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应该设置为True
,并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型,然后加载 PyTorch 模型要慢。
- model_args (额外的 positional arguments, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,以替代自动加载的配置。当满足以下条件时,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], optional) — 一个状态字典,用于替代从已保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载自己的权重,可以使用此选项。但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, optional) — 缓存已下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, optional, defaults toFalse
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。所有下载现在在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, optional) — 一个代理服务器字典,用于按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否同时返回一个字典,其中包含缺失键、意外键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上以其自己的建模文件定义的自定义模型。此选项仅应在您信任的存储库中设置为True
,并且您已阅读其中的代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (additional keyword arguments, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载配置,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,都将用于使用提供的kwargs
值覆盖该属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有视频分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- timesformer — TimesformerForVideoClassification (TimeSformer 模型)
- videomae — VideoMAEForVideoClassification (VideoMAE 模型)
- vivit — VivitForVideoClassification (ViViT 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForVideoClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVideoClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForVideoClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForVideoClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForKeypointDetection
AutoModelForMaskedImageModeling
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将实例化为库的模型类之一(带有掩码图像建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- DeiTConfig 配置类:DeiTForMaskedImageModeling (DeiT 模型)
- FocalNetConfig 配置类:FocalNetForMaskedImageModeling (FocalNet 模型)
- SwinConfig 配置类:SwinForMaskedImageModeling (Swin Transformer 模型)
- Swinv2Config 配置类:Swinv2ForMaskedImageModeling (Swin Transformer V2 模型)
- ViTConfig 配置类:ViTForMaskedImageModeling (ViT 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
) 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有掩码图像建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,托管在 huggingface.co 模型仓库中的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - TensorFlow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (additional positional arguments, optional) — 将传递给底层模型
__init__()
方法的其他位置参数。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。 当满足以下条件时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型ID字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到了名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从保存的权重文件中加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 然而,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理在每个请求上使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上定义的自定义模型在其自己的建模文件中。 此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分保留在不同的存储库中,则用于 Hub 上代码的特定修订版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有掩码图像建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- deit — DeiTForMaskedImageModeling (DeiT 模型)
- focalnet — FocalNetForMaskedImageModeling (FocalNet 模型)
- swin — SwinForMaskedImageModeling (Swin Transformer 模型)
- swinv2 — Swinv2ForMaskedImageModeling (Swin Transformer V2 模型)
- vit — ViTForMaskedImageModeling (ViT 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForMaskedImageModeling
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForMaskedImageModeling.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForMaskedImageModeling
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将实例化为库的模型类之一(带有掩码图像建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- DeiTConfig 配置类: TFDeiTForMaskedImageModeling (DeiT 模型)
- SwinConfig 配置类: TFSwinForMaskedImageModeling (Swin Transformer 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
) 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有掩码图像建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。 在这种情况下,应将from_pt
设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。 当满足以下条件时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型ID字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到了名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理在每个请求上使用。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义模型,这些模型在它们自己的建模文件中定义。此选项仅应针对您信任并在其中阅读过代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖该属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有掩码图像建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- deit — TFDeiTForMaskedImageModeling (DeiT 模型)
- swin — TFSwinForMaskedImageModeling (Swin Transformer 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForMaskedImageModeling
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForMaskedImageModeling.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForObjectDetection
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有对象检测头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- ConditionalDetrConfig 配置类: ConditionalDetrForObjectDetection (Conditional DETR 模型)
- DabDetrConfig 配置类: DabDetrForObjectDetection (DAB-DETR 模型)
- DeformableDetrConfig 配置类: DeformableDetrForObjectDetection (Deformable DETR 模型)
- DetaConfig 配置类: DetaForObjectDetection (DETA 模型)
- DetrConfig 配置类: DetrForObjectDetection (DETR 模型)
- RTDetrConfig 配置类: RTDetrForObjectDetection (RT-DETR 模型)
- RTDetrV2Config 配置类: RTDetrV2ForObjectDetection (RT-DETRv2 模型)
- TableTransformerConfig 配置类: TableTransformerForObjectDetection (Table Transformer 模型)
- YolosConfig 配置类: YolosForObjectDetection (YOLOS 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有对象检测头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件(例如,
./tf_model/model.ckpt.index
)的路径或 URL。在这种情况下,应将from_tf
设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。在以下情况下可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是一个更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_tf (
bool
, 可选, defaults toFalse
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许 Hub 上定义的自定义模型在其自己的建模文件中。 此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于一个配置属性,都将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的一个模型类(带有对象检测头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- conditional_detr — ConditionalDetrForObjectDetection (条件 DETR 模型)
- dab-detr — DabDetrForObjectDetection (DAB-DETR 模型)
- deformable_detr — DeformableDetrForObjectDetection (可变形 DETR 模型)
- deta — DetaForObjectDetection (DETA 模型)
- detr — DetrForObjectDetection (DETR 模型)
- rt_detr — RTDetrForObjectDetection (RT-DETR 模型)
- rt_detr_v2 — RTDetrV2ForObjectDetection (RT-DETRv2 模型)
- table-transformer — TableTransformerForObjectDetection (表格 Transformer 模型)
- yolos — YolosForObjectDetection (YOLOS 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForObjectDetection
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForObjectDetection.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForObjectDetection.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForObjectDetection.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageSegmentation
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有图像分割头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- DetrConfig 配置类: DetrForSegmentation (DETR 模型)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现方式(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有图像分割头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。 在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加位置参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。 当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_tf (
bool
, optional, defaults toFalse
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。 现在,所有下载在可能的情况下默认恢复。 将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理在每个请求上使用。 - output_loading_info(
bool
, 可选, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许 Hub 上定义的自定义模型在其自己的建模文件中。 此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于一个配置属性,都将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有图像分割头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- detr — DetrForSegmentation (DETR 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForImageSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForImageToImage
AutoModelForSemanticSegmentation
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有一个语义分割头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- BeitConfig 配置类: BeitForSemanticSegmentation (BEiT 模型)
- DPTConfig 配置类: DPTForSemanticSegmentation (DPT 模型)
- Data2VecVisionConfig 配置类: Data2VecVisionForSemanticSegmentation (Data2VecVision 模型)
- MobileNetV2Config 配置类: MobileNetV2ForSemanticSegmentation (MobileNetV2 模型)
- MobileViTConfig 配置类: MobileViTForSemanticSegmentation (MobileViT 模型)
- MobileViTV2Config 配置类: MobileViTV2ForSemanticSegmentation (MobileViTV2 模型)
- SegformerConfig 配置类: SegformerForSemanticSegmentation (SegFormer 模型)
- UperNetConfig 配置类: UperNetForSemanticSegmentation (UPerNet 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有一个语义分割头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 用于代替从已保存的权重文件加载的状态字典的状态字典。
如果要从预训练配置创建模型但加载自己的权重,则可以使用此选项。但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载都在可能的情况下默认恢复。将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上以其自己的建模文件定义的自定义模型。此选项仅应针对您信任并在其中阅读了代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供了config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性都不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有一个语义分割头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- beit — BeitForSemanticSegmentation (BEiT 模型)
- data2vec-vision — Data2VecVisionForSemanticSegmentation (Data2VecVision 模型)
- dpt — DPTForSemanticSegmentation (DPT 模型)
- mobilenet_v2 — MobileNetV2ForSemanticSegmentation (MobileNetV2 模型)
- mobilevit — MobileViTForSemanticSegmentation (MobileViT 模型)
- mobilevitv2 — MobileViTV2ForSemanticSegmentation (MobileViTV2 模型)
- segformer — SegformerForSemanticSegmentation (SegFormer 模型)
- upernet — UperNetForSemanticSegmentation (UPerNet 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForSemanticSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSemanticSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSemanticSegmentation
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有一个语义分割头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- Data2VecVisionConfig 配置类: TFData2VecVisionForSemanticSegmentation (Data2VecVision 模型)
- MobileViTConfig 配置类: TFMobileViTForSemanticSegmentation (MobileViT 模型)
- SegformerConfig 配置类: TFSegformerForSemanticSegmentation (SegFormer 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有一个语义分割头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的的位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不想使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许使用 Hub 上自定义的模型,这些模型在其自己的建模文件中定义。此选项应仅对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数, 可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,
output_attentions=True
)。行为方式取决于是否提供config
或自动加载:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖该属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有一个语义分割头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- data2vec-vision — TFData2VecVisionForSemanticSegmentation (Data2VecVision 模型)
- mobilevit — TFMobileViTForSemanticSegmentation (MobileViT 模型)
- segformer — TFSegformerForSemanticSegmentation (SegFormer 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForSemanticSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForInstanceSegmentation
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有实例分割头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 模型类将根据配置类进行实例化选择:
- MaskFormerConfig 配置类: MaskFormerForInstanceSegmentation (MaskFormer 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力机制实现 (如果相关)。可以是"eager"
(手动实现的注意力机制),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则默认使用手动"eager"
实现。
从配置实例化库中的一个模型类 (带有一个实例分割头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,表示托管在 huggingface.co 上的模型仓库中的预训练模型的 模型 ID 。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 tensorflow 索引检查点文件 的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应该设置为True
,并且应该提供一个配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的 positional 参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。在以下情况下可以自动加载配置:
- 该模型是库提供的模型 (使用预训练模型的 模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,可以使用此选项。但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重 (请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制 (重新) 下载模型权重和配置文件,覆盖缓存的版本 (如果存在)。 - resume_download — 已弃用且忽略。所有下载现在在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理服务器用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件 (例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上使用自定义模型,这些模型在它们自己的建模文件中定义。此选项仅应针对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword 参数, 可选) — 可用于更新配置对象 (在加载后) 并初始化模型 (例如,
output_attentions=True
)。根据是否提供config
或自动加载配置,行为有所不同:- 如果使用
config
提供了配置,**kwargs
将直接传递给底层模型的__init__
方法 (我们假设已经完成了对配置的所有相关更新) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类 (带有一个实例分割头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- maskformer — MaskFormerForInstanceSegmentation (MaskFormer 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForInstanceSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForInstanceSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForInstanceSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForInstanceSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForUniversalSegmentation
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的一个模型类 (带有一个通用图像分割头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 模型类将根据配置类进行实例化选择:
- DetrConfig 配置类: DetrForSegmentation (DETR 模型)
- Mask2FormerConfig 配置类: Mask2FormerForUniversalSegmentation (Mask2Former 模型)
- MaskFormerConfig 配置类: MaskFormerForInstanceSegmentation (MaskFormer 模型)
- OneFormerConfig 配置类: OneFormerForUniversalSegmentation (OneFormer 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力机制实现 (如果相关)。可以是"eager"
(手动实现的注意力机制),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,torch>=2.1.1 将使用 SDPA。否则默认使用手动"eager"
实现。
从配置实例化库中的一个模型类 (带有一个通用图像分割头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,表示托管在 huggingface.co 上的模型仓库中的预训练模型的 模型 ID 。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 tensorflow 索引检查点文件 的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应该设置为True
,并且应该提供一个配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的 positional 参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。 当满足以下条件时,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 状态字典,用于替代从已保存的权重文件中加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,可以使用此选项。 但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 目录路径,用于缓存下载的预训练模型配置,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。所有下载现在在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上自定义模型文件中定义的自定义模型。此选项仅应为信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有通用图像分割头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- detr — DetrForSegmentation (DETR 模型)
- mask2former — Mask2FormerForUniversalSegmentation (Mask2Former 模型)
- maskformer — MaskFormerForInstanceSegmentation (MaskFormer 模型)
- oneformer — OneFormerForUniversalSegmentation (OneFormer 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForUniversalSegmentation
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForUniversalSegmentation.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForUniversalSegmentation.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForUniversalSegmentation.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForZeroShotImageClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有零样本图像分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AlignConfig 配置类: AlignModel (ALIGN 模型)
- AltCLIPConfig 配置类: AltCLIPModel (AltCLIP 模型)
- Blip2Config 配置类: Blip2ForImageTextRetrieval (BLIP-2 模型)
- BlipConfig 配置类: BlipModel (BLIP 模型)
- CLIPConfig 配置类: CLIPModel (CLIP 模型)
- CLIPSegConfig 配置类: CLIPSegModel (CLIPSeg 模型)
- ChineseCLIPConfig 配置类: ChineseCLIPModel (Chinese-CLIP 模型)
- Siglip2Config 配置类: Siglip2Model (SigLIP2 模型)
- SiglipConfig 配置类: SiglipModel (SigLIP 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有零样本图像分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - TensorFlow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。 当满足以下条件时,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 状态字典,用于替代从已保存的权重文件中加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,可以使用此选项。 但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 用于指定一个目录路径,下载的预训练模型配置应缓存到该目录中,以防不使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件中加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 一个代理服务器字典,用于指定按协议或端点使用的代理,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许使用 Hub 上自定义模型文件中定义的自定义模型。此选项仅应在您信任的仓库中设置为True
,并且您已阅读过其中的代码,因为它将在您的本地计算机上执行 Hub 上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于指定 Hub 上代码的特定版本,如果代码与模型的其余部分位于不同的仓库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
中与配置属性对应的每个键将用于使用提供的kwargs
值覆盖该属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有 zero-shot 图像分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- align — AlignModel (ALIGN 模型)
- altclip — AltCLIPModel (AltCLIP 模型)
- blip — BlipModel (BLIP 模型)
- blip-2 — Blip2ForImageTextRetrieval (BLIP-2 模型)
- chinese_clip — ChineseCLIPModel (Chinese-CLIP 模型)
- clip — CLIPModel (CLIP 模型)
- clipseg — CLIPSegModel (CLIPSeg 模型)
- siglip — SiglipModel (SigLIP 模型)
- siglip2 — Siglip2Model (SigLIP2 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForZeroShotImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForZeroShotImageClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForZeroShotImageClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有零样本图像分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类基于配置类选择:
- BlipConfig 配置类: TFBlipModel (BLIP 模型)
- CLIPConfig 配置类: TFCLIPModel (CLIP 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现方式(如果相关)。可以是以下任何一种:"eager"
(手动实现注意力),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,对于 torch>=2.1.1 将使用 SDPA。否则,默认使用手动"eager"
实现。
从配置实例化库的模型类之一(带有零样本图像分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个 PyTorch state_dict 保存文件 的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,应将from_pt
设置为True
,并应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。在以下情况下可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型是通过提供本地目录作为
pretrained_model_name_or_path
加载的,并且在目录中找到了名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, 可选) — 用于指定一个目录路径,下载的预训练模型配置应缓存到该目录中,以防不使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件中加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理服务器用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在它们自己的建模文件中定义。 此选项仅应针对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个对应于配置属性的键将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有 zero-shot 图像分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- blip — TFBlipModel (BLIP 模型)
- clip — TFCLIPModel (CLIP 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForZeroShotImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForZeroShotImageClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForZeroShotObjectDetection
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有零样本对象检测头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- GroundingDinoConfig 配置类: GroundingDinoForObjectDetection (Grounding DINO 模型)
- OmDetTurboConfig 配置类: OmDetTurboForObjectDetection (OmDet-Turbo 模型)
- OwlViTConfig 配置类: OwlViTForObjectDetection (OWL-ViT 模型)
- Owlv2Config 配置类: Owlv2ForObjectDetection (OWLv2 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有零样本对象检测头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的 模型 ID 。
- 目录 的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - tensorflow 索引检查点文件 的路径或 URL(例如,
./tf_model/model.ckpt.index
)。 在这种情况下,from_tf
应设置为True
,并且应将配置对象作为config
参数提供。 此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型,然后加载 PyTorch 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。 当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的 模型 ID 字符串加载)。
- 模型已使用 save_pretrained() 保存,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 要使用的状态字典,而不是从已保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果应不使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。 现在,所有下载在可能的情况下默认恢复。 将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理服务器用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许使用 Hub 上自定义的模型,这些模型在其自身的建模文件中定义。此选项仅应针对您信任的存储库设置为True
,并且在您已阅读代码的情况下使用,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖所述属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有零样本对象检测头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- grounding-dino — GroundingDinoForObjectDetection (Grounding DINO 模型)
- omdet-turbo — OmDetTurboForObjectDetection (OmDet-Turbo 模型)
- owlv2 — Owlv2ForObjectDetection (OWLv2 模型)
- owlvit — OwlViTForObjectDetection (OWL-ViT 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForZeroShotObjectDetection
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForZeroShotObjectDetection.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
音频
以下自动类可用于以下音频任务。
AutoModelForAudioClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有音频分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- ASTConfig 配置类: ASTForAudioClassification (音频频谱图 Transformer 模型)
- Data2VecAudioConfig 配置类: Data2VecAudioForSequenceClassification (Data2VecAudio 模型)
- HubertConfig 配置类: HubertForSequenceClassification (Hubert 模型)
- SEWConfig 配置类: SEWForSequenceClassification (SEW 模型)
- SEWDConfig 配置类: SEWDForSequenceClassification (SEW-D 模型)
- UniSpeechConfig 配置类: UniSpeechForSequenceClassification (UniSpeech 模型)
- UniSpeechSatConfig 配置类: UniSpeechSatForSequenceClassification (UniSpeechSat 模型)
- Wav2Vec2BertConfig 配置类: Wav2Vec2BertForSequenceClassification (Wav2Vec2-BERT 模型)
- Wav2Vec2Config 配置类: Wav2Vec2ForSequenceClassification (Wav2Vec2 模型)
- Wav2Vec2ConformerConfig 配置类: Wav2Vec2ConformerForSequenceClassification (Wav2Vec2-Conformer 模型)
- WavLMConfig 配置类: WavLMForSequenceClassification (WavLM 模型)
- WhisperConfig 配置类: WhisperForAudioClassification (Whisper 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
) 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有音频分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。 在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。 当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 然而,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。 现在,所有下载在可能的情况下默认恢复。 将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上自定义模型文件中定义的自定义模型。此选项应仅对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键(对应于配置属性)将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带音频分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- audio-spectrogram-transformer — ASTForAudioClassification (音频频谱图转换器模型)
- data2vec-audio — Data2VecAudioForSequenceClassification (Data2VecAudio 模型)
- hubert — HubertForSequenceClassification (Hubert 模型)
- sew — SEWForSequenceClassification (SEW 模型)
- sew-d — SEWDForSequenceClassification (SEW-D 模型)
- unispeech — UniSpeechForSequenceClassification (UniSpeech 模型)
- unispeech-sat — UniSpeechSatForSequenceClassification (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForSequenceClassification (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForSequenceClassification (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForSequenceClassification (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForSequenceClassification (WavLM 模型)
- whisper — WhisperForAudioClassification (Whisper 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForAudioClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForAudioFrameClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有音频分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- Wav2Vec2Config 配置类: TFWav2Vec2ForSequenceClassification (Wav2Vec2 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一种。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现方式。
从配置实例化库的模型类之一(带有音频分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。在以下情况下可以自动加载配置:
- 该模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在该目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许使用 Hub 上自定义的模型文件。 此选项应仅对您信任的仓库设置 True,并且您已阅读过其中的代码,因为它会在您的本地机器上执行 Hub 上的代码。 - code_revision (
str
, optional, defaults to"main"
) — Hub 上代码的特定修订版本,如果代码与模型的其余部分放在不同的仓库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可以用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载配置而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,都将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带音频分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- wav2vec2 — TFWav2Vec2ForSequenceClassification (Wav2Vec2 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForAudioClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForAudioClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForAudioClassification.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
TFAutoModelForAudioFrameClassification
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带音频帧(token)分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- Data2VecAudioConfig 配置类: Data2VecAudioForAudioFrameClassification (Data2VecAudio 模型)
- UniSpeechSatConfig 配置类: UniSpeechSatForAudioFrameClassification (UniSpeechSat 模型)
- Wav2Vec2BertConfig 配置类: Wav2Vec2BertForAudioFrameClassification (Wav2Vec2-BERT 模型)
- Wav2Vec2Config 配置类: Wav2Vec2ForAudioFrameClassification (Wav2Vec2 模型)
- Wav2Vec2ConformerConfig 配置类: Wav2Vec2ConformerForAudioFrameClassification (Wav2Vec2-Conformer 模型)
- WavLMConfig 配置类: WavLMForAudioFrameClassification (WavLM 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现方式(如果相关)。 可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
) 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一种。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带音频帧(token)分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个 tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。 在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型,然后加载 PyTorch 模型要慢。
- model_args (附加位置参数, optional) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,而不是自动加载的配置。 当以下情况时,可以自动加载配置:
- 该模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], optional) — 要使用的状态字典,而不是从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型,但加载您自己的权重,则可以使用此选项。 然而,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, optional) — 如果不想使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, optional, defaults toFalse
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。 现在,所有下载都在可能的情况下默认恢复。 将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, optional) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否同时返回包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许使用 Hub 上自定义的模型文件。 此选项应仅对您信任的仓库设置 True,并且您已阅读过其中的代码,因为它会在您的本地机器上执行 Hub 上的代码。 - code_revision (
str
, optional, defaults to"main"
) — Hub 上代码的特定修订版本,如果代码与模型的其余部分放在不同的仓库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可以用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载配置而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,都将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带音频帧(token)分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- data2vec-audio — Data2VecAudioForAudioFrameClassification (Data2VecAudio 模型)
- unispeech-sat — UniSpeechSatForAudioFrameClassification (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForAudioFrameClassification (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForAudioFrameClassification (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForAudioFrameClassification (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForAudioFrameClassification (WavLM 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForAudioFrameClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioFrameClassification.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioFrameClassification.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioFrameClassification.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForCTC
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有一个连接主义时间分类头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- Data2VecAudioConfig 配置类: Data2VecAudioForCTC (Data2VecAudio 模型)
- HubertConfig 配置类: HubertForCTC (Hubert 模型)
- MCTCTConfig 配置类: MCTCTForCTC (M-CTC-T 模型)
- SEWConfig 配置类: SEWForCTC (SEW 模型)
- SEWDConfig 配置类: SEWDForCTC (SEW-D 模型)
- UniSpeechConfig 配置类: UniSpeechForCTC (UniSpeech 模型)
- UniSpeechSatConfig 配置类: UniSpeechSatForCTC (UniSpeechSat 模型)
- Wav2Vec2BertConfig 配置类: Wav2Vec2BertForCTC (Wav2Vec2-BERT 模型)
- Wav2Vec2Config 配置类: Wav2Vec2ForCTC (Wav2Vec2 模型)
- Wav2Vec2ConformerConfig 配置类: Wav2Vec2ConformerForCTC (Wav2Vec2-Conformer 模型)
- WavLMConfig 配置类: WavLMForCTC (WavLM 模型)
- attn_implementation (
str
, optional) — 模型中使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一种。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带连接主义时间分类头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如
./my_model_directory/
。 - 一个 tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应该设置为True
,并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的 positional 参数, 可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。当满足以下条件时,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在该目录中找到了名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,可以使用此选项。 在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应为信任的存储库设置为True
,并且您已阅读其中的代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带连接主义时间分类头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- data2vec-audio — Data2VecAudioForCTC (Data2VecAudio 模型)
- hubert — HubertForCTC (Hubert 模型)
- mctct — MCTCTForCTC (M-CTC-T 模型)
- sew — SEWForCTC (SEW 模型)
- sew-d — SEWDForCTC (SEW-D 模型)
- unispeech — UniSpeechForCTC (UniSpeech 模型)
- unispeech-sat — UniSpeechSatForCTC (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForCTC (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForCTC (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForCTC (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForCTC (WavLM 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForCTC
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCTC.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCTC.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCTC.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForSpeechSeq2Seq
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有一个序列到序列的语音到文本建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- MoonshineConfig 配置类: MoonshineForConditionalGeneration (Moonshine 模型)
- Pop2PianoConfig 配置类: Pop2PianoForConditionalGeneration (Pop2Piano 模型)
- SeamlessM4TConfig 配置类: SeamlessM4TForSpeechToText (SeamlessM4T 模型)
- SeamlessM4Tv2Config 配置类: SeamlessM4Tv2ForSpeechToText (SeamlessM4Tv2 模型)
- Speech2TextConfig 配置类: Speech2TextForConditionalGeneration (Speech2Text 模型)
- SpeechEncoderDecoderConfig 配置类: SpeechEncoderDecoderModel (语音编码器-解码器模型)
- SpeechT5Config 配置类: SpeechT5ForSpeechToText (SpeechT5 模型)
- WhisperConfig 配置类: WhisperForConditionalGeneration (Whisper 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库中的一个模型类(带有一个序列到序列的语音到文本建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 模型仓库中的预训练模型的 模型 ID 。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应该设置为True
并且应该提供一个配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 该模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到了名为 config.json 的配置文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是一个更简单的选项。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个包含缺失键、意外键和错误消息的字典。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上以其自己的建模文件定义的自定义模型。此选项仅应为信任的存储库设置True
,并且在您已阅读代码的情况下,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。行为因是否提供config
或自动加载而异:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性都不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有一个序列到序列的语音到文本建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- moonshine — MoonshineForConditionalGeneration (Moonshine 模型)
- pop2piano — Pop2PianoForConditionalGeneration (Pop2Piano 模型)
- seamless_m4t — SeamlessM4TForSpeechToText (SeamlessM4T 模型)
- seamless_m4t_v2 — SeamlessM4Tv2ForSpeechToText (SeamlessM4Tv2 模型)
- speech-encoder-decoder — SpeechEncoderDecoderModel (语音编码器-解码器模型)
- speech_to_text — Speech2TextForConditionalGeneration (Speech2Text 模型)
- speecht5 — SpeechT5ForSpeechToText (SpeechT5 模型)
- whisper — WhisperForConditionalGeneration (Whisper 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForSpeechSeq2Seq.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForSpeechSeq2Seq
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有一个序列到序列的语音到文本建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 基于配置类选择要实例化的模型类:
- Speech2TextConfig 配置类: TFSpeech2TextForConditionalGeneration (Speech2Text 模型)
- WhisperConfig 配置类: TFWhisperForConditionalGeneration (Whisper 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库中的一个模型类(带有一个序列到序列的语音到文本建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
oros.PathLike
) — 可以是以下之一:- 一个字符串,即 huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。与使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型,然后加载 TensorFlow 模型相比,此加载路径速度较慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。 在以下情况下,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中移除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项应仅对您信任并且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码位于与模型其余部分不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword arguments,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有一个序列到序列的语音到文本建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- speech_to_text — TFSpeech2TextForConditionalGeneration (Speech2Text 模型)
- whisper — TFWhisperForConditionalGeneration (Whisper 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForSpeechSeq2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForSpeechSeq2Seq
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有一个序列到序列的语音到文本建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- SpeechEncoderDecoderConfig 配置类: FlaxSpeechEncoderDecoderModel (语音编码器-解码器模型)
- WhisperConfig 配置类: FlaxWhisperForConditionalGeneration (Whisper 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(手动实现注意力)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库中的一个模型类(带有一个序列到序列的语音到文本建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即 huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。与使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型,然后加载 TensorFlow 模型相比,此加载路径速度较慢。
- model_args (额外的 positional arguments,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。 在以下情况下,可以自动加载配置:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers v5 中移除。
- proxies (
Dict[str, str]
, optional) — 用于按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义模型,这些模型在其自己的建模文件中定义。此选项仅应为信任的存储库设置为True
,并且您已阅读代码,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码保留在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (additional keyword arguments, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,都将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带有一个序列到序列的语音到文本建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- speech-encoder-decoder — FlaxSpeechEncoderDecoderModel (语音编码器-解码器模型)
- whisper — FlaxWhisperForConditionalGeneration (Whisper 模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForSpeechSeq2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForAudioXVector
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有通过 x-vector 头的音频检索)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- Data2VecAudioConfig 配置类: Data2VecAudioForXVector (Data2VecAudio 模型)
- UniSpeechSatConfig 配置类: UniSpeechSatForXVector (UniSpeechSat 模型)
- Wav2Vec2BertConfig 配置类: Wav2Vec2BertForXVector (Wav2Vec2-BERT 模型)
- Wav2Vec2Config 配置类: Wav2Vec2ForXVector (Wav2Vec2 模型)
- Wav2Vec2ConformerConfig 配置类: Wav2Vec2ConformerForXVector (Wav2Vec2-Conformer 模型)
- WavLMConfig 配置类: WavLMForXVector (WavLM 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一个。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有通过 x-vector 头的音频检索)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是:- 一个字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个 tensorflow 索引检查点文件 的路径或 url (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (additional positional arguments, optional) — 将传递给底层模型
__init__()
方法的其他位置参数。 - config (PretrainedConfig, optional) — 用于模型的配置,以代替自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型已使用 save_pretrained() 保存,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], optional) — 状态字典,用于代替从保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。但是,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, optional) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_tf (
bool
, optional, defaults toFalse
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, optional) — 用于按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应针对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果提供了带有
config
的配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性相对应,将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果提供了带有
从预训练模型实例化库的模型类之一(带有通过 x-vector 头部的音频检索)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- data2vec-audio — Data2VecAudioForXVector (Data2VecAudio 模型)
- unispeech-sat — UniSpeechSatForXVector (UniSpeechSat 模型)
- wav2vec2 — Wav2Vec2ForXVector (Wav2Vec2 模型)
- wav2vec2-bert — Wav2Vec2BertForXVector (Wav2Vec2-BERT 模型)
- wav2vec2-conformer — Wav2Vec2ConformerForXVector (Wav2Vec2-Conformer 模型)
- wavlm — WavLMForXVector (WavLM 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForAudioXVector
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForAudioXVector.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForAudioXVector.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForAudioXVector.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForTextToSpectrogram
AutoModelForTextToWaveform
多模态
以下自动类可用于以下多模态任务。
AutoModelForTableQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将实例化为库的模型类之一(带有表格问答头部)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- TapasConfig 配置类: TapasForQuestionAnswering (TAPAS 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有表格问答头部)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 字符串,huggingface.co 上模型仓库中托管的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加位置参数, 可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 要使用的模型的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 要使用的状态字典,而不是从已保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义模型,这些模型在其自己的建模文件中定义。此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性对应,都将用于使用提供的kwargs
值覆盖所述属性。不与任何配置属性对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带表格问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- tapas — TapasForQuestionAnswering (TAPAS 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForTableQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
>>> # Update configuration during loading
>>> model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/tapas_tf_model_config.json")
>>> model = AutoModelForTableQuestionAnswering.from_pretrained(
... "./tf_model/tapas_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForTableQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将实例化为库的模型类之一(带有表格问答头部)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- TapasConfig 配置类: TFTapasForQuestionAnswering (TAPAS 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有表格问答头部)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的 model id。
- 一个 directory 的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件 的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, optional) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,而不是自动加载的配置。当满足以下条件时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的 model id 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (
str
或os.PathLike
, optional) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录路径。 - from_pt (
bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在默认情况下,所有下载都在可能的情况下恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, optional) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理在每个请求上使用。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义模型,这些模型在其自己的建模文件中定义。此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果与配置属性对应,都将用于使用提供的kwargs
值覆盖所述属性。不与任何配置属性对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带表格问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- tapas — TFTapasForQuestionAnswering (TAPAS 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForTableQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
>>> # Update configuration during loading
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/tapas_pt_model_config.json")
>>> model = TFAutoModelForTableQuestionAnswering.from_pretrained(
... "./pt_model/tapas_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForDocumentQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有文档问答头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- LayoutLMConfig 配置类:LayoutLMForQuestionAnswering (LayoutLM 模型)
- LayoutLMv2Config 配置类:LayoutLMv2ForQuestionAnswering (LayoutLMv2 模型)
- LayoutLMv3Config 配置类:LayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一种。默认情况下,如果可用,对于 torch>=2.1.1 将使用 SDPA。否则,默认使用手动"eager"
实现。
从配置实例化库中的模型类之一(带有文档问答头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
示例
>>> from transformers import AutoConfig, AutoModelForDocumentQuestionAnswering
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> model = AutoModelForDocumentQuestionAnswering.from_config(config)
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 上的模型仓库内的预训练模型的 模型 ID 。
- 一个指向目录的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个指向 tensorflow 索引检查点文件 的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,应将from_tf
设置为True
,并且应将配置对象作为config
参数提供。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并在之后加载 PyTorch 模型要慢。
- model_args (额外的 positional 参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以替代自动加载的配置。在以下情况下可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从已保存的权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。不过,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都默认恢复。将在 Transformers 的 v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上自定义的模型在其自己的建模文件中定义。此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的 keyword 参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。行为因是否提供config
或自动加载而异:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的其余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的模型类之一(带有文档问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- layoutlm — LayoutLMForQuestionAnswering (LayoutLM 模型)
- layoutlmv2 — LayoutLMv2ForQuestionAnswering (LayoutLMv2 模型)
- layoutlmv3 — LayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForDocumentQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> # Update configuration during loading
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/layoutlm_tf_model_config.json")
>>> model = AutoModelForDocumentQuestionAnswering.from_pretrained(
... "./tf_model/layoutlm_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForDocumentQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有文档问答头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是基于配置类选择的:
- LayoutLMConfig 配置类:TFLayoutLMForQuestionAnswering (LayoutLM 模型)
- LayoutLMv3Config 配置类:TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
- attn_implementation (
str
, 可选) — 模型中使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention) 中的任何一种。默认情况下,如果可用,对于 torch>=2.1.1 将使用 SDPA。否则,默认使用手动"eager"
实现。
从配置实例化库中的模型类之一(带有文档问答头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
示例
>>> from transformers import AutoConfig, TFAutoModelForDocumentQuestionAnswering
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> model = TFAutoModelForDocumentQuestionAnswering.from_config(config)
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个PyTorch state_dict 保存文件的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应该设置为True
,并且应该提供一个配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (额外的的位置参数,可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。当以下情况时,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, 可选) — 目录的路径,在该目录中应缓存下载的预训练模型配置,如果不想使用标准缓存。 - from_pt (
bool
, 可选, 默认为False
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中删除。
- proxies (
Dict[str, str]
, 可选) — 代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许在 Hub 上自定义模型在其自己的建模文件中定义。此选项仅应为信任的存储库设置True
,并且您已阅读了代码,因为它将在本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统来存储 huggingface.co 上的模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (额外的关键字参数,可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载而表现不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设配置的所有相关更新都已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。与任何配置属性都不对应的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的模型类之一(带有文档问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- layoutlm — TFLayoutLMForQuestionAnswering (LayoutLM 模型)
- layoutlmv3 — TFLayoutLMv3ForQuestionAnswering (LayoutLMv3 模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForDocumentQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3")
>>> # Update configuration during loading
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="52e01b3", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/layoutlm_pt_model_config.json")
>>> model = TFAutoModelForDocumentQuestionAnswering.from_pretrained(
... "./pt_model/layoutlm_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForVisualQuestionAnswering
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有视觉问答头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- Blip2Config 配置类: Blip2ForConditionalGeneration (BLIP-2 模型)
- BlipConfig 配置类: BlipForQuestionAnswering (BLIP 模型)
- ViltConfig 配置类: ViltForQuestionAnswering (ViLT 模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一个。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有视觉问答头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,即托管在 huggingface.co 模型仓库中的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL (例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应该设置为True
,并且应该提供一个配置对象作为config
参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (额外的的位置参数,可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。当以下情况时,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], optional) — 一个状态字典,用于替代从已保存权重文件中加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。不过在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, optional) — 下载的预训练模型配置应缓存到的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, optional, defaults toFalse
) — 从 TensorFlow 检查点保存文件中加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖可能存在的缓存版本。 - resume_download — 已弃用并忽略。现在所有下载在可能的情况下默认恢复。将在 Transformers v5 版本中移除。
- proxies (
Dict[str, str]
, optional) — 按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许 Hub 上定义的自定义模型在其自己的建模文件中存在。此选项仅应针对您信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 用于 Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数, optional) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载config
,行为有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个对应于配置属性的键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有视觉问答头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- blip — BlipForQuestionAnswering (BLIP 模型)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 模型)
- vilt — ViltForQuestionAnswering (ViLT 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForVisualQuestionAnswering
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
>>> # Update configuration during loading
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/vilt_tf_model_config.json")
>>> model = AutoModelForVisualQuestionAnswering.from_pretrained(
... "./tf_model/vilt_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
AutoModelForVision2Seq
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有视觉到文本建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- Blip2Config 配置类: Blip2ForConditionalGeneration (BLIP-2 模型)
- BlipConfig 配置类: BlipForConditionalGeneration (BLIP 模型)
- ChameleonConfig 配置类: ChameleonForConditionalGeneration (Chameleon 模型)
- GitConfig 配置类: GitForCausalLM (GIT 模型)
- Idefics2Config 配置类: Idefics2ForConditionalGeneration (Idefics2 模型)
- Idefics3Config 配置类: Idefics3ForConditionalGeneration (Idefics3 模型)
- InstructBlipConfig 配置类: InstructBlipForConditionalGeneration (InstructBLIP 模型)
- InstructBlipVideoConfig 配置类: InstructBlipVideoForConditionalGeneration (InstructBlipVideo 模型)
- Kosmos2Config 配置类: Kosmos2ForConditionalGeneration (KOSMOS-2 模型)
- LlavaConfig 配置类: LlavaForConditionalGeneration (LLaVa 模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一种。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认值为手动"eager"
实现。
从配置实例化库的模型类之一(带有视觉到文本建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。在这种情况下,from_tf
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (附加位置参数, optional) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, optional) — 用于模型的配置,以代替自动加载的配置。在以下情况下,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], optional) — 用于替代从已保存的权重文件中加载的状态字典。
如果您想从预训练配置创建模型,但加载您自己的权重,则可以使用此选项。不过,在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否是更简单的选择。
- cache_dir (*`str` 或 `os.PathLike`*, *可选*) — 用于缓存已下载的预训练模型配置的目录路径,如果不想使用标准缓存。
- from_tf (*`bool`*, *可选*, 默认为 `False`*) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅 `pretrained_model_name_or_path` 参数的文档字符串)。
- force_download (*`bool`*, *可选*, 默认为 `False`*) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。
- resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都会默认恢复。将在 Transformers v5 版本中移除。
- proxies (*Dict[str, str]*, *可选*) — 一个代理服务器字典,用于指定按协议或端点使用的代理,例如,
{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理将用于每个请求。 - output_loading_info(*`bool`*, *可选*, 默认为 `False`*) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。
- local_files_only(*`bool`*, *可选*, 默认为 `False`*) — 是否仅查看本地文件(例如,不尝试下载模型)。
- revision (*`str`*, *可选*, 默认为 `"main"`) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此 `revision` 可以是 git 允许的任何标识符。
- trust_remote_code (*`bool`*, *可选*, 默认为 `False`*) — 是否允许使用 Hub 上定义的自定义模型,这些模型在其自身的建模文件中定义。 此选项应仅对您信任的仓库且您已阅读过其中的代码设置为 `True`,因为它会在您的本地计算机上执行 Hub 上的代码。
- code_revision (*`str`*, *可选*, 默认为 `"main"`) — 如果代码与模型的其余部分位于不同的仓库中,则指定用于 Hub 上代码的特定版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此 `revision` 可以是 git 允许的任何标识符。
- kwargs (附加关键字参数, *可选*) — 可用于更新配置对象(在加载后)和初始化模型(例如,`output_attentions=True`)。 其行为取决于是否提供了 `config` 或自动加载了配置:
- 如果使用
config
提供了配置,**kwargs**
将直接传递给底层模型的__init__
方法(我们假设所有相关的配置更新都已完成) - 如果未提供配置,
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
中与配置属性对应的每个键将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带视觉到文本建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- blip — BlipForConditionalGeneration (BLIP模型)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 模型)
- chameleon — ChameleonForConditionalGeneration (Chameleon模型)
- git — GitForCausalLM (GIT 模型)
- idefics2 — Idefics2ForConditionalGeneration (Idefics2 模型)
- idefics3 — Idefics3ForConditionalGeneration (Idefics3 模型)
- instructblip — InstructBlipForConditionalGeneration (InstructBLIP模型)
- instructblipvideo — InstructBlipVideoForConditionalGeneration (InstructBlipVideo模型)
- kosmos-2 — Kosmos2ForConditionalGeneration (KOSMOS-2 模型)
- llava — LlavaForConditionalGeneration (LLaVa 模型)
- llava_next — LlavaNextForConditionalGeneration (LLaVA-NeXT 模型)
- llava_next_video — LlavaNextVideoForConditionalGeneration (LLaVa-NeXT-Video 模型)
- llava_onevision — LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- mistral3 — Mistral3ForConditionalGeneration (Mistral3 模型)
- mllama — MllamaForConditionalGeneration (Mllama 模型)
- paligemma — PaliGemmaForConditionalGeneration (PaliGemma 模型)
- pix2struct — Pix2StructForConditionalGeneration (Pix2Struct 模型)
- qwen2_5_vl — Qwen2_5_VLForConditionalGeneration (Qwen2_5_VL 模型)
- qwen2_vl — Qwen2VLForConditionalGeneration (Qwen2VL 模型)
- video_llava — VideoLlavaForConditionalGeneration (VideoLlava 模型)
- vipllava — VipLlavaForConditionalGeneration (VipLlava 模型)
- vision-encoder-decoder — VisionEncoderDecoderModel (Vision Encoder decoder 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForVision2Seq.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForVision2Seq
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有视觉到文本建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类会根据配置类进行选择:
- BlipConfig 配置类: TFBlipForConditionalGeneration (BLIP 模型)
- VisionEncoderDecoderConfig 配置类: TFVisionEncoderDecoderModel (Vision Encoder decoder 模型)
- attn_implementation (*`str`*, *可选*) — 模型中使用的注意力机制实现方式(如果相关)。 可以是
"eager"
(手动实现的注意力),"sdpa"
(使用F.scaled_dot_product_attention
), 或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。 默认情况下,如果可用,对于 torch>=2.1.1 版本将使用 SDPA。否则,默认使用手动"eager"
实现方式。
从配置实例化库的模型类之一(带有视觉到文本建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (*`str` 或 `os.PathLike`*) — 可以是:
- 一个字符串,表示 huggingface.co 模型仓库中托管的预训练模型的 *模型 ID*。
- 一个*目录*的路径,该目录包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个 *PyTorch state_dict 保存文件*的路径或 URL (例如,
./pt_model/pytorch_model.bin
)。 在这种情况下,from_pt
应设置为True
,并且应将配置对象作为config
参数提供。 此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并在之后加载 TensorFlow 模型要慢。
- model_args (附加位置参数, *可选*) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, *可选*) — 用于模型的配置,以替代自动加载的配置。 在以下情况下,配置可以自动加载:
- 模型是由库提供的模型(使用预训练模型的 *模型 ID* 字符串加载)。
- 模型已使用 save_pretrained() 保存,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置文件。
- cache_dir (*`str` 或 `os.PathLike`*, *可选*) — 用于缓存已下载的预训练模型配置的目录路径,如果不想使用标准缓存。
- from_pt (*`bool`*, *可选*, 默认为 `False`*) — 从 PyTorch 检查点保存文件加载模型权重(请参阅 `pretrained_model_name_or_path` 参数的文档字符串)。
- force_download (*`bool`*, *可选*, 默认为 `False`*) — 是否强制(重新)下载模型权重和配置文件,覆盖已缓存的版本(如果存在)。
- resume_download — 已弃用且被忽略。现在,所有下载在可能的情况下都会默认恢复。将在 Transformers v5 版本中移除。
- proxies (*Dict[str, str]*, *可选*) — 一个代理服务器字典,用于指定按协议或端点使用的代理,例如,
{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理将用于每个请求。 - output_loading_info(*`bool`*, *可选*, 默认为 `False`*) — 是否同时返回一个字典,其中包含缺失的键、意外的键和错误消息。
- local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义模型,这些模型在其自己的建模文件中定义。此选项仅应针对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。根据是否提供config
或自动加载配置,行为会有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设已完成对配置的所有相关更新) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。与配置属性对应的kwargs
的每个键将用于使用提供的kwargs
值覆盖所述属性。不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带视觉到文本建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- blip — TFBlipForConditionalGeneration (BLIP 模型)
- vision-encoder-decoder — TFVisionEncoderDecoderModel (视觉编码器解码器模型)
示例
>>> from transformers import AutoConfig, TFAutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForVision2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForVision2Seq
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库的模型类之一(带有视觉到文本建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- VisionEncoderDecoderConfig 配置类: FlaxVisionEncoderDecoderModel (视觉编码器解码器模型)
- attn_implementation (
str
, optional) — 模型中要使用的注意力实现方式(如果相关)。可以是"eager"
(注意力的手动实现)、"sdpa"
(使用F.scaled_dot_product_attention
)或"flash_attention_2"
(使用 Dao-AILab/flash-attention)中的任何一种。默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。否则,默认为手动"eager"
实现。
从配置实例化库的模型类之一(带有视觉到文本建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是:- 字符串,托管在 huggingface.co 上的模型仓库中的预训练模型的模型 ID。
- 目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - PyTorch state_dict 保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin
)。在这种情况下,from_pt
应设置为True
,并且应提供配置对象作为config
参数。此加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
- model_args (附加位置参数,可选) — 将传递给底层模型
__init__()
方法。 - config (PretrainedConfig, optional) — 要使用的模型配置,而不是自动加载的配置。在以下情况下,可以自动加载配置:
- 该模型是由库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 该模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 该模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到名为 config.json 的配置 JSON 文件。
- cache_dir (
str
或os.PathLike
, optional) — 如果不应使用标准缓存,则应在其中缓存下载的预训练模型配置的目录的路径。 - from_pt (
bool
, optional, defaults toFalse
) — 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, optional, defaults toFalse
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。 - resume_download — 已弃用并忽略。现在,所有下载都在可能的情况下默认恢复。将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, optional) — 要按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - output_loading_info(
bool
, optional, defaults toFalse
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, optional, defaults toFalse
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, optional, defaults to"main"
) — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, optional, defaults toFalse
) — 是否允许在 Hub 上自定义模型,这些模型在其自己的建模文件中定义。此选项仅应针对您信任且已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, optional, defaults to"main"
) — 如果代码与模型的其余部分位于不同的存储库中,则用于 Hub 上代码的特定修订版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (其他关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 其行为方式取决于是否提供了config
或自动加载了配置:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新都已完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于一个配置属性,都将用于使用提供的kwargs
值覆盖该属性。 其余不对应于任何配置属性的键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库中的一个模型类(带视觉到文本建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- vision-encoder-decoder — FlaxVisionEncoderDecoderModel (视觉编码器-解码器模型)
示例
>>> from transformers import AutoConfig, FlaxAutoModelForVision2Seq
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForVision2Seq.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForVision2Seq.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
AutoModelForImageTextToText
这是一个通用模型类,当使用 from_pretrained() 类方法或 from_config() 类方法创建时,它将被实例化为库中的模型类之一(带有图像-文本到文本建模头)。
此类不能使用 __init__()
直接实例化(会抛出错误)。
from_config
< source > ( **kwargs )
参数
- config (PretrainedConfig) — 要实例化的模型类是根据配置类选择的:
- AriaConfig 配置类:AriaForConditionalGeneration (Aria 模型)
- AyaVisionConfig 配置类:AyaVisionForConditionalGeneration (AyaVision 模型)
- Blip2Config 配置类:Blip2ForConditionalGeneration (BLIP-2 模型)
- BlipConfig 配置类:BlipForConditionalGeneration (BLIP 模型)
- ChameleonConfig 配置类:ChameleonForConditionalGeneration (Chameleon 模型)
- Emu3Config 配置类:Emu3ForConditionalGeneration (Emu3 模型)
- FuyuConfig 配置类:FuyuForCausalLM (Fuyu 模型)
- Gemma3Config 配置类:Gemma3ForConditionalGeneration (Gemma3ForConditionalGeneration 模型)
- GitConfig 配置类:GitForCausalLM (GIT 模型)
- GotOcr2Config 配置类:GotOcr2ForConditionalGeneration (GOT-OCR2 模型)
- Idefics2Config 配置类:Idefics2ForConditionalGeneration (Idefics2 模型)
- Idefics3Config 配置类:Idefics3ForConditionalGeneration (Idefics3 模型)
- IdeficsConfig 配置类:IdeficsForVisionText2Text (IDEFICS 模型)
- InstructBlipConfig 配置类:InstructBlipForConditionalGeneration (InstructBLIP 模型)
- Kosmos2Config 配置类:Kosmos2ForConditionalGeneration (KOSMOS-2 模型)
- LlavaConfig 配置类:LlavaForConditionalGeneration (LLaVa 模型)
- LlavaNextConfig 配置类:LlavaNextForConditionalGeneration (LLaVA-NeXT 模型)
- LlavaOnevisionConfig 配置类:LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- Mistral3Config 配置类:Mistral3ForConditionalGeneration (Mistral3 模型)
- MllamaConfig 配置类:MllamaForConditionalGeneration (Mllama 模型)
- PaliGemmaConfig 配置类:PaliGemmaForConditionalGeneration (PaliGemma 模型)
- Pix2StructConfig 配置类:Pix2StructForConditionalGeneration (Pix2Struct 模型)
- PixtralVisionConfig 配置类:LlavaForConditionalGeneration (Pixtral 模型)
- Qwen2VLConfig 配置类:Qwen2VLForConditionalGeneration (Qwen2VL 模型)
- Qwen2_5_VLConfig 配置类:Qwen2_5_VLForConditionalGeneration (Qwen2_5_VL 模型)
- ShieldGemma2Config 配置类:Gemma3ForConditionalGeneration (Shieldgemma2 模型)
- SmolVLMConfig 配置类:SmolVLMForConditionalGeneration (SmolVLM 模型)
- UdopConfig 配置类:UdopForConditionalGeneration (UDOP 模型)
- VipLlavaConfig 配置类:VipLlavaForConditionalGeneration (VipLlava 模型)
- VisionEncoderDecoderConfig 配置类:VisionEncoderDecoderModel (视觉编码器-解码器模型)
- attn_implementation (
str
, 可选) — 模型中要使用的注意力实现(如果相关)。 可以是"eager"
(注意力的手动实现),"sdpa"
(使用F.scaled_dot_product_attention
),或"flash_attention_2"
(使用 Dao-AILab/flash-attention)。 默认情况下,如果可用,SDPA 将用于 torch>=2.1.1。 否则,默认值为手动"eager"
实现。
从配置实例化库中的模型类之一(带有图像-文本到文本建模头)。
注意: 从其配置文件加载模型**不会**加载模型权重。它仅影响模型的配置。使用 from_pretrained() 加载模型权重。
from_pretrained
< source > ( *model_args **kwargs )
参数
- pretrained_model_name_or_path (
str
或os.PathLike
) — 可以是以下之一:- 一个字符串,huggingface.co 上的模型仓库中托管的预训练模型的模型 ID。
- 一个目录的路径,其中包含使用 save_pretrained() 保存的模型权重,例如,
./my_model_directory/
。 - 一个tensorflow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index
)。 在这种情况下,应将from_tf
设置为True
,并应提供配置对象作为config
参数。 此加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并随后加载 PyTorch 模型要慢。
- model_args (其他位置参数,可选) — 将传递给底层模型的
__init__()
方法。 - config (PretrainedConfig, 可选) — 用于模型的配置,以代替自动加载的配置。 当以下情况时,可以自动加载配置:
- 模型是库提供的模型(使用预训练模型的模型 ID 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 模型通过提供本地目录作为
pretrained_model_name_or_path
加载,并且在目录中找到了名为 config.json 的配置 JSON 文件。
- state_dict (Dict[str, torch.Tensor], 可选) — 一个状态字典,用于代替从已保存权重文件加载的状态字典。
如果您想从预训练配置创建模型但加载您自己的权重,则可以使用此选项。 但在这种情况下,您应该检查使用 save_pretrained() 和 from_pretrained() 是否不是更简单的选择。
- cache_dir (
str
或os.PathLike
, 可选) — 缓存下载的预训练模型配置的目录路径,如果不想使用标准缓存。 - from_tf (
bool
, 可选, 默认为False
) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path
参数的文档字符串)。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存的版本(如果存在)。 - resume_download — 已弃用且被忽略。 现在,所有下载在可能的情况下默认恢复。 将在 Transformers v5 中删除。
- proxies (
Dict[str, str]
, 可选) — 一个代理服务器字典,用于按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。 代理用于每个请求。 - output_loading_info(
bool
, 可选, 默认为False
) — 是否也返回一个字典,其中包含缺失的键、意外的键和错误消息。 - local_files_only(
bool
, 可选, 默认为False
) — 是否仅查看本地文件(例如,不尝试下载模型)。 - revision (
str
, 可选, 默认为"main"
) — 要使用的特定模型版本。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - trust_remote_code (
bool
, 可选, 默认为False
) — 是否允许 Hub 上自定义模型在其自己的建模文件中定义。 此选项仅应为信任的存储库以及您已阅读代码的存储库设置为True
,因为它将在您的本地计算机上执行 Hub 上存在的代码。 - code_revision (
str
, 可选, 默认为"main"
) — Hub 上代码的特定修订版本,如果代码与模型的其余部分位于不同的仓库中。 它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,因此revision
可以是 git 允许的任何标识符。 - kwargs (附加的关键字参数, 可选) — 可用于更新配置对象(加载后)并初始化模型(例如,
output_attentions=True
)。 根据是否提供config
或自动加载config
的方式有所不同:- 如果使用
config
提供了配置,则**kwargs
将直接传递给底层模型的__init__
方法(我们假设对配置的所有相关更新已经完成) - 如果未提供配置,则
kwargs
将首先传递给配置类初始化函数 (from_pretrained())。kwargs
的每个键,如果对应于配置属性,将用于使用提供的kwargs
值覆盖该属性。 不对应于任何配置属性的剩余键将传递给底层模型的__init__
函数。
- 如果使用
从预训练模型实例化库的模型类之一(带有图像-文本到文本建模头)。
要实例化的模型类是根据配置对象的 model_type
属性(作为参数传递或从 pretrained_model_name_or_path
加载,如果可能)选择的,或者当它丢失时,通过回退到在 pretrained_model_name_or_path
上使用模式匹配来选择。
- aria — AriaForConditionalGeneration (Aria 模型)
- aya_vision — AyaVisionForConditionalGeneration (AyaVision 模型)
- blip — BlipForConditionalGeneration (BLIP模型)
- blip-2 — Blip2ForConditionalGeneration (BLIP-2 模型)
- chameleon — ChameleonForConditionalGeneration (Chameleon模型)
- emu3 — Emu3ForConditionalGeneration (Emu3 模型)
- fuyu — FuyuForCausalLM (Fuyu 模型)
- gemma3 — Gemma3ForConditionalGeneration (Gemma3ForConditionalGeneration 模型)
- git — GitForCausalLM (GIT 模型)
- got_ocr2 — GotOcr2ForConditionalGeneration (GOT-OCR2 模型)
- idefics — IdeficsForVisionText2Text (IDEFICS 模型)
- idefics2 — Idefics2ForConditionalGeneration (Idefics2 模型)
- idefics3 — Idefics3ForConditionalGeneration (Idefics3 模型)
- instructblip — InstructBlipForConditionalGeneration (InstructBLIP模型)
- kosmos-2 — Kosmos2ForConditionalGeneration (KOSMOS-2 模型)
- llava — LlavaForConditionalGeneration (LLaVa 模型)
- llava_next — LlavaNextForConditionalGeneration (LLaVA-NeXT 模型)
- llava_onevision — LlavaOnevisionForConditionalGeneration (LLaVA-Onevision 模型)
- mistral3 — Mistral3ForConditionalGeneration (Mistral3 模型)
- mllama — MllamaForConditionalGeneration (Mllama 模型)
- paligemma — PaliGemmaForConditionalGeneration (PaliGemma 模型)
- pix2struct — Pix2StructForConditionalGeneration (Pix2Struct 模型)
- pixtral — LlavaForConditionalGeneration (Pixtral 模型)
- qwen2_5_vl — Qwen2_5_VLForConditionalGeneration (Qwen2_5_VL 模型)
- qwen2_vl — Qwen2VLForConditionalGeneration (Qwen2VL 模型)
- shieldgemma2 — Gemma3ForConditionalGeneration (Shieldgemma2 模型)
- smolvlm — SmolVLMForConditionalGeneration (SmolVLM 模型)
- udop — UdopForConditionalGeneration (UDOP 模型)
- vipllava — VipLlavaForConditionalGeneration (VipLlava 模型)
- vision-encoder-decoder — VisionEncoderDecoderModel (Vision Encoder decoder 模型)
模型默认设置为评估模式,使用 model.eval()
(例如,dropout 模块被禁用)。要训练模型,您应该首先使用 model.train()
将其设置回训练模式。
示例
>>> from transformers import AutoConfig, AutoModelForImageTextToText
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageTextToText.from_pretrained("google-bert/bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForImageTextToText.from_pretrained("google-bert/bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForImageTextToText.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )