| import tensorflow as tf | |
| import os | |
| def train_model(model, train_data, test_data, epochs=10, batch_size=64, model_save_path='models/cifar10_cnn.h5'): | |
| """ | |
| Trains the model and saves it. | |
| """ | |
| (x_train, y_train) = train_data | |
| (x_test, y_test) = test_data | |
| model.compile(optimizer='adam', | |
| loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
| metrics=['accuracy']) | |
| history = model.fit(x_train, y_train, epochs=epochs, | |
| batch_size=batch_size, | |
| validation_data=(x_test, y_test)) | |
| # Ensure directory exists | |
| os.makedirs(os.path.dirname(model_save_path), exist_ok=True) | |
| model.save(model_save_path) | |
| print(f"Model saved to {model_save_path}") | |
| return history | |