Optimum 文档
模型
并获得增强的文档体验
开始使用
模型
通用模型类
以下 Furiosa 类可用于实例化不带特定头的基本模型类。
FuriosaAI模型
class optimum.furiosa.FuriosaAIModel
< 来源 >( model config: PretrainedConfig = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None label_names: typing.Optional[typing.List[str]] = None **kwargs )
运行评估并返回指标和预测结果。
使用指定的 device
进行推理。例如:“cpu”或“gpu”。device
可以是大写或小写。为了加快首次推理,请在调用 .to()
后调用 .compile()
。
计算机视觉
以下类可用于以下计算机视觉任务。
FuriosaAIModelForImageClassification
class optimum.furiosa.FuriosaAIModelForImageClassification
< 来源 >( model = None config = None **kwargs )
参数
- 模型 (
furiosa.runtime.model
) — 用于运行推理的主要类。 - 配置 (
transformers.PretrainedConfig
) — PretrainedConfig 是模型配置类,包含模型的所有参数。使用配置文件初始化不会加载与模型相关的权重,只加载配置。请查看~furiosa.modeling.FuriosaAIBaseModel.from_pretrained
方法以加载模型权重。 - 设备 (
str
, 默认为"CPU"
) — 模型将为其优化的设备类型。生成的编译模型将包含特定于该设备的节点。 - furiosa_config (
Optional[Dict]
, 默认为None
) — 包含模型编译相关信息的字典。 - 编译 (
bool
, 默认为True
) — 设置为False
时,禁用加载步骤中的模型编译。
带有 ImageClassifierOutput 的 FuriosaAI 模型,用于图像分类任务。
该模型继承自 optimum.furiosa.FuriosaAIBaseModel
。有关库为其所有模型实现的通用方法(如下载或保存),请查看超类文档。
前向
< 来源 >( pixel_values: typing.Union[torch.Tensor, numpy.ndarray] **kwargs )
参数
- 像素值 (
torch.Tensor
) — 当前批次图像对应的像素值。像素值可以使用AutoFeatureExtractor
从编码图像中获取。
FuriosaAIModelForImageClassification 的前向方法,覆盖了 __call__
特殊方法。
虽然前向传递的实现需要在该函数内部定义,但之后应该调用 Module
实例而不是直接调用此函数,因为前者负责运行预处理和后处理步骤,而后者则默默忽略它们。
使用 transformers.pipelines
进行图像分类的示例
>>> from transformers import AutoFeatureExtractor, pipeline
>>> from optimum.furiosa import FuriosaAIModelForImageClassification
>>> preprocessor = AutoFeatureExtractor.from_pretrained("microsoft/resnet50")
>>> model = FuriosaAIModelForImageClassification.from_pretrained("microsoft/resnet50", export=True, input_shape_dict="dict('pixel_values': [1, 3, 224, 224])", output_shape_dict="dict("logits": [1, 1000])",)
>>> pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> outputs = pipe(url)