import keras
import numpy as np
import cv2
from cvzone.HandTrackingModule import HandDetector


class Classifier:

    def __init__(self, model_path, label_path):
        self.model_path = model_path

        np.set_printoptions(suppress=True)

        self.model = keras.models.load_model(self.model_path)
        self.data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
        self.labels_path = label_path

        if self.labels_path:
            with open(self.labels_path, "r") as label_file:
                self.list_labels = [line.strip() for line in label_file]
        else:
            print("No Labels Found")

    def getPrediction(self, img):
        cv2.imshow("imgFixed classifier", img)
        img_resized = cv2.resize(img, (224, 224))
        image_array = np.asarray(img_resized)
        normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1

        self.data[0] = normalized_image_array

        prediction = self.model.predict(self.data)
        index_val = np.argmax(prediction)

        return list(prediction[0]), index_val


def main():
    cap = cv2.VideoCapture(0)
    detector = HandDetector(maxHands=1)

    mask_classifier = Classifier('Model/keras_model.h5', 'Model/labels.txt')

    while True:
        ret, frame = cap.read()
        hands, img = detector.findHands(frame)
        

        if not ret:
            print("Failed to capture frame")
            break

        prediction = mask_classifier.getPrediction(img)

        print("Prediction:", prediction)



        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()