#!/usr/bin/env python3
import argparse
import json
import os
import re
import sys
import urllib.error
import urllib.request


def read_json(path):
    with open(path, "r", encoding="utf-8") as file:
        return json.load(file)


def read_config(path):
    if not path or not os.path.exists(path):
        return {}
    try:
        data = read_json(path)
    except Exception:
        return {}
    return data if isinstance(data, dict) else {}


def rect_union(left, right):
    if left is None:
        return right
    x1 = min(left[0], right[0])
    y1 = min(left[1], right[1])
    x2 = max(left[0] + left[2], right[0] + right[2])
    y2 = max(left[1] + left[3], right[1] + right[3])
    return [x1, y1, x2 - x1, y2 - y1]


def no_leading_space(text):
    return bool(text) and text[0] in ".,;:!?)]}"


def should_insert_space(previous, current):
    if not previous or not current or no_leading_space(current.get("text", "")):
        return False
    left = previous["box"]
    right = current["box"]
    gap = right[0] - (left[0] + left[2])
    threshold = max(3.0, min(left[3], right[3]) * 0.28)
    return gap > threshold


def source_segments(tokens):
    lines = {}
    for token in tokens:
        if not isinstance(token, dict):
            continue
        text = str(token.get("text") or "").strip()
        box = token.get("box") or token.get("bbox")
        if not text or not isinstance(box, list) or len(box) != 4:
            continue
        line = int(token.get("line", 0))
        lines.setdefault(line, []).append(
            {
                "text": text,
                "box": [float(v) for v in box],
                "index": int(token.get("index", 0)),
            }
        )

    segments = []
    for line_id, line_tokens in sorted(lines.items()):
        line_tokens.sort(key=lambda item: (item["index"], item["box"][0]))
        text = ""
        box = None
        previous = None
        for token in line_tokens:
            if should_insert_space(previous, token):
                text += " "
            text += token["text"]
            box = rect_union(box, token["box"])
            previous = token
        if text.strip() and box:
            segments.append({"id": line_id, "text": text.strip(), "box": box})
    return segments


def chat_completion(config, target_language, segments):
    translation_config = config.get("translation") if isinstance(config.get("translation"), dict) else {}
    api_base = (
        translation_config.get("apiBase")
        or translation_config.get("baseUrl")
        or os.environ.get("MARK_SHOT_LLM_API_BASE")
        or os.environ.get("OPENAI_BASE_URL")
        or os.environ.get("OPENAI_API_BASE")
        or "https://api.openai.com/v1"
    ).rstrip("/")
    model = (
        translation_config.get("model")
        or os.environ.get("MARK_SHOT_LLM_MODEL")
        or os.environ.get("OPENAI_MODEL")
        or "gpt-4o-mini"
    )
    api_key = translation_config.get("apiKey")
    api_key_env = translation_config.get("apiKeyEnv") or "OPENAI_API_KEY"
    if not api_key:
        api_key = os.environ.get(api_key_env) or os.environ.get("MARK_SHOT_LLM_API_KEY")
    if not api_key:
        raise RuntimeError(f"missing api key: set {api_key_env} or translation.apiKey")

    timeout = float(translation_config.get("timeoutSeconds") or os.environ.get("MARK_SHOT_LLM_TIMEOUT_SECONDS") or 60)
    system_prompt = translation_config.get("systemPrompt") or (
        "You translate OCR text segments. Preserve meaning, keep segment count and ids unchanged, "
        "and return only valid JSON."
    )
    user_prompt = {
        "target_language": target_language,
        "instructions": [
            "Translate each segment into target_language.",
            "Return JSON exactly as {\"translations\":[{\"id\":0,\"text\":\"...\"}]} with no markdown.",
            "Do not add explanations.",
        ],
        "segments": [{"id": item["id"], "text": item["text"]} for item in segments],
    }
    payload = {
        "model": model,
        "temperature": float(translation_config.get("temperature", 0.2)),
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": json.dumps(user_prompt, ensure_ascii=False)},
        ],
    }

    data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
    request = urllib.request.Request(
        f"{api_base}/chat/completions",
        data=data,
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        },
        method="POST",
    )
    try:
        with urllib.request.urlopen(request, timeout=timeout) as response:
            result = json.loads(response.read().decode("utf-8"))
    except urllib.error.HTTPError as exc:
        detail = exc.read().decode("utf-8", errors="replace")
        raise RuntimeError(f"llm http {exc.code}: {detail[:500]}") from exc

    content = result["choices"][0]["message"]["content"]
    return parse_translation_content(content)


def parse_translation_content(content):
    content = content.strip()
    try:
        data = json.loads(content)
    except json.JSONDecodeError:
        match = re.search(r"\{.*\}", content, re.S)
        if not match:
            raise
        data = json.loads(match.group(0))

    translations = data.get("translations") if isinstance(data, dict) else None
    if not isinstance(translations, list):
        raise RuntimeError("llm response missing translations array")
    return {int(item["id"]): str(item.get("text") or "").strip() for item in translations if isinstance(item, dict) and "id" in item}


def build_output_tokens(segments, translations):
    tokens = []
    for index, segment in enumerate(segments):
        translated = translations.get(segment["id"], "").strip()
        if not translated:
            translated = segment["text"]
        tokens.append(
            {
                "text": translated,
                "box": segment["box"],
                "line": index,
                "index": 0,
                "confidence": 1.0,
            }
        )
    return tokens


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True)
    parser.add_argument("--target-language", default="")
    parser.add_argument("--config", default="")
    parser.add_argument("--format", choices=("json",), default="json")
    args = parser.parse_args()

    config = read_config(args.config)
    input_data = read_json(args.input)
    target_language = args.target_language or input_data.get("targetLanguage") or "Simplified Chinese"
    tokens = input_data.get("tokens") if isinstance(input_data, dict) else []
    segments = source_segments(tokens if isinstance(tokens, list) else [])
    if not segments:
        print(json.dumps({"backend": "openai-compatible", "tokens": [], "errors": ["no source text"]}, ensure_ascii=False))
        return 1

    try:
        translations = chat_completion(config, target_language, segments)
        output_tokens = build_output_tokens(segments, translations)
        print(json.dumps({"backend": "openai-compatible", "tokens": output_tokens, "errors": []}, ensure_ascii=False))
        return 0
    except Exception as exc:
        print(json.dumps({"backend": "openai-compatible", "tokens": [], "errors": [str(exc)]}, ensure_ascii=False))
        return 1


if __name__ == "__main__":
    sys.exit(main())
