Skip to content
Snippets Groups Projects
trainModel.py 1.51 KiB
Newer Older
Michiel_VE's avatar
Michiel_VE committed
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from keras.preprocessing.image import ImageDataGenerator
Michiel_VE's avatar
Michiel_VE committed
from keras.losses import CategoricalCrossentropy

Michiel_VE's avatar
Michiel_VE committed
from Func.getSubFolders import count_sub_folders

Michiel_VE's avatar
Michiel_VE committed
path = 'Data_test'
output = 'Model/pruned.h5'

Michiel_VE's avatar
Michiel_VE committed
# Step 1: Load and Preprocess Images
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

test_datagen = ImageDataGenerator(rescale=1. / 255)

# Step 2: Label the Data
train_set = train_datagen.flow_from_directory(
Michiel_VE's avatar
Michiel_VE committed
    path,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
Michiel_VE's avatar
Michiel_VE committed
)

test_set = test_datagen.flow_from_directory(
Michiel_VE's avatar
Michiel_VE committed
    path,
Michiel_VE's avatar
Michiel_VE committed
    target_size=(224, 224),
    batch_size=32,
Michiel_VE's avatar
Michiel_VE committed
    class_mode='categorical'
Michiel_VE's avatar
Michiel_VE committed
)

# Step 4: Build the Model
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(224, 224, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(units=128, activation='relu'))
Michiel_VE's avatar
Michiel_VE committed
model.add(Dense(units=count_sub_folders(path), activation='softmax'))

Michiel_VE's avatar
Michiel_VE committed
# Compile the Model after pruning
model.compile(optimizer='adam',
              loss=CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])
Michiel_VE's avatar
Michiel_VE committed

# Step 6: Train the Model
Michiel_VE's avatar
Michiel_VE committed
model.fit(train_set, epochs=10, validation_data=test_set)
Michiel_VE's avatar
Michiel_VE committed
# Step 7: Evaluate the Model
loss, accuracy = model.evaluate(test_set)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')
Michiel_VE's avatar
Michiel_VE committed

# Save the trained model
Michiel_VE's avatar
Michiel_VE committed
model.save(output)