import tensorflow as tf import os def train_model(model, train_data, test_data, epochs=10, batch_size=64, model_save_path='models/cifar10_cnn.h5'): """ Trains the model and saves it. """ (x_train, y_train) = train_data (x_test, y_test) = test_data model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, y_test)) # Ensure directory exists os.makedirs(os.path.dirname(model_save_path), exist_ok=True) model.save(model_save_path) print(f"Model saved to {model_save_path}") return history