CNN / src /model.py
sheethal0703's picture
Upload folder using huggingface_hub
e85e22c verified
raw
history blame contribute delete
737 Bytes
import tensorflow as tf
from tensorflow.keras import layers, models
def create_model(input_shape=(32, 32, 3)):
"""
Creates a simple CNN model for CIFAR-10 classification.
"""
model = models.Sequential()
# Convolutional Base
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
# Dense Layers
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10)) # 10 classes
return model