import time
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from keras.preprocessing.image import ImageDataGenerator
from keras.losses import CategoricalCrossentropy
from keras.src.callbacks import EarlyStopping
from keras.src.optimizers import Adam
import tensorflow as tf
from keras import layers, models

from Func.getSubFolders import count_sub_folders

path = 'Data'
output = 'Model/keras_model.h5'
start_time = time.time()

# Step 1: Load and Preprocess Images
datagen = ImageDataGenerator(
    rescale=1. / 255,
    validation_split=0.2,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
)

test_datagen = ImageDataGenerator(rescale=1. / 255)

# Step 2: Label the Data
train_set = datagen.flow_from_directory(
    path,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='training'
)

test_set = datagen.flow_from_directory(
    path,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation'
)

# Step 4: Build the Model
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(224, 224, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(256, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Flatten())
model.add(Dense(units=256, activation='relu'))
model.add(Dense(units=128, activation='relu'))
model.add(Dense(units=count_sub_folders(path), activation='softmax'))
# def createLayers(input_shape=(224, 224, 3)):
#     inputs = tf.keras.Input(shape=input_shape)
#
#     x = layers.Conv2D(48, (1, 1), padding='same', use_bias=False, name='block_1_expand')(inputs)
#     x = layers.BatchNormalization(name='block_1_expand_BN')(x)
#     x = layers.ReLU(6., name='block_1_expand_relu')(x)
#
#     x = layers.DepthwiseConv2D((3, 3), padding='same', use_bias=False, name='block_1_depthwise')(x)
#     x = layers.BatchNormalization(name='block_1_depthwise_BN')(x)
#     x = layers.ReLU(6., name='block_1_depthwise_relu')(x)
#
#     x = layers.Conv2D(8, (1, 1), padding='same', use_bias=False, name='block_1_project')(x)
#     x = layers.BatchNormalization(name='block_1_project_BN')(x)
#
#     for i in range(2,5):
#         x1 = layers.Conv2D(48, (1, 1), padding='same', use_bias=False, name=f'block_{i}_expand')(x)
#         x1 = layers.BatchNormalization(name=f'block_{i}_expand_BN')(x1)
#         x1 = layers.ReLU(6., name=f'block_{i}_expand_relu')(x1)
#
#         x1 = layers.DepthwiseConv2D((3, 3), padding='same', use_bias=False, name=f'block_{i}_depthwise')(x1)
#         x1 = layers.BatchNormalization(name=f'block_{i}_depthwise_BN')(x1)
#         x1 = layers.ReLU(6., name=f'block_{i}_depthwise_relu')(x1)
#
#         x1 = layers.Conv2D(8, (1, 1), padding='same', use_bias=False, name=f'block_{i}_project')(x1)
#         x1 = layers.BatchNormalization(name=f'block_{i}_project_BN')(x1)
#
#         x = layers.Add(name=f'block_{i}_add')([x, x1])
#
#     x = tf.keras.layers.GlobalAveragePooling2D()(x)
#     outputs = tf.keras.layers.Dense(count_sub_folders(path), activation='softmax')(x)
#     model = models.Model(inputs, outputs, name='testModel')
#
#     return model
#
#
# model = createLayers()

# Compile the Model after pruning
model.compile(optimizer=Adam(learning_rate=0.001),
              loss=CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

# Step 6: Train the Model
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

model.fit(train_set, validation_data=test_set, epochs=10)

# Step 7: Evaluate the Model
loss, accuracy = model.evaluate(test_set)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

# Save the trained model
model.save(output)
end_time = time.time()

execute_time = (end_time - start_time) / 60

model.summary()

# Print the result
print(f"It took: {execute_time:0.2f} minutes")