Newer
Older
import matplotlib.pyplot as plt
import numpy as np
def plot_images(images, labels, class_indices):
num_images = len(images)
grid_size = int(np.ceil(np.sqrt(num_images)))
plt.figure(figsize=(15, 15))
for i in range(num_images):
plt.subplot(grid_size, grid_size, i + 1)
img = images[i]
plt.imshow(img)
label = np.argmax(labels[i])
label_name = list(class_indices.keys())[list(class_indices.values()).index(label)]
plt.title(label_name)
plt.axis('off')
plt.tight_layout()
plt.show()