import json
import re
import argparse
import os
from typing import List, Dict, Tuple, Optional, Union


def convert_time_format(time_str: str) -> str:
    """转换时间格式（兼容SRT和VTT）为目标格式（00:00:00.000）"""
    # 统一处理逗号和点分隔符
    if ',' in time_str:
        time_str = time_str.replace(',', '.')

    # 确保毫秒部分为3位
    if '.' in time_str:
        main, ms = time_str.split('.', 1)
        ms = ms.ljust(3, '0')[:3]  # 补零到3位并截断
        # 检查main部分是否缺少小时
        parts = main.split(':')
        if len(parts) == 2:
            # MM:SS格式，补全小时
            main = f"00:{parts[0].zfill(2)}:{parts[1].zfill(2)}"
        elif len(parts) == 3:
            # HH:MM:SS格式，确保每段两位
            main = f"{parts[0].zfill(2)}:{parts[1].zfill(2)}:{parts[2].zfill(2)}"
        return f"{main}.{ms}"
    elif ':' in time_str:
        # 处理只有分:秒格式的情况
        parts = time_str.split(':')
        if len(parts) == 2:  # 分:秒格式
            return f"00:{parts[0].zfill(2)}:{parts[1].zfill(2)}.000"
        return time_str + ".000"
    else:
        return time_str + ".000"


def parse_srt(content: str) -> List[Dict[str, str]]:
    """解析SRT内容为字典列表"""
    blocks = re.split(r'\n\s*\n', content.strip())  # 按空行分割块
    entries = []

    for block in blocks:
        lines = block.strip().splitlines()
        if len(lines) < 3:  # 确保有编号、时间轴和内容
            continue

        # 解析时间轴（格式：00:00:00,000 --> 00:00:00,000）
        time_match = re.match(r'(\d{2}:\d{2}:\d{2}[.,]\d{3}) --\> (\d{2}:\d{2}:\d{2}[.,]\d{3})', lines[1])
        if not time_match:
            continue

        start_time = convert_time_format(time_match.group(1))
        end_time = convert_time_format(time_match.group(2))
        src_text = lines[2]
        trans_texts = lines[3:]

        entries.append({
            "start": start_time,
            "end": end_time,
            "srcText": src_text,
            "transTexts": trans_texts
        })

    return entries


def parse_vtt(content: str) -> List[Dict[str, str]]:
    """解析VTT内容为字典列表"""
    # 移除WEBVTT头部和可能的样式信息
    content = re.sub(r'^WEBVTT.*?\n\n', '', content, flags=re.DOTALL | re.IGNORECASE)

    # 按空行分割块
    blocks = re.split(r'\n\s*\n', content.strip())
    entries = []

    for block in blocks:
        lines = block.strip().splitlines()
        if not lines:
            continue

        # 检查是否有时间轴行（格式：00:00:00.000 --> 00:00:00.000 或 00:00.000 --> 00:00.000）
        time_match = re.match(r'(\d{2}:\d{2}:\d{2}[.,]\d{3}|\d{1,2}:\d{2}[.,]\d{3}) --\> (\d{2}:\d{2}:\d{2}[.,]\d{3}|\d{1,2}:\d{2}[.,]\d{3})', lines[0])
        if time_match:
            # 时间轴在第一行
            start_time = convert_time_format(time_match.group(1))
            end_time = convert_time_format(time_match.group(2))
            text_lines = lines[1:]
        elif len(lines) > 1:
            # 尝试第二行是否为时间轴
            time_match = re.match(r'(\d{2}:\d{2}:\d{2}[.,]\d{3}|\d{1,2}:\d{2}[.,]\d{3}) --\> (\d{2}:\d{2}:\d{2}[.,]\d{3}|\d{1,2}:\d{2}[.,]\d{3})', lines[1])
            if time_match:
                start_time = convert_time_format(time_match.group(1))
                end_time = convert_time_format(time_match.group(2))
                text_lines = lines[2:]
            else:
                continue
        else:
            continue

        def remove_cue_settings(text):  # 移除cue设置
            text = re.sub(r'<[^>]+>', '', text)  # 移除HTML标签
            text = re.sub(r'^\s*[-–]?\s*', '', text)  # 移除行首的破折号
            text = re.sub(r'^\s*<v\s+[^>]+>\s*', '', text, flags=re.IGNORECASE)  # 移除说话人标签
            return text

        src_text = text_lines[0]
        trans_texts = text_lines[1:]
        src_text = remove_cue_settings(src_text)
        for i in range(len(trans_texts)):
            trans_texts[i] = remove_cue_settings(trans_texts[i])

        entries.append({
            "start": start_time,
            "end": end_time,
            "srcText": src_text,
            "transTexts": trans_texts
        })

    return entries


