#!/usr/bin/env python3
"""
extract_totals.py
Usage:
    python3 extract_totals.py /absolute/path/to/receipt.jpg

Requires:
    pip install pillow pytesseract
And Tesseract OCR installed on the machine.
"""

import sys
import json
import re
from typing import Optional, List, Tuple
from PIL import Image
import pytesseract

# ============ Money token ============
# Supports 1969.69 and 1,969.69 (with optional currency symbol).
# Lookarounds prevent partial matches like the tail "9.69".
MONEY_TOKEN = r"(?<!\d)(?:[\$\£\€]\s*)?(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d{2})(?!\d)"
MONEY_RE = re.compile(MONEY_TOKEN)

# If you want to allow integers (no cents), use:
# MONEY_TOKEN = r"(?<!\d)(?:[\$\£\€]\s*)?(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d{2})?(?!\d)"
# MONEY_RE = re.compile(MONEY_TOKEN)

# ============ Flexible keyword patterns per category ============
# Tolerates common OCR glitches (spaces/hyphens, l/1, a/i).
CATEGORY_PATTERNS = {
    "subtotal": [
        r"\bsub[\s\-]*total\b",
        r"\bsub[\s\-]*tot[aai1]l\b",  # subtotall/subtotai/subtota1
    ],
    "tax": [
        r"\btax\b", r"\bsales\s*tax\b", r"\bhst\b", r"\bgst\b", r"\bpst\b", r"\bvat\b",
    ],
    "total": [
        r"\bgrand[\s\-]*total\b",
        r"\b(total|total\s*due|amount\s*due|balance\s*due|amount\s*payable)\b",
    ],
    # Optional:
    # "tip": [r"\btip\b", r"\bgratuity\b", r"\bservice\s*charge\b"],
}

def build_keyword_patterns():
    """Compile start-of-line and anywhere-in-line detectors for each category."""
    compiled = {}
    for cat, pats in CATEGORY_PATTERNS.items():
        arr = []
        for pat in pats:
            arr.append(re.compile(rf"^\s*(?:{pat})\s*:?\s*", re.IGNORECASE))
            arr.append(re.compile(rf"(?:{pat})\s*:?\s*", re.IGNORECASE))
        compiled[cat] = arr
    return compiled

# ============ Helpers ============

def rightmost_amount(line: str) -> Optional[Tuple[str, float]]:
    """Return (raw, float) for the rightmost monetary amount on a line, or None."""
    matches = list(MONEY_RE.finditer(line))
    if not matches:
        return None
    raw = matches[-1].group(0)
    normalized = re.sub(r"[^\d.]", "", raw)  # "$1,969.69" -> "1969.69"
    try:
        return raw, float(normalized)
    except ValueError:
        return None

def ocr_text(image_path: str) -> str:
    """Run OCR and return raw text."""
    img = Image.open(image_path).convert("L")
    # preserve_interword_spaces helps keep gaps between label and amount
    return pytesseract.image_to_string(
        img, config="--oem 1 --psm 6 -c preserve_interword_spaces=1"
    )

def line_has_keyword(line: str, patterns: List[re.Pattern]) -> bool:
    return any(p.search(line) for p in patterns)

def next_non_empty_index(lines: List[str], start_idx: int) -> Optional[int]:
    j = start_idx
    while j < len(lines) and not lines[j].strip():
        j += 1
    return j if j < len(lines) else None

