Skip to content
Snippets Groups Projects
DrawImages.py 563 B
Newer Older
Michiel_VE's avatar
Michiel_VE committed
import matplotlib.pyplot as plt
import numpy as np


def plot_images(images, labels, class_indices):
    num_images = len(images)
    grid_size = int(np.ceil(np.sqrt(num_images)))

    plt.figure(figsize=(15, 15))
    for i in range(num_images):
        plt.subplot(grid_size, grid_size, i + 1)
        img = images[i]
        plt.imshow(img)

        label = np.argmax(labels[i])
        label_name = list(class_indices.keys())[list(class_indices.values()).index(label)]
        plt.title(label_name)
        plt.axis('off')
    plt.tight_layout()
    plt.show()