Spaces:
Build error
Build error
| import pathlib | |
| from typing import Callable, Optional, Any, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from torchvision.datasets import VisionDataset | |
| from torchvision.datasets.utils import download_and_extract_archive, download_url | |
| class StanfordCarsClass(VisionDataset): | |
| """`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset | |
| The Cars dataset contains 16,185 images of 196 classes of cars. The data is | |
| split into 8,144 training images and 8,041 testing images, where each class | |
| has been split roughly in a 50-50 split | |
| .. note:: | |
| This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format. | |
| Args: | |
| root (string): Root directory of dataset | |
| split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.RandomCrop`` | |
| target_transform (callable, optional): A function/transform that takes in the | |
| target and transforms it. | |
| download (bool, optional): If True, downloads the dataset from the internet and | |
| puts it in root directory. If dataset is already downloaded, it is not | |
| downloaded again.""" | |
| root = pathlib.Path.home() / "tmp" / "Datasets" / "StanfordCars" | |
| def __init__( | |
| self, | |
| train: bool = True, | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| download: bool = True, | |
| ) -> None: | |
| try: | |
| import scipy.io as sio | |
| except ImportError: | |
| raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") | |
| super().__init__(self.root, transform=transform, target_transform=target_transform) | |
| self.train = train | |
| self._base_folder = pathlib.Path(self.root) / "stanford_cars" | |
| devkit = self._base_folder / "devkit" | |
| if train: | |
| self._annotations_mat_path = devkit / "cars_train_annos.mat" | |
| self._images_base_path = self._base_folder / "cars_train" | |
| else: | |
| self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" | |
| self._images_base_path = self._base_folder / "cars_test" | |
| if download: | |
| self.download() | |
| if not self._check_exists(): | |
| raise RuntimeError("Dataset not found. You can use download=True to download it") | |
| self.samples = [ | |
| ( | |
| str(self._images_base_path / annotation["fname"]), | |
| annotation["class"] - 1, # Original target mapping starts from 1, hence -1 | |
| ) | |
| for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] | |
| ] | |
| self.targets = np.array([x[1] for x in self.samples]) | |
| self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() | |
| self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> Tuple[Any, Any]: | |
| """Returns pil_image and class_id for given index""" | |
| image_path, target = self.samples[idx] | |
| pil_image = Image.open(image_path).convert("RGB") | |
| if self.transform is not None: | |
| pil_image = self.transform(pil_image) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return pil_image, target | |
| def download(self) -> None: | |
| if self._check_exists(): | |
| return | |
| download_and_extract_archive( | |
| url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", | |
| download_root=str(self._base_folder), | |
| md5="c3b158d763b6e2245038c8ad08e45376", | |
| ) | |
| if self.train: | |
| download_and_extract_archive( | |
| url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", | |
| download_root=str(self._base_folder), | |
| md5="065e5b463ae28d29e77c1b4b166cfe61", | |
| ) | |
| else: | |
| download_and_extract_archive( | |
| url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", | |
| download_root=str(self._base_folder), | |
| md5="4ce7ebf6a94d07f1952d94dd34c4d501", | |
| ) | |
| download_url( | |
| url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", | |
| root=str(self._base_folder), | |
| md5="b0a2b23655a3edd16d84508592a98d10", | |
| ) | |
| def _check_exists(self) -> bool: | |
| if not (self._base_folder / "devkit").is_dir(): | |
| return False | |
| return self._annotations_mat_path.exists() and self._images_base_path.is_dir() | |