import logging
import time
from threading import Thread, Lock
from enum import Enum
import baco


logger = logging.getLogger(__name__)


class State(Enum):
    """
        This enum resembles the baco.PlayerState and adds a REOPENING state
    """
    STOPPED     = 0
    OPENING     = 1
    BUFFERING   = 2
    READY       = 3
    PLAYING     = 4
    DRAINING    = 5
    FINISHED    = 6
    REOPENING   = 7  # This state is not part of baco.PlayerState yet


class HlsPlayer:
    """
    Create a full buffer player using Baco library individual modules

    - support compensation mechanism
    """
    
    # Compensation Settings
    THRESHOLD_VERY_LOW          = 0.90
    THRESHOLD_LOW               = 0.95
    TARGET_LEVEL                = 0.99
    COMP_TIMER_INTERVAL_MS      = 1000
    AVG_COMP_MAX                = 0.0125  # 1.25%
    AVG_COMP_MILD               = 0.005   # 0.5%
    AVG_COMP_STEP               = 0.00025 # 0.025%
    MAX_IMM_COMP_RATE           = 0.001   # 0.1%

    def __init__(self, url, deviceName, sampleRate):
        
        #player related settings
        self.url                = url
        self.deviceName         = deviceName
        self.inputTimeout       = 500
        self.inputFifoDuration  = 0
        self.outputFifoDuration = 0
        self.readyThreshold     = 0
        self.outputFormat       = baco.FrameFormat(baco.ChannelLayout.STEREO, baco.SampleFormat.I16 ,sampleRate)
        self.inputFormat        = ""
        self.inputFormatOptions = {}

        self.lock               = Lock()

        self.playerThread       = None
        self.callback           = None
        self.streamInput        = None
        self.decoderChain       = None
        self.equalizerConfig    = None
        self.equalizer          = None
        self.equalizerEnabled   = False
        self.alsaSink           = None
        self.stopPlayer         = False
        self.codecId            = baco.UNKNOWN
        self.state              = State.STOPPED
        self.volume             = 1.0
        self.metadataCallback   = None

        # compensation counters
        self.avgDelayAcc     = 0
        self.avgDelayCount   = 0
        self.elapsed         = 0
        self.nextRecalc      = self.COMP_TIMER_INTERVAL_MS
        self.curAvgComp      = 0
        self.fillingSilence  = True 

        self.playerThread = Thread(target=self.__player)
        logger.info("start player thread")
        self.stopPlayer = False
        self.startTime = time.time()
        self.playerThread.start()
        self.timerThread = Thread(target=self.__timeTask)
        self.timerThread.start()


    def setInputTimeout(self, timeout):
        self.inputTimeout = timeout

    def setInputFifoDuration(self, inputFifoDuration):
        self.inputFifoDuration = inputFifoDuration

    def setOutputFifoDuration(self, outputFifoDuration):
        self.outputFifoDuration = outputFifoDuration

    def setReadyThreshold(self,threshold):
        self.readyThreshold = threshold

    def forceInputFormat(self, format, options):
        """ Define stream input options.
        options: ex dictionary -> {"sample_rate": sampleRate,"channels": channels}
        """
        self.inputFormat = format
        self.inputFormatOptions = options

    def forceCodec(self, codec):
        self.codecId = codec

    def setCallback(self, func):
        self.callback = func

    def connectSource(self):
        with self.lock:
            if self.state == State.STOPPED and self.playerThread is not None:
                self.__setState(State.OPENING)
        
    def connectAudio(self):
        with self.lock:
            if self.decoderChain is not None:
                logger.info("decoderChain unmute")
                self.decoderChain.mute(False)

    def disconnectAudio(self):
        with self.lock:
            if self.decoderChain is not None:
                logger.info("decoderChain mute")
                self.decoderChain.mute(True)

    def setVolume(self, vol):
        with self.lock:
            if  self.decoderChain is not None:
                logger.info(f"Set Volume:{vol}%")
                self.volume = vol/100.0
                self.decoderChain.setGain(self.volume, 0, 0)

    def getState(self):
        with self.lock:
            return self.state

    def setMetadataCallback(self, callback):
        self.metadataCallback = callback

    def setEqualizer(self, config:dict):
        """
        Set Equalizer
        :param config: equalizer json config band list
        """
        try:
            self.equalizerConfig = config
            logger.info("Add Equalizer configuration")
        except Exception as e:
            logger.error(f"setEqualizer: {e}")
    
    def setEqualizerFrequency(self,name, f):
        """
        Set Equalizer band filter central frequency
        :name name: band filter name 
        :param f: central frequency
        """
        with self.lock:
            try:
                if self.equalizer is not None:
                    self.equalizer.setFrequency(name,f)
                    logger.info(f"Set Filter {name} central frequency: {f}")
            except Exception as e:
                logger.error(f"{e}")    
            
    def setEqualizerQFactor(self,name, q):
        """
        Set Equalizer band filter Q-factor
        :name name: band filter name 
        :param q: q-factor
        """
        with self.lock:
            try:
                if self.equalizer is not None:
                    self.equalizer.setQFactor(name,q)
                    logger.info(f"Set Filter {name} Q-factor: {q}")
            except Exception as e:
                logger.error(f"{e}")


    def setEqualizerGain(self,name, g):
        """
        Set Equalizer band filter gain
        :name name: band filter name 
        :param g: gain
        """
        with self.lock:
            try:
                if self.equalizer is not None:
                    self.equalizer.setGain(name,g)
                    logger.info(f"Set Filter {name} gain: {g}")
            except Exception as e:
                logger.error(f"{e}")

    def enableEqualizer(self, enable):
        with self.lock:
            try:
                if self.equalizer is not None:
                    self.equalizer.enable(enable)
                    self.equalizerEnabled = enable
                    logger.info(f"Enable Equalizer: {enable}")
            except Exception as e:
                self.logger.error(e)

    def disconnect(self):
        with self.lock:
            if self.decoderChain is not None:
                self.decoderChain.mute(True)
        self.stopPlayer = True
        if self.playerThread.is_alive():
            self.playerThread.join()
        if self.timerThread.is_alive():
            self.timerThread.join()
        logger.info("HlsPlayer disconnected...")


    """ 
    ---------------------------------------------------------------------------------
    Player internal methods
    ---------------------------------------------------------------------------------
    """ 

    def __timeTask(self):
        while not self.stopPlayer:
            self.elapsed = int((time.time() - self.startTime) * 1000) #milliseconds
            time.sleep(0.2)

    def __setState(self, new_state):
        if new_state == self.state:
            return
        logger.info(f"State {self.state.name} --> {new_state.name}")
        self.state = new_state
        if self.callback is not None:
            thread = Thread(target=self.__execCallback, args=(new_state,))
            thread.start()

    def __execCallback(self, stat):
        if stat == State.PLAYING:
            logger.info("Stream is Alive")
            self.callback(True)
        if stat == State.STOPPED:
            logger.info("Stream is Dead")
            self.callback(False)


    def __player(self):
        """
        Player Thead
        """

        while not self.stopPlayer:
            
            while self.state == State.STOPPED:
                if self.stopPlayer:
                    break
                time.sleep(0.2)

            while self.state == State.OPENING:
                if self.stopPlayer:
                    break
                
                self.streamInput = baco.StreamInput()
                self.streamInput.setTimeout(self.inputTimeout)
                res = self.streamInput.open(self.url, self.inputFormat, self.inputFormatOptions, self.codecId)
                if res == baco.ReturnCode.OK:
                    logger.info(f"opened input: {self.url}")
                    
                    srcInfo = self.streamInput.getSourceInfo()
                    streamDesc = srcInfo.getStream(0)
                    if srcInfo.getNumStreams() == 0 or not streamDesc:
                        logger.warning("Input format is unknown! Retrying in 5s")
                        time.sleep(5)
                        continue
                    else:
                        if streamDesc.codec() == baco.CodecId.UNKNOWN:
                            logger.warning("Unknown CodecId!")
                            continue
                        
                        # create decoder chain and attach equalizer (if used)
                        with self.lock:
                            config = baco.Config(srcInfo.getStream(0), self.outputFormat)
                            config.inputFifoDuration = self.inputFifoDuration
                            config.outputFifoDuration = self.outputFifoDuration
                            self.decoderChain = baco.DecoderChain(config)
                            self.decoderChain.mute(True)
                            self.decoderChain.setGain(self.volume, 0, 0)
                            self.decoderChain.setGainMask(1)
                            self.alsaSink = baco.AlsaSink(self.deviceName,self.outputFormat,0,0)
                            self.metadataExporter = baco.MetadataExporter()
                            if self.metadataCallback is not None:
                                self.metadataExporter.setMetadataCallback(self.metadataCallback)
                            if self.equalizerConfig is not None:
                                self.equalizer = baco.Equalizer(self.outputFormat,self.equalizerConfig)            
                                self.decoderChain.attachSink(self.equalizer)
                                self.equalizer.setFrameSink(self.metadataExporter)
                                if self.equalizerEnabled:
                                    self.equalizer.enable(True)
                            else:
                                self.decoderChain.attachSink(self.metadataExporter)
                            
                            self.metadataExporter.setFrameSink(self.alsaSink)
                            self.__setState(State.BUFFERING)

                if res == baco.ReturnCode.E_TIMEOUT:            
                    logger.error(f"Timeout probing stream (t={self.inputTimeout}). Retrying in 1s")
                    time.sleep(1)

                if res == baco.ReturnCode.E_IO:
                    logger.error(f"Connection refused... Retrying in 1s")
                    time.sleep(1)


            while self.state == State.BUFFERING:
                if self.stopPlayer:
                    break

                with self.lock:
                    bufferLevel = self.decoderChain.getBufferLevel(1000)
                    
                logger.debug(f"Buffer Level: {bufferLevel}")
                    
                if bufferLevel >= self.readyThreshold:
                    with self.lock:
                        self.decoderChain.startOutput()
                        self.__setState(State.PLAYING)
                else:
                    res = self.streamInput.readPacket(self.decoderChain)
                    if res != baco.ReturnCode.OK:
                        if res == baco.ReturnCode.E_TIMEOUT:            
                            logger.error(f"Read timeout while buffering (t={self.inputTimeout})")
                        if res == baco.ReturnCode.E_EOF:
                            logger.error("Input is EOF while buffering")    
                        self.streamInput.close()
                        self.__setState(State.STOPPED)
                    
                    self.__applyCompensation()


            while self.state == State.PLAYING:
                if self.stopPlayer:
                    break
                
                res = self.streamInput.readPacket(self.decoderChain)
                if res != baco.ReturnCode.OK:
                    if res == baco.ReturnCode.E_EOF:            
                        logger.info("Input is EOF while playing")
                        # if fifo has data reopening stream else stop
                        with self.lock:
                            bufferLevel = self.decoderChain.getBufferLevel(1000)
                            if bufferLevel > 0:
                                self.streamInput.close()
                                self.__setState(State.REOPENING)
                            else:
                                self.streamInput.close()
                                self.__setState(State.STOPPED)
                    else:
                        if res == baco.ReturnCode.E_TIMEOUT:
                            logger.error(f"Read timeout while playing (t={self.inputTimeout})")
                            with self.lock:
                                self.streamInput.close()
                                self.__setState(State.STOPPED)
                                                     
                self.__applyCompensation()
                            
            while self.state == State.REOPENING:
                if self.stopPlayer:
                    break

                with self.lock:
                    res = self.streamInput.open(self.url, self.inputFormat, self.inputFormatOptions, self.codecId)
                    if res == baco.ReturnCode.OK: 
                        self.__setState(State.PLAYING)
                    else: 
                        bufferLevel = self.decoderChain.getBufferLevel(1000)
                        if bufferLevel < 100:
                            self.__setState(State.STOPPED)
            

        logger.info("stop player thread")
        self.playerThread = None

    
    def __applyCompensation(self):

        with self.lock:        
            buffLevel = int(self.decoderChain.getBufferLevel(1000))

            if not self.fillingSilence and ((100*buffLevel/self.readyThreshold)<10):
                self.fillingSilence = True
                logger.info(f"buffer level critical. start filling - lv={buffLevel}ms th={self.readyThreshold}ms")
                self.decoderChain.setDispatchMode(baco.FORCE_ON)
        
            elif self.fillingSilence and buffLevel>self.readyThreshold :
                logger.info(f"buffer refill complete (lv={buffLevel}ms). resuming audio")
                self.fillingSilence = False
                self.decoderChain.setDispatchMode(baco.FORCE_OFF)

            self.avgDelayAcc += self.decoderChain.getBufferLevel(1000000)
            self.avgDelayCount += 1

            cur_elapsed = self.elapsed
            if cur_elapsed >= self.nextRecalc:
                while self.nextRecalc <= cur_elapsed:
                    self.nextRecalc += self.COMP_TIMER_INTERVAL_MS

                cur_buf_delay = round(self.avgDelayAcc / self.avgDelayCount)
                full_buf_delay = self.decoderChain.getBufferLevel(1000000)

                self.avgDelayAcc = 0
                self.avgDelayCount = 0

                # Prevent zero division
                if full_buf_delay == 0:
                    percent = 0
                else:
                    percent = cur_buf_delay / full_buf_delay
                comp_changed = False

                if percent < self.THRESHOLD_VERY_LOW:
                    if self.curAvgComp < self.AVG_COMP_MAX:
                        self.curAvgComp += self.AVG_COMP_STEP
                        comp_changed = True
                elif self.THRESHOLD_LOW < percent < self.TARGET_LEVEL:
                    if self.curAvgComp > self.AVG_COMP_MILD:
                        self.curAvgComp -= self.AVG_COMP_STEP
                        comp_changed = True
                elif percent > self.TARGET_LEVEL:
                    if self.curAvgComp > 0:
                        self.curAvgComp -= self.AVG_COMP_STEP
                        comp_changed = True
                    else:
                        self.curAvgComp = 0

                target_delay = round(self.TARGET_LEVEL * full_buf_delay)
                instant_comp = (target_delay - cur_buf_delay) // 4
                max_delta_abs = round(1000000 * self.MAX_IMM_COMP_RATE)

                if instant_comp > max_delta_abs:
                    instant_comp = max_delta_abs
                elif instant_comp < -max_delta_abs:
                    instant_comp = -max_delta_abs

                if comp_changed:
                    if self.curAvgComp > 0:
                        logger.info(f"average compensation --> x{1 + self.curAvgComp:.4f} speed")
                        self.decoderChain.setAvgCompRate(True, 1 + self.curAvgComp)
                    else:
                        logger.info("average compensation disabled")
                        self.decoderChain.setAvgCompRate(False, 0)
                elif instant_comp != 0:
                    logger.debug(f"instant compensation: {instant_comp / 1000.0:+.1f}ms")
                    self.decoderChain.addCompensation(instant_comp, 1000000)