def find_category_amounts(
    lines_lc: List[str],
    patterns: List[re.Pattern],
    allow_multiple: bool = False
) -> List[dict]:
    """
    Find amounts for a category.
    - Bottom-up same-line matches first.
    - Then two-line fallback: keyword on one line, amount on next non-empty line.
    - If allow_multiple=True (tax), return all matches; else return first found.
    """
    results = []

    # 1) Same-line matches (bottom-up)
    for idx in range(len(lines_lc) - 1, -1, -1):
        line = lines_lc[idx]
        if line_has_keyword(line, patterns):
            ra = rightmost_amount(line)
            if ra:
                raw, amt = ra
                results.append({
                    "line_index": idx,
                    "line": line.strip(),
                    "raw_amount": raw,
                    "amount": amt,
                    "source": "same_line",
                })
                if not allow_multiple:
                    return results

    # 2) Two-line fallback (keyword line + next non-empty line has amount)
    for i in range(len(lines_lc) - 1):
        line1 = lines_lc[i]
        if not line_has_keyword(line1, patterns):
            continue
        j = next_non_empty_index(lines_lc, i + 1)
        if j is None:
            continue
        ra = rightmost_amount(lines_lc[j])
        if ra:
            raw, amt = ra
            results.append({
                "line_index": (i, j),
                "line": (line1.strip() + " | " + lines_lc[j].strip()),
                "raw_amount": raw,
                "amount": amt,
                "source": "next_line",
            })
            if not allow_multiple:
                return results

    return results

# ============ Main extraction ============

def extract_receipt_totals(image_path: str):
    text = ocr_text(image_path)
    # Lowercased lines for matching; we keep only lc for simplicity
    lines_lc = [ln.lower() for ln in text.splitlines()]

    kw_patterns = build_keyword_patterns()

    # Subtotal: first match
    subtotal_hits = find_category_amounts(lines_lc, kw_patterns["subtotal"], allow_multiple=False)
    subtotal = subtotal_hits[0]["amount"] if subtotal_hits else None

    # Tax: sum all matches (hst/gst/pst/etc.)
    tax_hits = find_category_amounts(lines_lc, kw_patterns["tax"], allow_multiple=True)
    tax_sum = round(sum(hit["amount"] for hit in tax_hits), 2) if tax_hits else None

    # Total: first match (prefer the bottom-most)
    total_hits = find_category_amounts(lines_lc, kw_patterns["total"], allow_multiple=False)
    total = total_hits[0]["amount"] if total_hits else None

    # ---- Derived fields ----
    derived_total = None
    derived_subtotal = None

    if total is None and subtotal is not None and tax_sum is not None:
        derived_total = round(subtotal + tax_sum, 2)

    if subtotal is None and total is not None and tax_sum is not None:
        derived_subtotal = round(total - tax_sum, 2)

    result = {
        "success": any(v is not None for v in [subtotal, tax_sum, total, derived_total, derived_subtotal]),
        "subtotal": subtotal,
        "tax": tax_sum,
        "total": total,
        "derived_total": derived_total,         # when we had subtotal + tax but no total
        "derived_subtotal": derived_subtotal,   # when we had total + tax but no subtotal
        "matches": {
            "subtotal": subtotal_hits,
            "tax": tax_hits,
            "total": total_hits,
        },
        "raw_text": text,
    }

    # Optional reconciliation diagnostics
    if total is not None and subtotal is not None and tax_sum is not None:
        calc_total = round(subtotal + tax_sum, 2)
        delta_total = round(total - calc_total, 2)
        if abs(delta_total) >= 0.01:
            result["reconciliation_delta_total"] = delta_total

    if subtotal is not None and derived_subtotal is not None:
        delta_sub = round(subtotal - derived_subtotal, 2)
        if abs(delta_sub) >= 0.01:
            result["reconciliation_delta_subtotal"] = delta_sub

    return result

# ============ CLI ============

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print(json.dumps({"success": False, "message": "No image path provided."}))
        sys.exit(0)
    path = sys.argv[1]
    try:
        out = extract_receipt_totals(path)
        print(json.dumps(out))
        sys.exit(0)
    except FileNotFoundError:
        print(json.dumps({"success": False, "message": f"File not found: {path}"}))
        sys.exit(0)
    except Exception as e:
        print(json.dumps({"success": False, "message": str(e)}))
        sys.exit(0)
