# BBD's Krita Script Starter Feb 2018
'''
    SPDX-FileCopyrightText: 2024 Agata Cacko <cacko.azh@gmail.com>

    This file is part of Fast Sketch Cleanup Plugin  for Krita

    SPDX-License-Identifier: GPL-3.0-or-later
'''

from krita import Extension

from PyQt5.QtWidgets import QDialog, QHBoxLayout, QSlider, QSpinBox, QVBoxLayout, QPushButton, QToolButton, QCheckBox, QFileDialog, QLabel, QLineEdit, QWidget
from PyQt5.QtWidgets import QRadioButton, QMessageBox, QGroupBox, QFrame, QComboBox, QProgressBar, QDoubleSpinBox
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import Qt, QCoreApplication

import threading

from os.path import isfile
import sys

import openvino as ov


import openvino.properties.hint as hints
import numpy as np
import sys

from os.path import isfile, join, exists
from os import listdir


#print (sys.path)
import sys
import os

sys.path.append(os.path.dirname(__file__))

import converter as conv
from inferenceRunner import *
from levels_widget import *

from preview_image_label import *
from preview_images_viewer import *
import image_processing as improc



EXTENSION_ID = 'pykrita_fast_sketch_cleanup'
MENU_ENTRY = 'Fast Sketch Cleanup'



settingGroup = "fast_sketch_cleanup_plugin"

# preprocessing
invertSettingName = "invert"
scaleSettingName = "scale"
preprocessingLevelsLowerSettingName = "pre_levels_lower"
preprocessingLevelsUpperSettingName = "pre_levels_upper"



# postprocessing
# invert and scale included above
postprocessingLevelsLowerSettingName = "post_levels_lower"
postprocessingLevelsUpperSettingName = "post_levels_upper"
useSharpenSettingName = "use_sharpen"
sharpenStrengthSettingName = "sharpen_strength"


# the rest of parameters
modelSettingName = "model_file"
modelFolderSettingName = "model_folder"
deviceSettingName = "device"
samplesCountSettingName = "samples_count"




class Parameters:

    preprocessing: improc.PreProcessingInfo = None
    postprocessing: improc.PostProcessingInfo = None
    modelFile: str = None
    modelFolder: str = None
    deviceToUse: str = None
    samplesCount: int = None
    
    def readSetting(self, settingName, defaultValue):
        return Krita.readSetting(settingGroup, settingName, defaultValue)

    def boolFromStr(self, string):
        return True if string.lower() == "true" else False

    def loadFromConfig(self):
        self.preprocessing = improc.PreProcessingInfo()
        self.postprocessing = improc.PostProcessingInfo()
        

        # pre processing
        self.preprocessing.invert = self.boolFromStr(self.readSetting(invertSettingName, "False")) 
        self.preprocessing.preprocessingLevelsLowerValue = float(self.readSetting(preprocessingLevelsLowerSettingName, "0.0"))
        self.preprocessing.preprocessingLevelsUpperValue = float(self.readSetting(preprocessingLevelsUpperSettingName, "1.0"))
        self.preprocessing.resizeInPreprocessingValue = float(self.readSetting(scaleSettingName, "1.0"))

        # post processing
        self.postprocessing.postprocessingLevelsLowerValue = float(self.readSetting(postprocessingLevelsLowerSettingName, "0.0"))
        self.postprocessing.postprocessingLevelsUpperValue = float(self.readSetting(postprocessingLevelsUpperSettingName, "1.0"))
        self.postprocessing.isSharpenChecked = self.boolFromStr(self.readSetting(useSharpenSettingName, "False"))
        self.postprocessing.sharpenStrength = float(self.readSetting(sharpenStrengthSettingName, "1.0"))
        
        self.postprocessing.resizeInPreprocessingValue = self.preprocessing.resizeInPreprocessingValue
        self.postprocessing.invert = self.preprocessing.invert

        # the rest
        self.deviceToUse = self.readSetting(deviceSettingName, "cpu")
        self.modelFile = self.readSetting(modelSettingName, "")
        self.modelFolder = self.readSetting(modelFolderSettingName, "")
        self.samplesCount = int(self.readSetting(samplesCountSettingName, "1"))
        


    def writeSetting(self, settingName, value):
        Krita.writeSetting(settingGroup, settingName, value)

    def saveToConfig(self):
        
        # pre processing
        self.writeSetting(invertSettingName, str(self.preprocessing.invert))
        self.writeSetting(preprocessingLevelsLowerSettingName, str(self.preprocessing.preprocessingLevelsLowerValue))
        self.writeSetting(preprocessingLevelsUpperSettingName, str(self.preprocessing.preprocessingLevelsUpperValue))
        self.writeSetting(scaleSettingName, str(self.preprocessing.resizeInPreprocessingValue))

        # post processing
        self.writeSetting(postprocessingLevelsLowerSettingName, str(self.postprocessing.postprocessingLevelsLowerValue))
        self.writeSetting(postprocessingLevelsUpperSettingName, str(self.postprocessing.postprocessingLevelsUpperValue))
        self.writeSetting(useSharpenSettingName, str(self.postprocessing.isSharpenChecked))
        self.writeSetting(sharpenStrengthSettingName, str(self.postprocessing.sharpenStrength))

        # the rest
        self.writeSetting(deviceSettingName, self.deviceToUse)
        self.writeSetting(modelSettingName, self.modelFile)
        self.writeSetting(modelFolderSettingName, self.modelFolder)
        self.writeSetting(samplesCountSettingName, str(self.samplesCount))



    def isEqualToSamplePreviewWise(self, cached):
        if cached == None:
            return False
        return (self.preprocessing.isEqualTo(cached.preprocessing)
                and self.postprocessing.isEqualTo(cached.postprocessing)
                and self.modelFile == cached.modelFile
                and self.modelFolder == cached.modelFolder
                and self.samplesCount == cached.samplesCount)
        # ignore device to use
                

    def __str__(self):
        return f"[{self.preprocessing}] [{self.postprocessing}] [{self.deviceToUse}] [{self.samplesCount}] [{self.modelFolder}] [{self.modelFile}]"



