CNN / src /train.py
sheethal0703's picture
Upload folder using huggingface_hub
e85e22c verified
raw
history blame contribute delete
833 Bytes
import tensorflow as tf
import os
def train_model(model, train_data, test_data, epochs=10, batch_size=64, model_save_path='models/cifar10_cnn.h5'):
"""
Trains the model and saves it.
"""
(x_train, y_train) = train_data
(x_test, y_test) = test_data
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, y_test))
# Ensure directory exists
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
model.save(model_save_path)
print(f"Model saved to {model_save_path}")
return history