CNN / src /data.py
sheethal0703's picture
Upload folder using huggingface_hub
e85e22c verified
raw
history blame contribute delete
623 Bytes
import tensorflow as tf
import numpy as np
def load_data():
"""
Loads CIFAR-10 dataset and normalizes it.
Returns:
(x_train, y_train), (x_test, y_test)
"""
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
return (x_train, y_train), (x_test, y_test)
def get_class_names():
return ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']