SetFit 文档
主要类别
并获得增强的文档体验
开始使用
主类
SetFitModel
class setfit.SetFitModel
< 来源 >( model_body: typing.Optional[sentence_transformers.SentenceTransformer.SentenceTransformer] = None model_head: typing.Union[setfit.modeling.SetFitHead, sklearn.linear_model._logistic.LogisticRegression, NoneType] = None multi_target_strategy: typing.Optional[str] = None normalize_embeddings: bool = False labels: typing.Optional[typing.List[str]] = None model_card_data: typing.Optional[setfit.model_card.SetFitModelCardData] = None sentence_transformers_kwargs: typing.Optional[typing.Dict] = None **kwargs )
一个集成了 Hugging Face Hub 的 SetFit 模型。
示例
>>> from setfit import SetFitModel
>>> model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")
>>> model.predict([
... "It's a charming and often affecting journey.",
... "It's slow -- very, very slow.",
... "A sometimes tedious film.",
... ])
['positive', 'negative', 'negative']
from_pretrained
< 来源 >( force_download: bool = False resume_download: typing.Optional[bool] = None proxies: typing.Optional[typing.Dict] = None token: typing.Union[bool, str, NoneType] = None cache_dir: typing.Union[str, pathlib.Path, NoneType] = None local_files_only: bool = False revision: typing.Optional[str] = None **model_kwargs )
参数
- pretrained_model_name_or_path (str, Path) —
- Hub 上模型的 model_id (字符串),例如 bigscience/bloom。
- 或者包含使用 [~transformers.PreTrainedModel.save_pretrained] 保存的模型权重的 目录 路径,例如 ../path/to/my_model_directory/。
- revision (str, 可选) — Hub 上模型的修订版本。可以是分支名称、git 标签或任何提交 ID。默认为 main 分支上的最新提交。
- force_download (bool, 可选, 默认为 False) — 是否强制(重新)从 Hub 下载模型权重和配置文件,覆盖现有缓存。
- proxies (Dict[str, str], 可选) — 要按协议或端点使用的代理服务器字典,例如 {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}。每个请求都会使用代理。
- token (str 或 bool, 可选) — 用于远程文件的 HTTP Bearer 授权令牌。默认情况下,它将使用运行 hf auth login 时缓存的令牌。
- cache_dir (str, Path, 可选) — 缓存文件存储的文件夹路径。
- local_files_only (bool, 可选, 默认为 False) — 如果为 True,则避免下载文件,如果本地缓存文件存在则返回其路径。
- labels (List[str], 可选) — 如果标签是 0 到 num_classes-1 范围内的整数,则这些标签表示相应的标签。
- model_card_data (SetFitModelCardData, 可选) — 一个 SetFitModelCardData 实例,存储模型语言、许可证、数据集名称等数据,用于自动生成的模型卡。
- multi_target_strategy (str, 可选) — 与多标签分类一起使用的策略。可以是 “one-vs-rest”、“multi-output” 或 “classifier-chain” 之一。
- use_differentiable_head (bool, 可选) — 是否使用可微分(即 Torch)头部而不是逻辑回归来加载 SetFit。
- normalize_embeddings (bool, 可选) — 是否对 Sentence Transformer 主体生成的嵌入应用归一化。
- device (Union[torch.device, str], 可选) — 加载 SetFit 模型的设备,例如 “cuda:0”、 “mps” 或 torch.device(“cuda”)。
- trust_remote_code (bool, 默认为 False) — 是否允许在 Hub 上自己的建模文件中定义的自定义 Sentence Transformers 模型。此选项仅应设置为您信任且已阅读其代码的仓库,因为它将在您的本地机器上执行 Hub 上存在的代码。默认为 False。
从 Huggingface Hub 下载模型并实例化它。
save_pretrained
< 来源 >( save_directory: typing.Union[str, pathlib.Path] config: typing.Union[dict, huggingface_hub.hub_mixin.DataclassInstance, NoneType] = None repo_id: typing.Optional[str] = None push_to_hub: bool = False model_card_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None **push_to_hub_kwargs ) → str
或 None
参数
- save_directory (
str
或Path
) — 保存模型权重和配置的目录路径。 - config (
dict
或DataclassInstance
, 可选) — 指定为键/值字典或数据类实例的模型配置。 - push_to_hub (
bool
, 可选, 默认为False
) — 保存模型后是否将其推送到 Huggingface Hub。 - repo_id (
str
, 可选) — 您在 Hub 上的仓库 ID。仅在push_to_hub=True
时使用。如果未提供,将默认为文件夹名称。 - model_card_kwargs (
Dict[str, Any]
, 可选) — 传递给模型卡模板的其他参数,用于自定义模型卡。 - push_to_hub_kwargs — 传递给
~ModelHubMixin.push_to_hub
方法的其他关键字参数。
返回
str
或 None
如果 push_to_hub=True
,则为 Hub 上提交的 URL,否则为 None
。
将权重保存到本地目录。
push_to_hub
< 来源 >( repo_id: str config: typing.Union[dict, huggingface_hub.hub_mixin.DataclassInstance, NoneType] = None commit_message: str = '使用 huggingface_hub 推送模型。' private: typing.Optional[bool] = None token: typing.Optional[str] = None branch: typing.Optional[str] = None create_pr: typing.Optional[bool] = None allow_patterns: typing.Union[str, typing.List[str], NoneType] = None ignore_patterns: typing.Union[str, typing.List[str], NoneType] = None delete_patterns: typing.Union[str, typing.List[str], NoneType] = None model_card_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None )
参数
- repo_id (
str
) — 要推送到的仓库 ID(例如:"username/my-model"
)。 - config (
dict
或DataclassInstance
, 可选) — 指定为键/值字典或数据类实例的模型配置。 - commit_message (
str
, 可选) — 推送时提交的消息。 - private (
bool
, 可选) — 创建的仓库是否应为私有。如果为None
(默认),则仓库将为公开,除非组织的默认设置为私有。 - token (
str
, 可选) — 用于远程文件的 HTTP Bearer 授权令牌。默认情况下,它将使用运行hf auth login
时缓存的令牌。 - branch (
str
, 可选) — 推送模型的 git 分支。默认为"main"
。 - create_pr (
boolean
, 可选) — 是否从branch
创建带有该提交的 Pull Request。默认为False
。 - allow_patterns (
List[str]
或str
, 可选) — 如果提供,则只推送至少匹配一个模式的文件。 - ignore_patterns (
List[str]
或str
, 可选) — 如果提供,则不推送匹配任何模式的文件。 - delete_patterns (
List[str]
或str
, 可选) — 如果提供,则匹配任何模式的远程文件将从仓库中删除。 - model_card_kwargs (
Dict[str, Any]
, 可选) — 传递给模型卡模板的其他参数,用于自定义模型卡。
将模型检查点上传到 Hub。
使用 allow_patterns
和 ignore_patterns
精确筛选哪些文件应推送到 Hub。使用 delete_patterns
在同一提交中删除现有远程文件。有关更多详细信息,请参阅 upload_folder
参考。
__call__
< 来源 >( inputs: typing.Union[str, typing.List[str]] batch_size: int = 32 as_numpy: bool = False use_labels: bool = True show_progress_bar: typing.Optional[bool] = None ) → Union[torch.Tensor, np.ndarray, List[str], int, str]
参数
- inputs (Union[str, List[str]]) — 用于预测类别的输入句子或句子列表。
- batch_size (int, 默认为 32) — 用于将句子编码为嵌入的批大小。越大通常意味着更快的处理速度,但内存使用量也越大。
- as_numpy (bool, 默认为 False) — 是否输出为 numpy 数组。
- use_labels (bool, 默认为 True) — 是否尝试返回 SetFitModel.labels 的元素。
- show_progress_bar (Optional[bool], 默认为 None) — 编码时是否显示进度条。
返回
Union[torch.Tensor, np.ndarray, List[str], int, str]
如果 use_labels 为 True 且 SetFitModel.labels 已定义,则返回与输入长度相同的字符串标签列表。否则返回与输入长度相同的向量,表示每个输入所属的预测类别。如果输入是单个字符串,则输出也是单个标签。
预测各种类别。
返回从字符串标签到整数 ID 的映射。
返回从整数 ID 到字符串标签的映射。
创建模型卡片
< 来源 >( path: str model_name: typing.Optional[str] = 'SetFit 模型' )
为 SetFit 模型创建并保存模型卡。
编码
< 源 >( inputs: typing.List[str] batch_size: int = 32 show_progress_bar: typing.Optional[bool] = None ) → Union[torch.Tensor, np.ndarray]
使用 SentenceTransformer
主体将输入句子转换为嵌入。
fit
< 源 >( x_train: typing.List[str] y_train: typing.Union[typing.List[int], typing.List[typing.List[int]]] num_epochs: int batch_size: typing.Optional[int] = None body_learning_rate: typing.Optional[float] = None head_learning_rate: typing.Optional[float] = None end_to_end: bool = False l2_weight: typing.Optional[float] = None max_length: typing.Optional[int] = None show_progress_bar: bool = True )
参数
- x_train (
List[str]
) — 训练句子列表。 - y_train (
Union[List[int], List[List[int]]]
) — 对应于训练句子的标签列表。 - num_epochs (
int
) — 训练的 epoch 数量。 - batch_size (
int
, 可选) — 要使用的批次大小。 - body_learning_rate (
float
, 可选) —AdamW
优化器中SentenceTransformer
主体的学习率。如果end_to_end=False
则忽略。 - head_learning_rate (
float
, 可选) —AdamW
优化器中可微分 torch 头的学习率。 - end_to_end (
bool
, 默认为False
) — 如果为 True,则端到端训练整个模型。否则,冻结SentenceTransformer
主体,仅训练头部。 - l2_weight (
float
, 可选) —AdamW
优化器中模型主体和头部的 L2 权重。 - max_length (
int
, 可选) — 分词器可以生成的最大 token 长度。如果未提供,则使用SentenceTransformer
主体的最大长度。 - show_progress_bar (
bool
, 默认为True
) — 是否显示训练 epoch 和迭代的进度条。
训练分类器头部,仅在使用可微分 PyTorch 头部时使用。
冻结模型主体和/或头部,防止在该组件上进行进一步训练,直到解冻。
根据模型卡数据生成并返回模型卡字符串。
predict
< 源 >( inputs: typing.Union[str, typing.List[str]] batch_size: int = 32 as_numpy: bool = False use_labels: bool = True show_progress_bar: typing.Optional[bool] = None ) → Union[torch.Tensor, np.ndarray, List[str], int, str]
参数
- inputs (Union[str, List[str]]) — 要预测类别的输入句子或句子列表。
- batch_size (int, 默认为 32) — 用于将句子编码为嵌入的批次大小。通常,更高的值意味着更快的处理速度,但内存使用量更高。
- as_numpy (bool, 默认为 False) — 是否输出为 numpy 数组。
- use_labels (bool, 默认为 True) — 是否尝试返回 SetFitModel.labels 的元素。
- show_progress_bar (Optional[bool], 默认为 None) — 编码时是否显示进度条。
返回
Union[torch.Tensor, np.ndarray, List[str], int, str]
如果 use_labels 为 True 且 SetFitModel.labels 已定义,则返回与输入长度相同的字符串标签列表。否则返回与输入长度相同的向量,表示每个输入所属的预测类别。如果输入是单个字符串,则输出也是单个标签。
预测各种类别。
predict_proba
< 源 >( inputs: typing.Union[str, typing.List[str]] batch_size: int = 32 as_numpy: bool = False show_progress_bar: typing.Optional[bool] = None ) → Union[torch.Tensor, np.ndarray]
参数
- inputs (Union[str, List[str]]) — 要预测类概率的输入句子。
- batch_size (int, 默认为 32) — 用于将句子编码为嵌入的批次大小。通常,更高的值意味着更快的处理速度,但内存使用量更高。
- as_numpy (bool, 默认为 False) — 是否输出为 numpy 数组。
- show_progress_bar (Optional[bool], 默认为 None) — 编码时是否显示进度条。
返回
Union[torch.Tensor, np.ndarray]
一个形状为 [INPUT_LENGTH, NUM_CLASSES] 的矩阵,表示将输入预测为某个类别的概率。如果输入是字符串,则输出是形状为 [NUM_CLASSES,] 的向量。
预测各种类别的概率。
示例
>>> model = SetFitModel.from_pretrained(...)
>>> model.predict_proba(["What a boring display", "Exhilarating through and through", "I'm wowed!"])
tensor([[0.9367, 0.0633],
[0.0627, 0.9373],
[0.0890, 0.9110]], dtype=torch.float64)
>>> model.predict_proba("That was cool!")
tensor([0.8421, 0.1579], dtype=torch.float64)
到
< 源 >( device: typing.Union[str, torch.device] ) → SetFitModel
将此 SetFitModel 移动到 device,然后返回 self。此方法不进行复制。
unfreeze
< 源 >( component: typing.Optional[typing.Literal['body', 'head']] = None keep_body_frozen: typing.Optional[bool] = None )
解冻模型主体和/或头部,允许在该组件上进行进一步训练。
SetFitHead
class setfit.SetFitHead
< 源 >( in_features: typing.Optional[int] = None out_features: int = 2 temperature: float = 1.0 eps: float = 1e-05 bias: bool = True device: typing.Union[torch.device, str, NoneType] = None multitarget: bool = False )
参数
- in_features (
int
, 可选) — SetFit 主体输出的嵌入维度。如果为None
,则默认为LazyLinear
。 - out_features (
int
, 默认为2
) — 目标数量。如果将out_features
设置为 1 用于二分类,它将被更改为 2,因为是二分类。 - temperature (
float
, 默认为1.0
) — Logits 的缩放因子。值越高,模型信心越低;值越低,模型信心越高。 - eps (
float
, 默认为1e-5
) — 用于缩放 logits 的数值稳定性值。 - bias (
bool
, 可选, 默认为True
) — 是否在头部添加偏差。 - device (
torch.device
, str, 可选) — 模型将被发送到的设备。如果为None
,将检查 GPU 是否可用。 - multitarget (
bool
, 默认为False
) — 通过将out_features
设置为二元预测而不是单一多项式预测来启用多目标分类。
一个支持多类别分类的 SetFit 头部,用于端到端训练。二元分类被视为二类别分类。
为了与 Sentence Transformers 兼容,我们继承了 Dense
自:https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Dense.py
forward
< 源 >( features: typing.Union[typing.Dict[str, torch.Tensor], torch.Tensor] temperature: typing.Optional[float] = None )
SetFitHead 可以接受以下嵌入:
- Sentence-Transformers 的输出格式 (
dict
)。 - 纯
torch.Tensor
。
SetFitModelCardData
class setfit.SetFitModelCardData
< 源 >( language: typing.Union[str, typing.List[str], NoneType] = None license: typing.Optional[str] = None tags: typing.Optional[typing.List[str]] = <factory> model_name: typing.Optional[str] = None model_id: typing.Optional[str] = None dataset_name: typing.Optional[str] = None dataset_id: typing.Optional[str] = None dataset_revision: typing.Optional[str] = None task_name: typing.Optional[str] = None st_id: typing.Optional[str] = None )
参数
- language (Optional[Union[str, List[str]]]) — 模型语言,可以是字符串或列表,例如“en”或“[“en”, “de”, “nl”]”
- license (Optional[str]) — 模型的许可证,例如“apache-2.0”、“mit”或“cc-by-nc-sa-4.0”
- model_name (Optional[str]) — 模型的漂亮名称,例如“SetFit with mBERT-base on SST2”。如果未定义,则使用 encoder_name/encoder_id 和 dataset_name/dataset_id 生成模型名称。
- model_id (Optional[str]) — 将模型推送到 Hub 时的模型 ID,例如“tomaarsen/span-marker-mbert-base-multinerd”。
- dataset_name (Optional[str]) — 数据集的漂亮名称,例如“SST2”。
- dataset_id (Optional[str]) — 数据集的 ID,例如“dair-ai/emotion”。
- dataset_revision (Optional[str]) — 用于训练/评估的数据集修订/提交。
- st_id (Optional[str]) — Sentence Transformers 模型 ID。
存储模型卡中使用的数据的 dataclass。
安装 codecarbon
可自动跟踪碳排放使用情况并将其包含在模型卡中。
示例
>>> model = SetFitModel.from_pretrained(
... "sentence-transformers/paraphrase-mpnet-base-v2",
... labels=["negative", "positive"],
... # Model card variables
... model_card_data=SetFitModelCardData(
... model_id="tomaarsen/setfit-paraphrase-mpnet-base-v2-sst2",
... dataset_name="SST2",
... dataset_id="sst2",
... license="apache-2.0",
... language="en",
... ),
... )
AbsaModel
from_pretrained
< 源 >( 模型ID: str 极性模型ID: typing.Optional[str] = None spaCy模型: typing.Optional[str] = None span上下文: typing.Tuple[typing.Optional[int], typing.Optional[int]] = (None, None) 强制下载: bool = None 恢复下载: bool = None 代理: typing.Optional[typing.Dict] = None 令牌: typing.Union[bool, str, NoneType] = None 缓存目录: typing.Optional[str] = None 仅限本地文件: bool = None 使用可微分头: bool = None 归一化嵌入: bool = None **模型参数 )
predict
< 源 >( 输入: typing.Union[str, typing.List[str], datasets.arrow_dataset.Dataset] ) → Union[List[Dict[str, Any]], Dataset]
预测给定输入的方面及其极性。
示例
>>> from setfit import AbsaModel
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> model.predict("The food and wine are just exquisite.")
[{'span': 'food', 'polarity': 'positive'}, {'span': 'wine', 'polarity': 'positive'}]
>>> from setfit import AbsaModel
>>> from datasets import load_dataset
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> dataset = load_dataset("tomaarsen/setfit-absa-semeval-restaurants", split="train")
>>> model.predict(dataset)
Dataset({
features: ['text', 'span', 'label', 'ordinal', 'pred_polarity'],
num_rows: 3693
})
save_pretrained
< 源 >( 保存目录: typing.Union[str, pathlib.Path] 极性保存目录: typing.Union[str, pathlib.Path, NoneType] = None 推送到集线器: bool = False **kwargs )
AspectModel
获取此模型所在的Torch设备。
from_pretrained
< 源 >( )
参数
- pretrained_model_name_or_path (
str
,Path
) —- 可以是托管在Hub上的模型的
model_id
(字符串),例如bigscience/bloom
。 - 或者是包含使用
save_pretrained
保存的模型权重的directory
路径,例如../path/to/my_model_directory/
。
- 可以是托管在Hub上的模型的
- revision (
str
, 可选) — Hub上模型的修订版本。可以是分支名称、git标签或任何提交ID。默认为main
分支的最新提交。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)从Hub下载模型权重和配置文件,覆盖现有缓存。 - proxies (
Dict[str, str]
, 可选) — 一个字典,其中包含按协议或端点使用的代理服务器,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - token (
str
或bool
, 可选) — 用作远程文件HTTP bearer授权的令牌。默认情况下,它将使用运行hf auth login
时缓存的令牌。 - cache_dir (
str
,Path
, 可选) — 缓存文件存储的文件夹路径。 - local_files_only (
bool
, 可选, 默认为False
) — 如果为True
,则避免下载文件,如果本地缓存文件存在,则返回其路径。 - labels (
List[str]
, 可选) — 如果标签是0到num_classes-1
之间的整数,则这些标签表示相应的标签。 - model_card_data (
SetFitModelCardData
, 可选) — 一个SetFitModelCardData
实例,存储模型语言、许可证、数据集名称等数据,用于自动生成的模型卡。 - model_card_data (
SetFitModelCardData
, 可选) — 一个SetFitModelCardData
实例,存储模型语言、许可证、数据集名称等数据,用于自动生成的模型卡。 - use_differentiable_head (
bool
, 可选) — 是否使用可微分(即 Torch)头部而不是逻辑回归加载SetFit。 - normalize_embeddings (
bool
, 可选) — 是否对Sentence Transformer主体生成的嵌入进行归一化。 - span_context (
int
, 默认为0
) — 跨度候选前后应添加的单词数。对于方面模型,默认为0;对于极性模型,默认为3。 - device (
Union[torch.device, str]
, 可选) — 加载SetFit模型的设备,例如"cuda:0"
、"mps"
或torch.device("cuda")
。
从 Huggingface Hub 下载模型并实例化它。
predict
< 源 >( inputs: typing.Union[str, typing.List[str]] batch_size: int = 32 as_numpy: bool = False use_labels: bool = True show_progress_bar: typing.Optional[bool] = None ) → Union[torch.Tensor, np.ndarray, List[str], int, str]
参数
- inputs (Union[str, List[str]]) — 用于预测类别的输入句子或句子列表。
- batch_size (int, 默认为32) — 用于将句子编码为嵌入的批处理大小。通常,值越大处理速度越快,但内存使用量也越高。
- as_numpy (bool, 默认为False) — 是否输出为numpy数组。
- use_labels (bool, 默认为True) — 是否尝试返回SetFitModel.labels的元素。
- show_progress_bar (Optional[bool], 默认为None) — 编码时是否显示进度条。
返回
Union[torch.Tensor, np.ndarray, List[str], int, str]
如果 use_labels 为 True 且 SetFitModel.labels 已定义,则返回与输入长度相同的字符串标签列表。否则返回与输入长度相同的向量,表示每个输入所属的预测类别。如果输入是单个字符串,则输出也是单个标签。
预测各种类别。
push_to_hub
< 源 >( repo_id: str config: typing.Union[dict, huggingface_hub.hub_mixin.DataclassInstance, NoneType] = None commit_message: str = '使用 huggingface_hub 推送模型。' private: typing.Optional[bool] = None token: typing.Optional[str] = None branch: typing.Optional[str] = None create_pr: typing.Optional[bool] = None allow_patterns: typing.Union[str, typing.List[str], NoneType] = None ignore_patterns: typing.Union[str, typing.List[str], NoneType] = None delete_patterns: typing.Union[str, typing.List[str], NoneType] = None model_card_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None )
参数
- repo_id (
str
) — 要推送到的仓库ID(例如:"username/my-model"
)。 - config (
dict
或DataclassInstance
, 可选) — 模型配置,指定为键/值字典或数据类实例。 - commit_message (
str
, 可选) — 推送时提交的消息。 - private (
bool
, 可选) — 创建的仓库是否应为私有。如果为None
(默认),则除非组织默认为私有,否则仓库将是公共的。 - token (
str
, 可选) — 用作远程文件HTTP bearer授权的令牌。默认情况下,它将使用运行hf auth login
时缓存的令牌。 - branch (
str
, 可选) — 要推送模型的git分支。默认为"main"
。 - create_pr (
boolean
, 可选) — 是否从branch
创建拉取请求并提交。默认为False
。 - allow_patterns (
List[str]
或str
, 可选) — 如果提供,则只推送至少匹配一个模式的文件。 - ignore_patterns (
List[str]
或str
, 可选) — 如果提供,则不推送匹配任何模式的文件。 - delete_patterns (
List[str]
或str
, 可选) — 如果提供,远程文件中匹配任何模式的文件将从仓库中删除。 - model_card_kwargs (
Dict[str, Any]
, 可选) — 传递给模型卡模板的其他参数,用于自定义模型卡。
将模型检查点上传到 Hub。
使用 allow_patterns
和 ignore_patterns
精确筛选哪些文件应推送到 Hub。使用 delete_patterns
在同一提交中删除现有远程文件。有关更多详细信息,请参阅 upload_folder
参考。
save_pretrained
< 源 >( save_directory: typing.Union[str, pathlib.Path] config: typing.Union[dict, huggingface_hub.hub_mixin.DataclassInstance, NoneType] = None repo_id: typing.Optional[str] = None push_to_hub: bool = False model_card_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None **push_to_hub_kwargs ) → str
或 None
参数
- save_directory (
str
或Path
) — 保存模型权重和配置的目录路径。 - config (
dict
或DataclassInstance
, 可选) — 模型配置,指定为键/值字典或数据类实例。 - push_to_hub (
bool
, 可选, 默认为False
) — 保存模型后是否将其推送到Huggingface Hub。 - repo_id (
str
, 可选) — 您在Hub上的仓库ID。仅在push_to_hub=True
时使用。如果未提供,则默认为文件夹名称。 - model_card_kwargs (
Dict[str, Any]
, 可选) — 传递给模型卡模板的其他参数,用于自定义模型卡。 - push_to_hub_kwargs — 传递给
~ModelHubMixin.push_to_hub
方法的其他关键字参数。
返回
str
或 None
如果 push_to_hub=True
,则为 Hub 上提交的 URL,否则为 None
。
将权重保存到本地目录。
到
< 源 >( device: typing.Union[str, torch.device] ) → SetFitModel
将此 SetFitModel 移动到 device,然后返回 self。此方法不进行复制。
极性模型
获取此模型所在的Torch设备。
from_pretrained
< source >( )
参数
- pretrained_model_name_or_path (
str
,Path
) —- 模型在Hub上的
model_id
(字符串),例如bigscience/bloom
。 - 或者包含使用
save_pretrained
保存的模型权重的directory
路径,例如../path/to/my_model_directory/
。
- 模型在Hub上的
- revision (
str
, 可选) — Hub上模型的修订版本。可以是分支名称、git标签或任何提交ID。默认为main
分支上的最新提交。 - force_download (
bool
, 可选, 默认为False
) — 是否强制(重新)从Hub下载模型权重和配置文件,覆盖现有缓存。 - proxies (
Dict[str, str]
, 可选) — 要按协议或端点使用的代理服务器字典,例如{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
。代理用于每个请求。 - token (
str
或bool
, 可选) — 用作远程文件HTTP bearer授权的token。默认情况下,它将使用运行hf auth login
时缓存的token。 - cache_dir (
str
,Path
, 可选) — 缓存文件存储的文件夹路径。 - local_files_only (
bool
, 可选, 默认为False
) — 如果为True
,则避免下载文件,如果本地缓存文件存在,则返回其路径。 - labels (
List[str]
, 可选) — 如果标签是介于0
到num_classes-1
之间的整数,则这些标签表示相应的标签。 - model_card_data (
SetFitModelCardData
, 可选) — 一个SetFitModelCardData
实例,用于存储模型语言、许可证、数据集名称等数据,以用于自动生成的模型卡。 - model_card_data (
SetFitModelCardData
, 可选) — 一个SetFitModelCardData
实例,用于存储模型语言、许可证、数据集名称等数据,以用于自动生成的模型卡。 - use_differentiable_head (
bool
, 可选) — 是否使用可微分(即 Torch)头部而不是逻辑回归来加载SetFit。 - normalize_embeddings (
bool
, 可选) — 是否对Sentence Transformer主体生成的嵌入应用归一化。 - span_context (
int
, 默认为0
) — 在完整句子中,应预置在 span 候选之前的单词数量。对于方面模型默认为 0,对于极性模型默认为 3。 - device (
Union[torch.device, str]
, 可选) — 加载SetFit模型时使用的设备,例如"cuda:0"
、"mps"
或torch.device("cuda")
。
从 Huggingface Hub 下载模型并实例化它。
predict
< source >( inputs: typing.Union[str, typing.List[str]] batch_size: int = 32 as_numpy: bool = False use_labels: bool = True show_progress_bar: typing.Optional[bool] = None ) → Union[torch.Tensor, np.ndarray, List[str], int, str]
参数
- inputs (Union[str, List[str]]) — 用于预测类别的输入句子或句子列表。
- batch_size (int, 默认为 32) — 用于将句子编码为嵌入的批次大小。越高通常意味着处理速度越快,但内存使用量也越高。
- as_numpy (bool, 默认为 False) — 是否以numpy数组形式输出。
- use_labels (bool, 默认为 True) — 是否尝试返回SetFitModel.labels的元素。
- show_progress_bar (Optional[bool], 默认为 None) — 编码时是否显示进度条。
返回
Union[torch.Tensor, np.ndarray, List[str], int, str]
如果 use_labels 为 True 且 SetFitModel.labels 已定义,则返回与输入长度相同的字符串标签列表。否则返回与输入长度相同的向量,表示每个输入所属的预测类别。如果输入是单个字符串,则输出也是单个标签。
预测各种类别。
push_to_hub
< source >( repo_id: str config: typing.Union[dict, huggingface_hub.hub_mixin.DataclassInstance, NoneType] = None commit_message: str = '使用 huggingface_hub 推送模型。' private: typing.Optional[bool] = None token: typing.Optional[str] = None branch: typing.Optional[str] = None create_pr: typing.Optional[bool] = None allow_patterns: typing.Union[str, typing.List[str], NoneType] = None ignore_patterns: typing.Union[str, typing.List[str], NoneType] = None delete_patterns: typing.Union[str, typing.List[str], NoneType] = None model_card_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None )
参数
- repo_id (
str
) — 要推送到的存储库ID(例如:"username/my-model"
)。 - config (
dict
或DataclassInstance
, 可选) — 指定为键/值字典或数据类实例的模型配置。 - commit_message (
str
, 可选) — 推送时提交的消息。 - private (
bool
, 可选) — 创建的存储库是否为私有。如果为None
(默认),除非组织的默认设置为私有,否则存储库将是公开的。 - token (
str
, 可选) — 用作远程文件HTTP bearer授权的token。默认情况下,它将使用运行hf auth login
时缓存的token。 - branch (
str
, 可选) — 推送模型所用的git分支。默认为"main"
。 - create_pr (
boolean
, 可选) — 是否从branch
创建带有该提交的拉取请求。默认为False
。 - allow_patterns (
List[str]
或str
, 可选) — 如果提供,则只推送至少匹配一个模式的文件。 - ignore_patterns (
List[str]
或str
, 可选) — 如果提供,则不推送与任何模式匹配的文件。 - delete_patterns (
List[str]
或str
, 可选) — 如果提供,则与任何模式匹配的远程文件将从仓库中删除。 - model_card_kwargs (
Dict[str, Any]
, 可选) — 传递给模型卡模板的附加参数,用于自定义模型卡。
将模型检查点上传到 Hub。
使用 allow_patterns
和 ignore_patterns
精确筛选哪些文件应推送到 Hub。使用 delete_patterns
在同一提交中删除现有远程文件。有关更多详细信息,请参阅 upload_folder
参考。
save_pretrained
< source >( save_directory: typing.Union[str, pathlib.Path] config: typing.Union[dict, huggingface_hub.hub_mixin.DataclassInstance, NoneType] = None repo_id: typing.Optional[str] = None push_to_hub: bool = False model_card_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None **push_to_hub_kwargs ) → str
或 None
参数
- save_directory (
str
或Path
) — 保存模型权重和配置的目录路径。 - config (
dict
或DataclassInstance
, 可选) — 指定为键/值字典或数据类实例的模型配置。 - push_to_hub (
bool
, 可选, 默认为False
) — 是否在保存模型后将其推送到Huggingface Hub。 - repo_id (
str
, 可选) — 您在Hub上的存储库ID。仅当push_to_hub=True
时使用。如果未提供,将默认为文件夹名称。 - model_card_kwargs (
Dict[str, Any]
, 可选) — 传递给模型卡模板的附加参数,用于自定义模型卡。 - push_to_hub_kwargs — 传递给
~ModelHubMixin.push_to_hub
方法的附加关键字参数。
返回
str
或 None
如果 push_to_hub=True
,则为 Hub 上提交的 URL,否则为 None
。
将权重保存到本地目录。
到
< source >( device: typing.Union[str, torch.device] ) → SetFitModel
将此 SetFitModel 移动到 device,然后返回 self。此方法不进行复制。