#!/usr/bin/env python3
import argparse
import contextlib
import csv
import json
import os
import shutil
import subprocess
import sys


def default_model_dir():
    xdg_data_home = os.environ.get("XDG_DATA_HOME", "").strip()
    base = xdg_data_home if xdg_data_home else os.path.expanduser("~/.local/share")
    return os.path.join(base, "mark-shot", "models")


def configured_model_dir():
    return os.environ.get("MARK_SHOT_OCR_MODEL_DIR", "").strip() or default_model_dir()


def maybe_reexec_venv():
    if os.environ.get("MARK_SHOT_OCR_REEXEC") == "1":
        return
    if os.environ.get("MARK_SHOT_OCR_NO_VENV") == "1":
        return

    preferred_python = os.environ.get("MARK_SHOT_OCR_PYTHON")
    if not preferred_python:
        preferred_python = os.path.expanduser("~/.local/share/mark-shot/ocr-venv/bin/python")

    if not os.path.exists(preferred_python):
        return
    if os.path.abspath(sys.executable) == os.path.abspath(preferred_python):
        return

    env = os.environ.copy()
    env["MARK_SHOT_OCR_REEXEC"] = "1"
    os.execve(preferred_python, [preferred_python, __file__, *sys.argv[1:]], env)


def as_plain(value):
    if hasattr(value, "tolist"):
        return value.tolist()
    if isinstance(value, tuple):
        return [as_plain(item) for item in value]
    if isinstance(value, list):
        return [as_plain(item) for item in value]
    return value


def rect_from_box(box):
    box = as_plain(box)
    if not isinstance(box, list) or len(box) < 2:
        return None

    if len(box) == 4 and all(isinstance(item, (int, float)) for item in box):
        x, y, width, height = box
        return [float(x), float(y), float(width), float(height)]

    points = []
    for point in box:
        if isinstance(point, (list, tuple)) and len(point) >= 2:
            points.append((float(point[0]), float(point[1])))
    if not points:
        return None

    xs = [point[0] for point in points]
    ys = [point[1] for point in points]
    left = min(xs)
    top = min(ys)
    return [left, top, max(xs) - left, max(ys) - top]


def append_token(tokens, text, box, line, index, confidence=0.0):
    text = str(text or "").strip()
    rect = rect_from_box(box)
    if not text or rect is None or rect[2] <= 0 or rect[3] <= 0:
        return index

    tokens.append(
        {
            "text": text,
            "box": rect,
            "line": int(line),
            "index": int(index),
            "confidence": float(confidence or 0.0),
        }
    )
    return index + 1


def normalize_token_order(tokens):
    if not tokens:
        return tokens

    def center_y(token):
        box = token["box"]
        return box[1] + box[3] / 2.0

    def line_threshold(line_tokens, token):
        heights = [item["box"][3] for item in line_tokens]
        heights.append(token["box"][3])
        return max(6.0, sum(heights) / len(heights) * 0.65)

    lines = []
    for token in sorted(tokens, key=lambda item: (center_y(item), item["box"][0])):
        target = None
        for line in lines:
            line_center = sum(center_y(item) for item in line) / len(line)
            if abs(center_y(token) - line_center) <= line_threshold(line, token):
                target = line
                break
        if target is None:
            lines.append([token])
        else:
            target.append(token)

    ordered = []
    for line_index, line in enumerate(lines):
        line.sort(key=lambda item: item["box"][0])
        for token_index, token in enumerate(line):
            token["line"] = line_index
            token["index"] = token_index
            ordered.append(token)
    return ordered


def load_rapidocr_engine():
    try:
        from rapidocr import RapidOCR

        return RapidOCR, "rapidocr"
    except Exception:
        pass

    try:
        from rapidocr_onnxruntime import RapidOCR

        return RapidOCR, "rapidocr_onnxruntime"
    except Exception as exc:
        raise RuntimeError(str(exc)) from exc


def build_rapidocr_engine(engine_class, backend):
    if backend != "rapidocr":
        return engine_class()

    try:
        from rapidocr import ModelType, OCRVersion

        params = {"Global.log_level": "critical"}
        version = os.environ.get("MARK_SHOT_OCR_VERSION", "PP-OCRv5")
        if version == "PP-OCRv5":
            params.update(
                {
                    "Det.ocr_version": OCRVersion.PPOCRV5,
                    "Cls.ocr_version": OCRVersion.PPOCRV5,
                    "Rec.ocr_version": OCRVersion.PPOCRV5,
                }
            )

        model_type = os.environ.get("MARK_SHOT_OCR_MODEL_TYPE", "mobile").lower()
        if model_type == "server":
            params.update(
                {
                    "Det.model_type": ModelType.SERVER,
                    "Cls.model_type": ModelType.SERVER,
                    "Rec.model_type": ModelType.SERVER,
                }
            )

        model_dir = configured_model_dir()
        os.makedirs(model_dir, exist_ok=True)
        for key in (
            "Global.model_dir",
            "Global.models_dir",
            "Global.root_dir",
            "Global.save_dir",
            "ModelRoot",
        ):
            try:
                return engine_class(params={**params, key: model_dir})
            except Exception:
                pass

        return engine_class(params=params)
    except Exception as exc:
        print(f"rapidocr config fallback: {exc}", file=sys.stderr)
        return engine_class()


def call_rapidocr(engine, image_path):
    try:
        return engine(
            image_path,
            return_word_box=True,
            return_single_char_box=True,
        )
    except TypeError:
        return engine(image_path, return_word_box=True)


