| from typing import Dict
|
|
|
| from transformers.configuration_utils import PretrainedConfig
|
| from optimum.exporters.onnx.model_configs import ViTOnnxConfig
|
|
|
| MODEL_NAMES = [
|
| 'efficientnet_b0',
|
| 'efficientnet_b1',
|
| 'efficientnet_b2',
|
| 'efficientnet_b3',
|
| 'efficientnet_b4',
|
| 'efficientnet_b5',
|
| 'efficientnet_b6',
|
| 'efficientnet_b7',
|
| 'efficientnet_b8',
|
| 'efficientnet_l2'
|
| ]
|
|
|
|
|
| class EfficientNetConfig(PretrainedConfig):
|
| model_type = 'efficientnet'
|
|
|
| def __init__(
|
| self,
|
| model_name: str = 'efficientnet_b0',
|
| pretrained: bool = False,
|
| num_classes: int = 1000,
|
| global_pool: str = 'avg',
|
| **kwargs,
|
| ):
|
| if model_name not in MODEL_NAMES:
|
| raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}')
|
|
|
| self.model_name = model_name
|
| self.pretrained = pretrained
|
| self.num_classes = num_classes
|
| self.global_pool = global_pool
|
| super().__init__(**kwargs)
|
|
|
|
|
| class EfficientNetOnnxConfig(ViTOnnxConfig):
|
| @property
|
| def outputs(self) -> Dict[str, Dict[int, str]]:
|
| common_outputs = super().outputs
|
|
|
| if self.task == "image-classification":
|
| common_outputs["logits"] = {0: "batch_size", 1: "num_classes"}
|
|
|
| return common_outputs
|
|
|
|
|
| __all__ = [
|
| 'EfficientNetConfig',
|
| 'EfficientNetOnnxConfig'
|
| ] |