import cv2
import numpy as np
import os
import glob
from multiprocessing import Pool

# 定义处理单个视频的函数
def process_video(video_path, output_base_dir, target_width=700):
    # 获取视频文件名（不带扩展名）
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    # 创建输出目录
    output_dir = os.path.join(output_base_dir, video_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # 读取视频
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}.")
        return

    frame_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # 将BGR图像转换为HSV颜色空间
        hsv_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)

        # 定义绿色的HSV范围（包括相近颜色）
        lower_green = np.array([35, 50, 50])  # 下限（H, S, V）
        upper_green = np.array([85, 255, 255])  # 上限（H, S, V）

        # 创建掩码，标记绿色区域
        mask = cv2.inRange(hsv_frame, lower_green, upper_green)

        # 形态学操作：去除绿色边缘残留
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)  # 闭运算，填充小孔洞
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)  # 开运算，去除小噪点

        # 高斯模糊：平滑掩码边缘
        mask = cv2.GaussianBlur(mask, (5, 5), 0)

        # 将原始图像转换为RGBA格式
        rgba_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2BGRA)

        # 只将绿色区域的透明度设置为0（完全透明），其他区域保持不变
        rgba_frame[mask > 0, 3] = 0

        # 调整图片大小：宽度固定为700，高度按比例缩放
        height, width = rgba_frame.shape[:2]
        scale_ratio = target_width / width
        new_height = int(height * scale_ratio)
        resized_frame = cv2.resize(rgba_frame, (target_width, new_height), interpolation=cv2.INTER_AREA)

        # 将图像保存为PNG
        output_path = os.path.join(output_dir, f"{frame_count:04d}.png")
        cv2.imwrite(output_path, resized_frame)

        frame_count += 1

    # 释放视频对象
    cap.release()
    print(f"Finished processing {video_path} ({frame_count} frames).")

# 主函数
def main():
    # 获取当前目录下所有 .mp4 文件
    video_files = glob.glob("*.mp4")
    if not video_files:
        print("No .mp4 files found in the current directory.")
        return

    # 设置输出基目录
    output_base_dir = "output"

    # 创建进程池，最大进程数为 20
    with Pool(processes=min(20, len(video_files))) as pool:
        # 并行处理每个视频文件
        pool.starmap(process_video, [(video, output_base_dir) for video in video_files])

    print("All videos processed.")

if __name__ == "__main__":
    main()