def normalize_object_output(output):
    tokens = []

    word_results = getattr(output, "word_results", None)
    if word_results:
        for line_index, line in enumerate(word_results):
            token_index = 0
            for item in line:
                if not isinstance(item, (list, tuple)) or len(item) < 3:
                    continue
                token_index = append_token(
                    tokens,
                    item[0],
                    item[2],
                    line_index,
                    token_index,
                    item[1] if len(item) > 1 else 0.0,
                )
        if tokens:
            return tokens

    boxes = getattr(output, "boxes", None)
    texts = getattr(output, "txts", None)
    scores = getattr(output, "scores", None)
    if boxes is not None and texts is not None:
        boxes = as_plain(boxes)
        texts = as_plain(texts)
        scores = as_plain(scores) if scores is not None else []
        for line_index, (box, text) in enumerate(zip(boxes, texts)):
            confidence = scores[line_index] if line_index < len(scores) else 0.0
            append_token(tokens, text, box, line_index, 0, confidence)

    return tokens


def normalize_list_output(output):
    tokens = []
    for line_index, item in enumerate(output):
        if isinstance(item, dict):
            box = item.get("box") or item.get("bbox") or item.get("points")
            append_token(
                tokens,
                item.get("text") or item.get("txt"),
                box,
                item.get("line", line_index),
                item.get("index", 0),
                item.get("confidence") or item.get("score") or 0.0,
            )
            continue

        if not isinstance(item, (list, tuple)) or len(item) < 3:
            continue

        box, text, score = item[0], item[1], item[2]
        token_index = 0
        if len(item) >= 4 and isinstance(item[3], (list, tuple)):
            for word in item[3]:
                if isinstance(word, (list, tuple)) and len(word) >= 3:
                    token_index = append_token(tokens, word[0], word[2], line_index, token_index, word[1])
        if token_index == 0:
            append_token(tokens, text, box, line_index, 0, score)
    return tokens


def run_rapidocr(image_path):
    with contextlib.redirect_stdout(sys.stderr):
        engine_class, backend = load_rapidocr_engine()
        engine = build_rapidocr_engine(engine_class, backend)
        output = call_rapidocr(engine, image_path)

    if isinstance(output, tuple) and output:
        output = output[0]

    if hasattr(output, "word_results") or hasattr(output, "boxes"):
        tokens = normalize_object_output(output)
    elif isinstance(output, list):
        tokens = normalize_list_output(output)
    else:
        tokens = []

    return backend, tokens


def parse_tesseract_tsv(content):
    tokens = []
    line_ids = {}
    indexes = {}
    reader = csv.DictReader(content.splitlines(), delimiter="\t")
    for row in reader:
        if row.get("level") != "5":
            continue
        text = (row.get("text") or "").strip()
        if not text:
            continue

        line_key = (
            row.get("block_num", "0"),
            row.get("par_num", "0"),
            row.get("line_num", "0"),
        )
        if line_key not in line_ids:
            line_ids[line_key] = len(line_ids)
            indexes[line_key] = 0

        try:
            box = [
                float(row.get("left", 0)),
                float(row.get("top", 0)),
                float(row.get("width", 0)),
                float(row.get("height", 0)),
            ]
            confidence = float(row.get("conf", 0))
        except ValueError:
            continue

        indexes[line_key] = append_token(
            tokens,
            text,
            box,
            line_ids[line_key],
            indexes[line_key],
            confidence,
        )
    return tokens


def run_tesseract(image_path):
    if shutil.which("tesseract") is None:
        raise RuntimeError("tesseract not found")

    primary_lang = os.environ.get("MARK_SHOT_OCR_LANG", "chi_sim+eng")
    psm = os.environ.get("MARK_SHOT_OCR_PSM", "6")
    languages = []
    for lang in (primary_lang, "eng"):
        if lang and lang not in languages:
            languages.append(lang)

    last_error = ""
    for lang in languages:
        command = [
            "tesseract",
            image_path,
            "stdout",
            "-l",
            lang,
            "--psm",
            psm,
            "tsv",
        ]
        result = subprocess.run(command, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
        if result.returncode == 0:
            return "tesseract", parse_tesseract_tsv(result.stdout)
        last_error = result.stderr.strip()

    raise RuntimeError(last_error or "tesseract failed")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("image")
    parser.add_argument("--backend", choices=("auto", "rapidocr", "tesseract"), default="auto")
    parser.add_argument("--format", choices=("json",), default="json")
    args = parser.parse_args()

    errors = []
    backend = ""
    tokens = []

    if args.backend in ("auto", "rapidocr"):
        try:
            backend, tokens = run_rapidocr(args.image)
        except Exception as exc:
            errors.append(f"rapidocr: {exc}")
            if args.backend == "rapidocr":
                print(json.dumps({"backend": "rapidocr", "tokens": [], "errors": errors}, ensure_ascii=False))
                return 1

    if not tokens and args.backend in ("auto", "tesseract"):
        try:
            backend, tokens = run_tesseract(args.image)
        except Exception as exc:
            errors.append(f"tesseract: {exc}")
            if args.backend == "tesseract":
                print(json.dumps({"backend": "tesseract", "tokens": [], "errors": errors}, ensure_ascii=False))
                return 1

    tokens = normalize_token_order(tokens)
    print(json.dumps({"backend": backend, "tokens": tokens, "errors": errors}, ensure_ascii=False))
    return 0


if __name__ == "__main__":
    maybe_reexec_venv()
    sys.exit(main())
