import tensorflow as tf from tensorflow.keras import layers, models, callbacks import matplotlib.pyplot as plt import numpy as np # 下载MNIST数据集 mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() # 归一化数据 x_train, x_test = x_train / 255.0, x_test / 255.0 # 拓展维度 x_train = x_train[..., tf.newaxis].astype("float32") x_test = x_test[..., tf.newaxis].astype("float32") # 数据增强 data_augmentation = tf.keras.Sequential([ layers.RandomRotation(0.1), layers.RandomZoom(0.1), layers.RandomTranslation(0.1, 0.1) ]) # 创建数据集对象,并设置每个批次的大小 train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32) test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) # 构建模型 model = models.Sequential([ data_augmentation, layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), layers.MaxPooling2D((2, 2)), layers.Conv2D(64, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Conv2D(128, (3, 3), activation='relu'), layers.MaxPooling2D((2, 2)), layers.Flatten(), layers.Dropout(0.5), layers.Dense(128, activation='relu'), layers.Dense(10) ]) # 编译模型 model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # 回调模型 reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0001) early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True) model_checkpoint = callbacks.ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True) # 训练模型 history = model.fit(train_ds, epochs=10, validation_data=test_ds, callbacks=[reduce_lr, early_stopping, model_checkpoint]) # 加载模型最佳权重 model.load_weights('best_model.keras') # 评估模型 test_loss, test_acc = model.evaluate(test_ds, verbose=2) print('\nTest accuracy:', test_acc) # 绘制实验的损失及准确率 plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(history.history['loss'], label='Training Loss') plt.plot(history.history['val_loss'], label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend(loc='upper right') plt.subplot(1, 2, 2) plt.plot(history.history['accuracy'], label='Training Accuracy') plt.plot(history.history['val_accuracy'], label='Validation Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend(loc='lower right') plt.show() # 预测并绘制一些预测样本 def plot_image(i, predictions_array, true_label, img): true_label, img = true_label[i], img[i].reshape((28, 28)) plt.grid(False) plt.xticks([]) plt.yticks([]) plt.imshow(img, cmap=plt.cm.binary) predicted_label = np.argmax(predictions_array) if predicted_label == true_label: color = 'blue' else: color = 'red' plt.xlabel("{} {:2.0f}% ({})".format(true_label, 100*np.max(predictions_array), predicted_label), color=color) # 预测测试样本集 predictions = model.predict(test_ds) # 绘制前五个测试样本 num_rows = 5 num_cols = 3 num_images = num_rows*num_cols plt.figure(figsize=(2*2*num_cols, 2*num_rows)) for i in range(num_images): plt.subplot(num_rows, 2*num_cols, 2*i+1) plot_image(i, predictions[i], y_test, x_test) plt.tight_layout() plt.show()