class FastSketchCleanup(Extension):

    executeButton : QPushButton = None
    invert : bool = False
    modelFilename : str = ""
    dialog : QDialog = None
    

    margin = 40
    divisableBy = 16

    imageSampleSize = 256

    slidersRangeScaler = 1000

    executionMutex = threading.Lock()
    executionsAllowedCounter = 1

    basePath = ""

    currentInference = None
    previewInProgress = False
    samplesGenerated = 0

    images = None
    imagesInOrder = None
    
    _parent : None

    cachedPreInfo = None
    cachedPostInfo = None
    cachedSampleCount = None
    cachedModelName = None
    cachedLastPreviewUpdateParameters = None
    
    

    def __init__(self, parent):
        # Always initialise the superclass.
        # This is necessary to create the underlying C++ object
        super().__init__(parent)
        self._parent = parent
        self.executionMutex = threading.Lock()

    def setup(self):
        pass

    def createActions(self, window):
        action = window.createAction(EXTENSION_ID, MENU_ENTRY, "tools/scripts")
        # parameter 1 = the name that Krita uses to identify the action
        # parameter 2 = the text to be added to the menu entry for this script
        # parameter 3 = location of menu entry
        action.triggered.connect(self.action_triggered)

    def createDialog(self):
        self.setupFinished = False
        dialog = QDialog()
        self.dialog = dialog

        self.imagesInOrder = [None, None, None, None]

        mainLayout = QVBoxLayout()

        dialog.blockSignals(True)

        self.setupModelChooserUi(mainLayout)
        self.setupDeviceToUseUi(mainLayout)
        
        
        #################
        self.setupPreviewImagesUi(mainLayout)

        #################
        self.setupPreAndPostprocessingUi(mainLayout)
        
        ##################
        self.setupAdvancedOptionsUi(mainLayout)
        self.setupBottomOfUi(mainLayout)
        

        dialog.setLayout(mainLayout)

        parameters = Parameters()
        parameters.loadFromConfig()

        dialog.blockSignals(False)

        self.setValuesFromParameters(parameters)

        self.setupFinished = True
        self.previewUpdatesPending = 0
        self.updatePreview(True)

        dialog.open()
        dialog.finished.connect(self.dialogFinished)

    
    def getPreProcessingInfo(self) -> improc.PreProcessingInfo:
        preinfo = improc.PreProcessingInfo()
        preinfo.invert = self.invert
        preinfo.preprocessingLevelsLowerValue = self.preprocessingLevels.lowerValue
        preinfo.preprocessingLevelsUpperValue = self.preprocessingLevels.upperValue
        preinfo.resizeInPreprocessingValue = self.resizeInPreSpinBox.value()
        return preinfo
    

    def getPostProcessingInfo(self) -> improc.PostProcessingInfo:
        postinfo = improc.PostProcessingInfo()
        postinfo.invert = self.invert
        postinfo.isSharpenChecked = self.sharpenCheckbox.isChecked()
        postinfo.postprocessingLevelsLowerValue = self.postprocessingLevels.lowerValue
        postinfo.postprocessingLevelsUpperValue = self.postprocessingLevels.upperValue
        postinfo.sharpenStrength = self.sharpenFactorSpinBox.value()
        postinfo.resizeInPreprocessingValue = self.resizeInPreSpinBox.value() # note that it says "pre"
        return postinfo

    def getParameters(self) -> Parameters:
        response = Parameters()
        response.preprocessing = self.getPreProcessingInfo()
        response.postprocessing = self.getPostProcessingInfo()
        response.deviceToUse = self.device
        response.samplesCount = self.previewSamplesNumberSpinBox.value()
        
        response.modelFile = self.modelCombobox.currentText()
        response.modelFolder = self.modelFolderTextbox.text()
        return response
        
    def setValuesFromParameters(self, parameters: Parameters):
        
        self.invertCheckbox.setChecked(parameters.preprocessing.invert)
        self.preprocessingLevels.setValues(parameters.preprocessing.preprocessingLevelsLowerValue, parameters.preprocessing.preprocessingLevelsUpperValue)
        self.resizeInPreSpinBox.setValue(parameters.preprocessing.resizeInPreprocessingValue)
        
        self.postprocessingLevels.setValues(parameters.postprocessing.postprocessingLevelsLowerValue, parameters.postprocessing.postprocessingLevelsUpperValue)
        self.sharpenCheckbox.setChecked(parameters.postprocessing.isSharpenChecked)
        self.sharpenFactorSpinBox.setValue(parameters.postprocessing.sharpenStrength)
        
        self.previewSamplesNumberSpinBox.setValue(parameters.samplesCount)
        if (parameters.deviceToUse.lower() == "npu" and self.npuRadioButton.isEnabled()):
            self.npuRadioButton.setChecked(True)
        elif(parameters.deviceToUse.lower() == "gpu" and self.gpuRadioButton.isEnabled()):
            self.gpuRadioButton.setChecked(True)
        else:
            self.cpuRadioButton.setChecked(True)

        if (exists(parameters.modelFolder)):
            self.updateFromFolderOrFile(parameters.modelFolder, parameters.modelFile)
        


    def setupModelChooserUi(self, mainLayout) -> None:
        
        self.basePath = self.getDefaultFolderForModels()
        

        self.executeButton = QPushButton("Run")
        
        self.executeButton.clicked.connect(self.updateGuiAndRun)
        
        self.invert = Krita.readSetting(settingGroup, invertSettingName, "false") == "true"
        self.invertCheckbox = QCheckBox("Invert input and output")
        self.invertCheckbox.stateChanged.connect(self.invertCheckboxChanged)
        self.invertCheckbox.setChecked(self.invert)
        
        self.modelCombobox = QComboBox()

        self.modelCombobox.currentTextChanged.connect(self.modelComboboxTextChanged)

        modelFileGroupBox = QGroupBox("Model:")
        modelFileGroupBoxLayout = QVBoxLayout()
        modelFileGroupBox.setLayout(modelFileGroupBoxLayout)

        modelFileComboboxLayout = QHBoxLayout()
        fileComboboxLabel = QLabel("File:")
        fileComboboxLabel.setFixedWidth(80)
        
        modelFileComboboxLayout.addWidget(fileComboboxLabel)
        
        modelFileComboboxLayout.addWidget(self.modelCombobox)
        modelFileGroupBoxLayout.addLayout(modelFileComboboxLayout)

        modelFolderLayout = QHBoxLayout()


        self.modelFolderChooserButton = QToolButton()
        self.modelFolderChooserButton.clicked.connect(self.chooseModelFolderButtonClicked)
        self.modelFolderChooserButton.setIcon(Krita.icon("folder"))

        self.modelFolderTextbox = QLineEdit()
        self.modelFolder = Krita.readSetting(settingGroup, modelFolderSettingName, "")
        
        self.modelFolderTextbox.setReadOnly(True)


        modelFolderLabel = QLabel("From folder:")
        modelFolderLabel.setFixedWidth(80)
        modelFolderSetToDefaultButton = QPushButton("Reset to default")
        modelFolderSetToDefaultButton.clicked.connect(self.folderSetToDefaultButtonClicked)
        modelFolderLayout.addWidget(modelFolderLabel)
        modelFolderLayout.addWidget(self.modelFolderTextbox)
        modelFolderLayout.addWidget(self.modelFolderChooserButton)
        modelFolderLayout.addWidget(modelFolderSetToDefaultButton)


        modelFileGroupBoxLayout.addLayout(modelFolderLayout)

        modelFileGroupBoxLayout.addWidget(QLabel("Note about the model:"))
        self.modelNoteLabel = QLabel("(no note)")
        self.modelNoteLabel.setWordWrap(True)
        
        modelFileGroupBoxLayout.addWidget(self.modelNoteLabel)
        

        mainLayout.addWidget(modelFileGroupBox)
        

    def setupDeviceToUseUi(self, mainLayout) -> None:

        self.device = Krita.readSetting(settingGroup, deviceSettingName, "CPU")
        
        
        self.deviceGroupBox = QGroupBox("Device to use:")

        self.cpuRadioButton = QRadioButton("CPU")
        self.npuRadioButton = QRadioButton("NPU")
        self.gpuRadioButton = QRadioButton("GPU")
        
        
        if self.device == "CPU":
            self.cpuRadioButton.setChecked(True)
        elif self.device == "NPU":
            self.npuRadioButton.setChecked(True)
        else:
            self.gpuRadioButton.setChecked(True)
        
        self.cpuRadioButton.toggled.connect(self.deviceRadioButtonChanged)
        self.npuRadioButton.toggled.connect(self.deviceRadioButtonChanged)
        self.gpuRadioButton.toggled.connect(self.deviceRadioButtonChanged)

        self.cpuRadioButton.setEnabled(False)
        self.npuRadioButton.setEnabled(False)
        self.gpuRadioButton.setEnabled(False)

        for availableDevice in ov.runtime.Core().get_available_devices():
            if availableDevice == "CPU":
                self.cpuRadioButton.setEnabled(True)
            elif availableDevice == "GPU":
                self.gpuRadioButton.setEnabled(True)
            elif availableDevice == "NPU":
                self.npuRadioButton.setEnabled(True)
            
        
        radioButtonLayout = QVBoxLayout()

        
        
        radioButtonLayout.addWidget(self.cpuRadioButton)
        radioButtonLayout.addWidget(self.gpuRadioButton)
        radioButtonLayout.addWidget(self.npuRadioButton)
        
        self.deviceGroupBox.setLayout(radioButtonLayout)
        mainLayout.addWidget(self.deviceGroupBox)

    def setupPreviewImagesUi(self, mainLayout) -> None:

        self.updatePreviewButton = QPushButton("Update Preview")

        self.previewProgressBar = QProgressBar()
        self.previewProgressBar.setMaximum(0)
        self.previewProgressBar.setVisible(False)
        

        

        examplePicturesLayout = QHBoxLayout()




        self.beforePictureLabel = PreviewImageLabel()
        imageSamplePreviewLabelSize = 200
        self.beforePictureLabel.setMinimumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        self.beforePictureLabel.setMaximumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        
        self.beforePictureLabel.setScaledContents(True)

        self.preprocessedPictureLabel = PreviewImageLabel()
        self.preprocessedPictureLabel.setMinimumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        self.preprocessedPictureLabel.setMaximumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        self.preprocessedPictureLabel.setScaledContents(True)


        self.afterPluginPictureLabel = PreviewImageLabel()
        self.afterPluginPictureLabel.setMinimumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        self.afterPluginPictureLabel.setMaximumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        self.afterPluginPictureLabel.setScaledContents(True)

        self.postprocessedPictureLabel = PreviewImageLabel()
        self.postprocessedPictureLabel.setMinimumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        self.postprocessedPictureLabel.setMaximumSize(imageSamplePreviewLabelSize, imageSamplePreviewLabelSize)
        self.postprocessedPictureLabel.setScaledContents(True)

        examplePicturesLayout.addWidget(self.beforePictureLabel)
        examplePicturesLayout.addWidget(self.preprocessedPictureLabel)
        
        
        examplePicturesLayout.addWidget(self.afterPluginPictureLabel)
        examplePicturesLayout.addWidget(self.postprocessedPictureLabel)



        
        mainLayout.addLayout(examplePicturesLayout)

        previewSamplesNumberLayout = QHBoxLayout()

        self.maxSamplesCount = 8
        self.previewSamplesNumberSlider = QSlider(Qt.Horizontal)
        self.previewSamplesNumberSlider.setRange(1, self.maxSamplesCount)
        self.previewSamplesNumberSlider.setMaximumWidth(300)
        self.previewSamplesNumberSlider.setTracking(False)
        self.previewSamplesNumberSlider.setTickPosition(QSlider.TicksAbove)

        self.previewSamplesNumberSpinBox = QSpinBox()
        self.previewSamplesNumberSpinBox.setRange(1, self.maxSamplesCount)

        previewSamplesNumberLayout.addWidget(QLabel("Preview size (more = slower update)"))
        previewSamplesNumberLayout.addWidget(self.previewSamplesNumberSlider)
        previewSamplesNumberLayout.addWidget(self.previewSamplesNumberSpinBox)

        self.previewSamplesNumberSlider.valueChanged.connect(self.previewSamplesSliderChanged)
        self.previewSamplesNumberSpinBox.valueChanged.connect(self.previewSamplesSpinBoxValueChanged)
        self.previewSamplesNumberSlider.setTickPosition(QSlider.TicksAbove)
        


        mainLayout.addWidget(self.updatePreviewButton)
        mainLayout.addWidget(self.previewProgressBar)

        mainLayout.addLayout(previewSamplesNumberLayout)

        sampleImage = self.getImageSample(self.imageSampleSize)
        samplePixmap = QPixmap(sampleImage)
        self.beforePictureLabel.setPixmap(samplePixmap)


        self.beforePictureLabel.labelClicked.connect(self.previewImageLabelClickedBefore)
        self.preprocessedPictureLabel.labelClicked.connect(self.previewImageLabelClickedPreprocessed)
        self.afterPluginPictureLabel.labelClicked.connect(self.previewImageLabelClickedAfterPlugin)
        self.postprocessedPictureLabel.labelClicked.connect(self.previewImageLabelClickedPostprocessed)
        
    def previewImageLabelClickedBefore(self):
        self.previewImageLabelClicked(0)
    
    def previewImageLabelClickedPreprocessed(self):
        self.previewImageLabelClicked(1)
    
    def previewImageLabelClickedAfterPlugin(self):
        self.previewImageLabelClicked(2)
    
    def previewImageLabelClickedPostprocessed(self):
        self.previewImageLabelClicked(3)
    
    


    def setupPreAndPostprocessingUi(self, mainLayout) -> None:
        self.preprocessingOptionsGroupBox = QGroupBox("Pre-processing:")
        preprocessingOptionsLayout = QVBoxLayout()
        
        self.preprocessingLevels = LevelsWidget()
        self.preprocessingLevels.changedValues.connect(self.updatePreview)
        preprocessingOptionsLayout.addWidget(self.preprocessingLevels)

        
        self.preprocessingOptionsGroupBox.setLayout(preprocessingOptionsLayout)
        
        self.postprocessingOptionsGroupBox = QGroupBox("Post-processing:")
        self.postprocessingLevels = LevelsWidget()
        postprocessingOptionsLayout = QVBoxLayout()

        
        resizeInPreLayout = QHBoxLayout()
        resizeInPreLayout.addWidget(QLabel("Scale before:"))
        self.resizeInPreSlider = QSlider()
        self.resizeInPreSlider.setOrientation(Qt.Horizontal)
        rangeScaler = self.slidersRangeScaler
        maxResize = 8
        minResize = 0.1
        self.resizeInPreSlider.setRange(int(minResize*rangeScaler), int(maxResize*rangeScaler)) # range * 1000
        self.resizeInPreSlider.setTickPosition(QSlider.TicksAbove)
        self.resizeInPreSlider.setTickInterval(int(rangeScaler/10))
        self.resizeInPreSlider.setValue(int(1.0*rangeScaler))
        self.resizeInPreSpinBox = QDoubleSpinBox()
        self.resizeInPreSpinBox.setRange(minResize, maxResize)
        self.resizeInPreSpinBox.setValue(1.0)
        self.resizeInPreSpinBox.valueChanged.connect(self.resizeInPreSpinBoxValueChanged)
        self.resizeInPreSlider.valueChanged.connect(self.resizeInPreSliderValueChanged)
        self.resizeInPreSlider.setTracking(False)

        
        

        resizeInPreLayout.addWidget(self.resizeInPreSlider)
        resizeInPreLayout.addWidget(self.resizeInPreSpinBox)

        
        resizeInPostLayout = QHBoxLayout()
        resizeInPostLayout.addWidget(QLabel("Scale after (based on scale before):"))
        self.resizeInPostSlider = QSlider(Qt.Horizontal)
        self.resizeInPostSlider.setRange(int(1/maxResize*rangeScaler), int(1/minResize*rangeScaler)) # range * 1000
        self.resizeInPostSlider.setTickPosition(QSlider.TicksAbove)
        self.resizeInPostSlider.setTickInterval(int(rangeScaler/10))
        self.resizeInPostSlider.setValue(int(1.0*rangeScaler))
        self.resizeInPostSlider.setDisabled(True)

        self.resizeInPostSpinBox = QDoubleSpinBox()
        self.resizeInPostSpinBox.setRange(1/maxResize, 1/minResize)
        self.resizeInPostSpinBox.setValue(1.0)
        
        resizeInPostLayout.addWidget(self.resizeInPostSlider)
        resizeInPostLayout.addWidget(self.resizeInPostSpinBox)
        
        
        resizeInPostWidget = QWidget()
        resizeInPostWidget.setLayout(resizeInPostLayout)

        resizeInPreWidget = QWidget()
        resizeInPreWidget.setLayout(resizeInPreLayout)



        self.sharpenCheckbox = QCheckBox("Apply Sharpen filter")
        self.sharpenCheckbox.stateChanged.connect(self.sharpenCheckboxStateChanged)

        preprocessingOptionsLayout.addWidget(resizeInPreWidget)
        self.resizeWarningLabel = QLabel("")
        preprocessingOptionsLayout.addWidget(self.resizeWarningLabel)


        postprocessingOptionsLayout.addWidget(resizeInPostWidget)
        postprocessingOptionsLayout.addWidget(self.postprocessingLevels)
        postprocessingOptionsLayout.addWidget(self.sharpenCheckbox)



        self.sharpenCheckbox.stateChanged.connect(self.updatePreview)
        self.postprocessingLevels.changedValues.connect(self.updatePreview)

        self.sharpenFactorLayout = QHBoxLayout()
        self.maxSharpenFactor = 2.0
        self.sharpenFactorSlider = QSlider(Qt.Horizontal)
        self.sharpenFactorSlider.setRange(0, int(self.slidersRangeScaler*self.maxSharpenFactor))


        self.sharpenFactorSpinBox = QDoubleSpinBox()
        self.sharpenFactorSpinBox.setRange(0.0, self.maxSharpenFactor)

        self.sharpenFactorSpinBox.setValue(1.0)
        self.sharpenFactorSlider.setValue(int(1.0*self.slidersRangeScaler))
        self.sharpenFactorSlider.setTracking(False)

        self.sharpenCheckbox.setChecked(False)
        self.sharpenFactorSlider.setEnabled(False)

        self.sharpenFactorSpinBox.valueChanged.connect(self.sharpenFactorSpinBoxValueChanged)
        self.sharpenFactorSlider.valueChanged.connect(self.sharpenFactorSliderValueChanged)

        self.sharpenFactorLayout.addWidget(self.sharpenFactorSlider)
        self.sharpenFactorLayout.addWidget(self.sharpenFactorSpinBox)

        postprocessingOptionsLayout.addLayout(self.sharpenFactorLayout)
        self.postprocessingOptionsGroupBox.setLayout(postprocessingOptionsLayout)
        
        
        mainLayout.addWidget(self.preprocessingOptionsGroupBox)
        mainLayout.addWidget(self.postprocessingOptionsGroupBox)

    def setupAdvancedOptionsUi(self, mainLayout) -> None:
        self.advancedOptionsGroupBox = QGroupBox("Advanced options:")
        advancedOptionsLayout = QHBoxLayout()
        
        advancedOptionsLayout.addWidget(self.invertCheckbox)

        self.advancedOptionsGroupBox.setLayout(advancedOptionsLayout)
        mainLayout.addWidget(self.advancedOptionsGroupBox)


    def setupBottomOfUi(self, mainLayout) -> None:
        

        line = QFrame()
        line.setFrameShape(QFrame.HLine)
        line.setFrameShadow(QFrame.Plain)
        line.setFrameShadow(QFrame.Plain)
        
        mainLayout.addWidget(line)


        infoLabel = QLabel("(The dialog will close after converting the image)")
        infoLabel.setAlignment(Qt.AlignCenter)
        mainLayout.addWidget(infoLabel)


        self.progressBar = QProgressBar()
        self.progressBar.setVisible(False)
        mainLayout.addWidget(self.progressBar)

        mainLayout.addWidget(self.executeButton)


    def action_triggered(self) -> None:
        
        self.createDialog()
        with self.executionMutex:
            self.executionsAllowedCounter = 1

    def previewSamplesSliderChanged(self, value):
        if self.previewSamplesNumberSpinBox.value() != value:
            self.previewSamplesNumberSpinBox.setValue(value)
            self.updatePreview()
        

    def previewSamplesSpinBoxValueChanged(self, value):
        if self.previewSamplesNumberSlider.value() != value:
            self.previewSamplesNumberSlider.setValue(value)
            self.updatePreview()

    def sharpenCheckboxStateChanged(self, state):
        if self.sharpenCheckbox.isChecked():
            self.sharpenFactorSlider.setEnabled(True)
            self.sharpenFactorSpinBox.setEnabled(True)
        else:
            self.sharpenFactorSlider.setEnabled(False)
            self.sharpenFactorSpinBox.setEnabled(False)

    def sharpenFactorSpinBoxValueChanged(self, value):
        if self.sharpenFactorSlider.value() != value*self.slidersRangeScaler:
            self.sharpenFactorSlider.setValue(int(value*self.slidersRangeScaler))
            self.updatePreview()

    def sharpenFactorSliderValueChanged(self, value):
        if self.sharpenFactorSpinBox.value()*self.slidersRangeScaler != value:
            self.sharpenFactorSpinBox.setValue(value/self.slidersRangeScaler)
            self.updatePreview()

    def previewImageLabelClicked(self, whichImage = 0):
        if (self.images == None):
            return
        dialog = PreviewImagesViewer(self.imagesInOrder, whichImage)
        dialog.exec()
        

    def updateModelsCombobox(self, guessCurrent=False) -> None:
        self.modelCombobox.blockSignals(True)
        self.modelCombobox.clear()
        for f in sorted(listdir(self.basePath)):
            if f.endswith(".xml") and isfile(join(self.basePath, f)):
                self.modelCombobox.addItem(f)
        
        self.modelCombobox.blockSignals(False)
        if guessCurrent:
            lastUsedModel = Krita.readSetting(settingGroup, modelSettingName, "")
            if (isfile(join(self.basePath, lastUsedModel))):
                self.modelCombobox.setCurrentText(lastUsedModel)
            else:
                self.modelCombobox.setCurrentIndex(0)

    def chooseModelFolderButtonClicked(self):
        currentFolder = self.modelFolderTextbox.text()
        folder = QFileDialog.getExistingDirectory(None, "Open Directory", currentFolder, QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks)
        if folder != "" and folder != currentFolder:
            self.updateFromFolderOrFile(folder + "/", "")
            


    def getDefaultFolderForModels(self) -> str:
        return os.path.dirname(os.path.realpath(__file__))

    def folderSetToDefaultButtonClicked(self):
        self.updateFromFolderOrFile(self.getDefaultFolderForModels(), "")

    def invertCheckboxChanged(self, value) -> None:
        self.invert = self.invertCheckbox.isChecked()
        Krita.writeSetting(settingGroup, invertSettingName, "true" if self.invert else "false")
        self.updatePreview()
        

    def resizeInPreSpinBoxValueChanged(self, value: float) -> None:
        if (value != self.resizeInPreSlider.value()):
            if value >= 2.0:
                self.resizeWarningLabel.setText("Warning: it's exponential; don't set the value too big or your device might not be able to handle it.")
            else:
                self.resizeWarningLabel.setText("")
            self.resizeInPreSlider.setValue(int(value*self.slidersRangeScaler))
            trueValue = value
            self.resizeInPostSlider.setValue(int(1/trueValue*self.slidersRangeScaler))
            self.resizeInPostSpinBox.setValue(1/trueValue)
            self.updatePreview()
            
    
    def resizeInPreSliderValueChanged(self, value: int) -> None:
        if (value != self.resizeInPreSpinBox.value()):
            trueValue = (float(value)/self.slidersRangeScaler)
            self.resizeInPreSpinBox.setValue(trueValue)
            self.resizeInPostSlider.setValue(int(1/trueValue*self.slidersRangeScaler))
            self.resizeInPostSpinBox.setValue(1/trueValue)
            self.updatePreview()

    def updateFieldsFromModelFileConfig(self, modelFile):
        config = self.readConfigForModel(modelFile)
        if config["invert"] is not None:
            self.invertCheckbox.setChecked(config["invert"])
        if config["note"] is not None:
            self.modelNoteLabel.setText(config["note"])
        else:
            self.modelNoteLabel.setText("(no note)")

    def updateFromFolderOrFile(self, folder, file):

        wasModelChanged = True
        if folder != "" and file != "":
            if os.path.exists(folder):
                self.basePath = folder
                self.modelFolderTextbox.setText(folder)
                self.updateModelsCombobox(False)
                if os.path.exists(os.path.join(folder, file)):
                    self.modelCombobox.setCurrentText(os.path.basename(file))
                else:
                    self.modelCombobox.setCurrentIndex(0)
                
        elif folder != "":
            file = self.modelCombobox.currentText()
            if os.path.exists(folder):
                self.basePath = folder
                self.modelFolderTextbox.setText(folder)
                self.updateModelsCombobox(False)
                if os.path.exists(os.path.join(folder, file)):
                    self.modelCombobox.setCurrentText(os.path.basename(file))
                else:
                    self.modelCombobox.setCurrentIndex(0)
        elif file != "": # file != "", folder == ""
            folder = self.modelFolderTextbox.text()
            if os.path.exists(os.path.join(folder, file)):
                self.modelCombobox.setCurrentText(os.path.basename(file))
            else:
                wasModelChanged = False
                # do nothing in that case
        else:
            wasModelChanged = False
        
        if wasModelChanged:
            self.modelFilename = os.path.join(self.modelFolderTextbox.text(), self.modelCombobox.currentText())
            self.updateFieldsFromModelFileConfig(self.modelFilename)
            self.updatePreview()


    def modelComboboxTextChanged(self, text):
        self.updateFromFolderOrFile("", text)
        
    def deviceRadioButtonChanged(self, value):
        if self.cpuRadioButton.isChecked():
            self.device = "CPU"
        elif self.gpuRadioButton.isChecked():
            self.device = "GPU"
        else:
            self.device = "NPU"
        Krita.writeSetting(settingGroup, deviceSettingName, self.device)
    
    def getImageSample(self, sampleSize) -> QImage:
        application = Krita.instance()
        currentDoc = application.activeDocument()

        docWidth = currentDoc.width()
        docHeight = currentDoc.height()

        x = int((docWidth-sampleSize)/2)
        y = int((docHeight-sampleSize)/2)
        w = int(sampleSize)
        h = int(sampleSize)

        sampleImage = currentDoc.projection(x, y, w, h)
        return sampleImage

        
        

    def updatePreviewButtonClicked(self) -> None:
        self.updatePreview(forced = True)

    def updatePreview(self, forced = False) -> None:
        if self.previewInProgress == True:
            self.previewUpdatesPending += 1
            return
        if (not self.setupFinished or self.modelCombobox.currentText() == ""):
            return

        preinfo = self.getPreProcessingInfo()

        sampleCount = self.previewSamplesNumberSpinBox.value()
        parameters = self.getParameters()

        if (not forced and parameters.isEqualToSamplePreviewWise(self.cachedLastPreviewUpdateParameters)):
            return
        
        self.previewInProgress = True
        self.cachedLastPreviewUpdateParameters = parameters
        parameters.saveToConfig()
        #print(f"~~~~~~~~~~~~~~~~~~~~~~ UPDATE PREVIEW BUTTON ~~~~~~~~~~~~~~~~~~~~~")
        

        self.updatePreviewButton.setVisible(False)
        self.previewProgressBar.setVisible(True)
        QCoreApplication.processEvents()
        scale = self.resizeInPreSpinBox.value()
        sample : QImage = self.getImageSample(int(self.imageSampleSize*sampleCount/scale))
        self.beforePictureLabel.setPixmap(QPixmap(sample))
        self.images = [sample]
        self.imagesInOrder[0] = sample

        numpyArrayResponse = conv.convertQImageToNumpy(sample)

        numpyArray = numpyArrayResponse[0]

        # get model
        self.modelFilename = join(parameters.modelFolder, parameters.modelFile)
        model = self.getModel()
        if model is None:
            self.cachedLastPreviewUpdateParameters = None
            self.previewInProgress = False
            return
        
        self.infer_request_for_preview = model.create_infer_request()

        
        self.cachedPreInfo = preinfo
        numpyArray = improc.applyPreProcessingNumpy(numpyArray, preinfo)
        afterPreProcessing = conv.convertNumpyToQImage(improc.uninvertIfNeeded(self.invert, numpyArray.copy()))
        self.images.append(afterPreProcessing)
        self.imagesInOrder[1] = afterPreProcessing
        self.preprocessedPictureLabel.setPixmap(QPixmap(afterPreProcessing))
        

        numpyArray = numpyArray.astype(model.input().get_element_type().to_dtype())

        (samples, fullExpectedShape) = conv.cutToSamples(numpyArray, model.input().get_shape(), sampleCount)
        self.inferredNumpy = np.zeros(fullExpectedShape)

        self.samplesGenerated = 0
        self.samplesToGenerate = len(samples)

        self.infer_queue = ov.AsyncInferQueue(model)
        self.infer_queue.set_callback(self.previewMultiSampleRequestCallback)
        
        for sample in samples:
            input_tensor = ov.Tensor(array=sample[2])
            self.infer_queue.start_async(input_tensor, (sample[0], sample[1]))
        
        


        #print(f"~~~~~~~~~~~~~~~~~~~~~~ UPDATE PREVIEW BUTTON (end of the first call) ~~~~~~~~~~~~~~~~~~~~~  {len(self.images)}")

    def previewMultiSampleRequestCallback(self, infer_request, userData):
        #print(f"~~~~~~~~~~~~~~~~~~~~~~ UPDATE PREVIEW BUTTON CALLBACK - NEW ~~~~~~~~~~~~~~~~~~~~~ {len(self.images)}, {self.samplesGenerated}/{self.samplesToGenerate}")
        output = infer_request.get_output_tensor().data
        sampleWidth = output.shape[2]
        sampleHeight = output.shape[3]
        
        self.inferredNumpy[:, :, (userData[0]*sampleWidth):((userData[0]+ 1)*sampleWidth), (userData[1]*sampleHeight):((userData[1]+ 1)*sampleHeight)] = output
        self.samplesGenerated += 1
        if self.samplesGenerated == self.samplesToGenerate:
            self.generateLastTwoPreviewsFromNumpy(self.inferredNumpy)



    def generateLastTwoPreviewsFromNumpy(self, numpyArray = np.ndarray):
        outputBuffer = numpyArray
        afterJustPlugin = conv.convertNumpyToQImage(improc.uninvertIfNeeded(self.invert, outputBuffer))
        self.afterPluginPictureLabel.setPixmap(QPixmap(afterJustPlugin))
        self.images.append(afterJustPlugin)
        self.imagesInOrder[2] = afterJustPlugin
        
        postinfo = self.getPostProcessingInfo()
        self.cachedPostInfo = postinfo
        outputBuffer = improc.applyPostProcessingNumpy(outputBuffer, postinfo)
        outputBuffer = conv.convertOutputToRGBANumpy(outputBuffer)
        preview = conv.convertNumpyToQImage(outputBuffer)

        self.postprocessedPictureLabel.setPixmap(QPixmap(preview))
        self.images.append(preview)
        self.imagesInOrder[3] = preview

        self.previewInProgress = False
        self.previewProgressBar.setVisible(False)
        self.updatePreviewButton.setVisible(True)
        if self.previewUpdatesPending > 0:
            self.previewUpdatesPending = 0
            self.updatePreview()


    def convertKritaImageToNumpy(self):
        application = Krita.instance()
        currentDoc = application.activeDocument()
        currentDoc.refreshProjection()

        width = conv.ensureSizeDivisableBy(currentDoc.width(), self.divisableBy)
        height = conv.ensureSizeDivisableBy(currentDoc.height(), self.divisableBy)

        projection = currentDoc.projection(0, 0, width, height) # that's QImage
        print(currentDoc.name())

        return conv.convertQImageToNumpy(projection)

    def levelsParametersToString(self, lower, upper) -> str:
        return f"({lower:.2f}; {upper:.2f})"
        

    def parametersToStringShort(self) -> str:
        preinfo = self.getPreProcessingInfo()
        postinfo = self.getPostProcessingInfo()
        
        response = ""
        if (self.invert):
            response += "invert = True, "
        
        if (preinfo.resizeInPreprocessingValue != 1.0):
            response += f"scale: {preinfo.resizeInPreprocessingValue:.2f}, "
        
        if preinfo.preprocessingLevelsLowerValue != 0.0 or preinfo.preprocessingLevelsUpperValue != 1.0:
            response += f"(pre: levels = {self.levelsParametersToString(preinfo.preprocessingLevelsLowerValue, preinfo.preprocessingLevelsUpperValue)}), "

        if postinfo.postprocessingLevelsLowerValue != 0.0 or postinfo.postprocessingLevelsUpperValue != 1.0:
            if postinfo.isSharpenChecked:
                response += f"(post: levels = {self.levelsParametersToString(postinfo.postprocessingLevelsLowerValue, postinfo.postprocessingLevelsUpperValue)}, sharpen = {postinfo.sharpenStrength})"
            else:
                response += f"(post: levels = {self.levelsParametersToString(postinfo.postprocessingLevelsLowerValue, postinfo.postprocessingLevelsUpperValue)}"
        elif postinfo.isSharpenChecked:
            response += f"(post: sharpen = {postinfo.sharpenStrength})"
        return response
        
        

    def convertOutputToKritaLayer(self, output, width, height):

        bytesArray = conv.convertOutputToLayerData(output)

        application = Krita.instance()
        currentDoc = application.activeDocument()

        root = currentDoc.rootNode()

        layerName = f"Generated by {os.path.basename(os.path.realpath(self.modelFilename))}: {self.parametersToStringShort()}"
        result = currentDoc.createNode(layerName, "paintLayer")
        
        result.setColorSpace("RGBA", "U8", "sRGB-elle-V2-srgbtrc.icc")
        
        result.setPixelData(bytesArray, 0, 0, width, height)
        

        root.addChildNode(result, None)
        
    
    def updateGuiAndRun(self):
        self.executeButton.setVisible(False)
        self.progressBar.setVisible(True)

        QCoreApplication.processEvents()
        self.run()

        
    def readConfigForModel(self, modelFile):
        
        configFile = modelFile.replace(".xml", ".yaml")
        config = self.readConfig(configFile)
        return config

    def readConfig(self, configFile):
        config = {}
        config["invert"] = None
        config["note"] = None


        if isfile(configFile):
            with open(configFile, "r") as file:
                for line in file.readlines():
                    if line.startswith("invert: "):
                        line = line.replace("invert: ", "")
                        
                        invert = True if ("true" in line) else False
                        config["invert"] = invert
                    if line.startswith("note: "):
                        line = line.replace("note: ", "")
                        line = line.replace("\n", "")
                        line = line.replace("\r", "")
                        
                        config["note"] = line
        
        return config


    def getModel(self):

        model = None
        compiled_model = None

        ie = ov.Core()

        try:
            model = ie.read_model(model=self.modelFilename)
            if model is None:
                QMessageBox.critical(self.dialog, "The model file cannot be read", f"The model file cannot be read: {self.modelFilename}.")
                return None
            config = {hints.performance_mode: hints.PerformanceMode.LATENCY}
            compiled_model = ie.compile_model(model=model, device_name=self.device)
        except Exception as e:
            print(e)
        finally:
            pass

        if (compiled_model is None):
            print("The model cannot be compiled.", file=sys.stderr)
            QMessageBox.critical(self.dialog, "The model cannot be compiled", f"The model {self.modelFilename} cannot be compiled to specified device: {self.device}.")
            
            return None
        
        return compiled_model

            



    def run(self):

        #print(f"################ Fast Sketch Plugin: RUN ################")
        if (self.currentInference is not None):
            return
        
        self.getParameters().saveToConfig()

        compiled_model = self.getModel()
        
        compiledModelInputShape = compiled_model.input().get_shape()
        
        partSize = compiledModelInputShape[2]
        assert partSize == compiledModelInputShape[3], "Part size must be equal in both dimensions"

        
        (data, width, height) = self.convertKritaImageToNumpy()

        self.currentInference = InferenceRunner(self)
        
        preinfo = self.getPreProcessingInfo()
        data = improc.applyPreProcessingNumpy(data, preinfo)
        data = conv.extendToBeDivisable(data, self.divisableBy)
        
        self.currentInference.startInference(compiled_model, data, partSize, margin=self.margin, divisableBy=self.divisableBy)


    def stepInInference(self, currentStep, allSteps):
        self.progressBar.setMaximum(allSteps)
        self.progressBar.setValue(currentStep)
        QCoreApplication.processEvents()


    def finishInference(self, outputData, width, height):
        #print(f"################ Fast Sketch Plugin: FINISH ################")
        postinfo = self.getPostProcessingInfo()
        outputData = improc.applyPostProcessingNumpy(outputData, postinfo)
        self.convertOutputToKritaLayer(outputData, outputData.shape[3], outputData.shape[2])
        self.currentInference = None
        self.dialog.close()


    def dialogFinished(self):
        #print(f"############## CLOSE EVENT ##############")
        if self.currentInference is not None:
            self.currentInference.cancelAllInference()
            self.currentInference = None

