File size: 623 Bytes
e85e22c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import tensorflow as tf
import numpy as np
def load_data():
"""
Loads CIFAR-10 dataset and normalizes it.
Returns:
(x_train, y_train), (x_test, y_test)
"""
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
return (x_train, y_train), (x_test, y_test)
def get_class_names():
return ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
|