import matplotlib.pyplot as plt import numpy as np def plot_images(images, labels, class_indices): num_images = len(images) grid_size = int(np.ceil(np.sqrt(num_images))) plt.figure(figsize=(15, 15)) for i in range(num_images): plt.subplot(grid_size, grid_size, i + 1) img = images[i] plt.imshow(img) label = np.argmax(labels[i]) label_name = list(class_indices.keys())[list(class_indices.values()).index(label)] plt.title(label_name) plt.axis('off') plt.tight_layout() plt.show()