def align_entries(src_entries: List[Dict], dst_entries: List[Dict]) -> List[Tuple[Optional[Dict], Optional[Dict]]]:
    """对齐源语言和目标语言条目"""
    aligned = []

    # 简单对齐：按顺序匹配
    min_len = min(len(src_entries), len(dst_entries))
    for i in range(min_len):
        aligned.append((src_entries[i], dst_entries[i]))

    # 处理剩余条目
    if len(src_entries) > min_len:
        for i in range(min_len, len(src_entries)):
            aligned.append((src_entries[i], None))

    if len(dst_entries) > min_len:
        for i in range(min_len, len(dst_entries)):
            aligned.append((None, dst_entries[i]))

    return aligned


def detect_file_type(filename: str) -> str:
    """检测文件类型（SRT或VTT）"""
    ext = os.path.splitext(filename)[1].lower()
    if ext == '.srt':
        return 'srt'
    elif ext == '.vtt':
        return 'vtt'
    else:
        raise ValueError(f"不支持的文件类型: {filename}。只支持 .srt 和 .vtt 文件")


def main():
    parser = argparse.ArgumentParser(description='转换字幕文件到Speaker JSON格式')
    parser.add_argument('inputs', nargs='+', help='输入文件路径（1-2个文件）')
    parser.add_argument('output', help='输出JSON文件路径')
    parser.add_argument('--src_lang', default='zh', help='源语言代码 (默认: zh)')
    parser.add_argument('--dst_langs', default='en', help='逗号分隔的目标语言代码 (默认: en)')
    parser.add_argument('--alternate', action='store_true', help='在条目间交替说话人')
    parser.add_argument('--default_speaker', default='speaker_0', help='默认说话人ID (默认: speaker_0)')
    parser.add_argument('--speakers', default='male,female', help='逗号分隔的说话人性别 (默认: male,female)')
    parser.add_argument('--time_tolerance', type=float, default=0.5,
                        help='匹配条目的时间差容差（秒）(默认: 0.5)')
    args = parser.parse_args()

    # 验证输入文件数量
    if len(args.inputs) not in [1, 2]:
        parser.error("需要1或2个输入文件")

    # 检测文件类型
    file_types = [detect_file_type(f) for f in args.inputs]

    # 如果是双文件模式，确保文件类型一致
    if len(file_types) == 2 and file_types[0] != file_types[1]:
        parser.error("双文件模式下，输入文件必须是相同类型（都是SRT或都是VTT）")

    # 读取和解析文件
    entries_list = []
    for i, filename in enumerate(args.inputs):
        with open(filename, 'r', encoding='utf-8') as f:
            content = f.read()

        if file_types[i] == 'srt':
            entries = parse_srt(content)
        else:  # vtt
            entries = parse_vtt(content)

        entries_list.append(entries)

    # 处理单文件或双文件模式
    if len(entries_list) == 1:
        # 单文件模式
        src_entries = entries_list[0]
        dst_entries = []  # 空列表，表示无翻译
    else:
        # 双文件模式
        src_entries, dst_entries = entries_list

    # 对齐条目（双文件模式）
    if dst_entries:
        aligned_entries = align_entries(src_entries, dst_entries)
    else:
        # 单文件模式，每个源条目对应None作为目标
        aligned_entries = [(entry, None) for entry in src_entries]

    # 解析说话人配置
    speaker_genders = args.speakers.split(',')
    speakers = []
    for i, gender in enumerate(speaker_genders):
        speakers.append({"Id": f"speaker_{i}", "Gender": gender.strip()})

    # 解析目标语言
    dst_langs = args.dst_langs.split(',')

    # 构建目标结构
    result = {
        "SrcLang": args.src_lang,
        "DstLangs": dst_langs,
        "Speakers": speakers,
        "Clips": []
    }

    # 生成Clips
    for idx, (src_entry, dst_entry) in enumerate(aligned_entries):
        speaker_id = args.default_speaker
        if args.alternate:
            speaker_id = f"speaker_{idx % len(speakers)}"  # 循环使用所有说话人

        # 确保至少有一个条目存在
        if src_entry is None and dst_entry is None:
            continue

        # 创建Clip对象
        clip = {
            "SpeakerId": speaker_id,
            "DstTexts": {}
        }

        # 添加源语言信息
        if src_entry:
            clip["TextStartTime"] = src_entry["start"]
            clip["TextEndTime"] = src_entry["end"]
            clip["SrcText"] = src_entry["srcText"]
            # 单文件模式下从源语言条目中获取目标语言文本
            for i in range(len(src_entry["transTexts"])):
                if i < len(dst_langs):
                    clip["DstTexts"][dst_langs[i]] = src_entry["transTexts"][i]
        else:
            # 如果源语言条目缺失，使用目标语言的时间
            clip["TextStartTime"] = dst_entry["start"]
            clip["TextEndTime"] = dst_entry["end"]
            clip["SrcText"] = ""  # 源文本留空

        if dst_entry:  # 双文件模式下从目标语言条目中获取目标语言文本
            clip["DstTexts"][dst_langs[0]] = dst_entry["srcText"]

        result["Clips"].append(clip)

    # 写入输出文件
    with open(args.output, 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=2)

    print(f"转换完成！共处理 {len(result['Clips'])} 个条目")


if __name__ == "__main__":
    main()
