import os
import cv2
import time
import numpy
import logging
import uvicorn
import requests
import threading
import subprocess
from abc import abstractmethod
from fastapi import FastAPI, WebSocket

logging.basicConfig(format="%(asctime)s - %(message)s")


class YUVGetter:
    @abstractmethod
    def __init__(self, source_url: str):
        self.source_url = source_url

    @abstractmethod
    def prepare_yuv_source(self):
        pass

    @abstractmethod
    def get_yuv(self):
        pass


class FolderYUVGetter(YUVGetter):
    def __init__(self, yuv_dir: str):
        super(FolderYUVGetter, self).__init__(yuv_dir)
        self.yuv_list = []
        self.yuv_index = 0

    def prepare_yuv_source(self):
        for file_name in os.listdir(self.source_url):
            with open(os.path.join(self.source_url, file_name), 'rb') as f:
                self.yuv_list.append(f.read())

    def get_yuv(self):
        if len(self.yuv_list) == 0:
            raise RuntimeError("无法获取YUV数据")
        yuv_bytes = self.yuv_list[self.yuv_index]
        self.yuv_index += 1
        if self.yuv_index == len(self.yuv_list):
            self.yuv_index = 0
        return yuv_bytes


class FaceYUVGetter(YUVGetter):
    def __init__(self, face_url: str):
        super(FaceYUVGetter, self).__init__(face_url)
        self.session = requests.session()

    def prepare_yuv_source(self):
        pass

    def get_yuv(self):
        jpg_bytes = self.session.post(self.source_url).content
        jpg_nparray = numpy.frombuffer(jpg_bytes, numpy.uint8)
        img = cv2.imdecode(jpg_nparray, cv2.IMREAD_COLOR)
        # 旋转90度
        img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
        yuv_img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV_I420)
        return yuv_img.tobytes()


class RTMPYuvGetter(YUVGetter):
    def __init__(self, rtmp_url: str):
        super(RTMPYuvGetter, self).__init__(rtmp_url)
        self.width = 960
        self.height = 540
        self.frame_size = int(540 * 960 * 1.5)
        self.frame = None

    def prepare_yuv_source(self):
        command = [
            'ffmpeg',
            '-loglevel', 'quiet',
            '-i', self.source_url,
            '-f', 'rawvideo',
            '-pix_fmt', 'yuv420p',
            '-vcodec', 'rawvideo', '-'
        ]

        pipe = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10 ** 8)

        while True:
            self.frame = pipe.stdout.read(self.frame_size)
            time.sleep(0.001)

    def get_yuv(self):
        if self.frame is None:
            raise RuntimeError("无法获取YUV数据")
        return self.frame


class MyServer:
    app = FastAPI()

    def __init__(self, host: str, port: int, getter: YUVGetter):
        self.host = host
        self.port = port
        self.getter = getter
        self.setup_routes()

    def setup_routes(self):
        @self.app.websocket("/")
        async def yuv(ws: WebSocket):
            t0 = time.time()
            ret = 0
            await ws.accept()
            while True:
                await ws.receive_text()
                yuv_bytes = self.getter.get_yuv()
                ret += 1
                if ret % 100 == 0:
                    t1 = time.time()
                    logging.error(f"延迟 => {round((t1 - t0) * 1000 / 100, 3)}ms")
                    t0 = t1
                await ws.send_bytes(yuv_bytes)

    def run_forever(self):
        print(f"启动YUV服务器: http://{self.host}:{self.port}")
        uvicorn.run(self.app, host=self.host, port=self.port)


if __name__ == "__main__":
    # yuv_getter = FolderYUVGetter("D:\\shuziren\\test\\yuv_960x540")
    # yuv_getter = FaceYUVGetter("http://127.0.0.1:17000/vcam/getCameraImage/XX")
    yuv_getter = RTMPYuvGetter("rtmp://127.0.0.1:1935/live/test")

    my_server = MyServer(host="0.0.0.0", port=9876, getter=yuv_getter)

    threading.Thread(target=yuv_getter.prepare_yuv_source).start()

    threading.Thread(target=my_server.run_forever).start()

    while True:
        time.sleep(1)
