| | from torch.utils.data import Dataset |
| | import os |
| | import pathlib |
| | import torch |
| |
|
| | from PIL import Image |
| | from torch.utils.data import Dataset |
| | from torchvision import transforms |
| | from typing import Tuple, Dict, List |
| |
|
| | import torch.utils.data as data |
| | import numpy as np |
| | |
| | import random |
| | |
| | |
| |
|
| | |
| | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: |
| | """Finds the class folder names in a target directory. |
| | |
| | Assumes target directory is in standard image classification format. |
| | |
| | Args: |
| | directory (str): target directory to load classnames from. |
| | |
| | Returns: |
| | Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...)) |
| | |
| | Example: |
| | find_classes("food_images/train") |
| | >>> (["class_1", "class_2"], {"class_1": 0, ...}) |
| | """ |
| | |
| | classes = sorted([entry.name for entry in os.scandir(directory) if entry.is_dir()]) |
| | |
| | if not classes: |
| | raise FileNotFoundError(f"Couldn't find any classes in {directory}.") |
| | |
| | |
| | class_to_idx = {cls_name: int(cls_name) for cls_name in (classes)} |
| | return classes, class_to_idx |
| |
|
| |
|
| | class SamData(Dataset): |
| | |
| | |
| | def __init__(self, targ_dir: str, transform=None) -> None: |
| | |
| | |
| | |
| | self.paths = sorted(list(pathlib.Path(targ_dir).glob("*/*.jpg"))) |
| | |
| | |
| | self.indexes = [] |
| | self.folds = [] |
| | for i, n in enumerate(self.paths): |
| | |
| | strrr= str(n) |
| | |
| | self.indexes.append(int(strrr[strrr.index('sa_')+13:strrr.index('.jpg')])) |
| | self.folds.append(strrr[strrr.index('sa_')+3:strrr.index('sa_')+9]) |
| |
|
| | self.transform = transform |
| | |
| | |
| |
|
| | |
| | def load_image(self, index: int) -> Image.Image: |
| | "Opens an image via a path and returns it." |
| | image_path = self.paths[index] |
| | return Image.open(image_path) |
| | |
| | |
| | def __len__(self) -> int: |
| | "Returns the total number of samples." |
| | return len(self.paths) |
| | |
| | |
| | def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: |
| | "Returns one sample of data, data, label (X, y, index)." |
| | img = self.load_image(index) |
| |
|
| | indx = self.indexes[index] |
| | |
| | |
| | |
| | |
| | |
| | if self.transform: |
| | return self.transform(img), indx |
| | else: |
| | return img, indx |