#!/usr/bin/env python3
"""
Lightweight QA regression harness for the RAG system.

Usage:
    python -m tests.run_regression
    python -m tests.run_regression --categories history,finance --limit 5
"""
from __future__ import annotations

import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Any

from main import RAGSystem


ROOT_DIR = Path(__file__).resolve().parent
BASELINE_PATH = ROOT_DIR / "baselines" / "qa_cases.json"
RESULTS_DIR = ROOT_DIR / "baselines" / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

GENERIC_PATTERNS = [
    "based on the available documents",
    "i analyzed the available documents",
    "no documents available",
    "no relevant documents found",
    "according to the distilled documents"
]


def load_cases(categories: List[str] | None, limit: int | None) -> List[Dict[str, Any]]:
    with open(BASELINE_PATH, "r", encoding="utf-8") as f:
        cases = json.load(f)

    if categories:
        categories_lower = {c.strip().lower() for c in categories}
        cases = [
            case for case in cases
            if case.get("category", "").lower() in categories_lower
        ]

    if limit is not None:
        cases = cases[:limit]

    return cases


def is_generic_answer(answer: str) -> bool:
    answer_lower = answer.lower()
    return any(pattern in answer_lower for pattern in GENERIC_PATTERNS)


def evaluate_case(rag: RAGSystem, case: Dict[str, Any]) -> Dict[str, Any]:
    question = case["question"]
    preferred_sources = case.get("preferred_sources", [])
    required_terms = case.get("required_terms", [])
    optional_terms = case.get("optional_terms", [])

    try:
        raw_result = rag.search_and_answer(question)
    except Exception as exc:
        return {
            "id": case["id"],
            "category": case.get("category"),
            "question": question,
            "status": "error",
            "error": str(exc),
            "answer": None,
            "sources": [],
            "metrics": {}
        }

    if isinstance(raw_result, str):
        return {
            "id": case["id"],
            "category": case.get("category"),
            "question": question,
            "status": "error",
            "error": raw_result,
            "answer": raw_result,
            "sources": [],
            "metrics": {}
        }

    answer = raw_result.get("answer", "")
    sources = raw_result.get("sources", []) or []
    source_names = [src.get("source") for src in sources if isinstance(src, dict)]

    answer_lower = answer.lower()
    missing_required = [
        term for term in required_terms
        if term.lower() not in answer_lower
    ]
    optional_hits = [
        term for term in optional_terms
        if term.lower() in answer_lower
    ]

    preferred_hit = any(
        src for src in source_names if src in preferred_sources
    )

    generic = is_generic_answer(answer)
    answer_length = len(answer.split())

    status = "pass"
    if missing_required or generic or not preferred_hit:
        status = "fail"

    metrics = {
        "answer_length": answer_length,
        "missing_required_terms": missing_required,
        "optional_hits": optional_hits,
        "preferred_sources_hit": preferred_hit,
        "generic_detected": generic,
        "top_sources": source_names[:3]
    }

    return {
        "id": case["id"],
        "category": case.get("category"),
        "question": question,
        "status": status,
        "answer": answer,
        "sources": sources,
        "metrics": metrics
    }


def summarise(results: List[Dict[str, Any]]) -> Dict[str, Any]:
    total = len(results)
    passes = sum(1 for result in results if result["status"] == "pass")
    failures = total - passes
    errors = sum(1 for result in results if result["status"] == "error")

    categories: Dict[str, Dict[str, int]] = {}
    for result in results:
        cat = (result.get("category") or "uncategorized").lower()
        bucket = categories.setdefault(cat, {"pass": 0, "fail": 0, "error": 0})
        bucket[result["status"]] = bucket.get(result["status"], 0) + 1

    return {
        "total_cases": total,
        "passes": passes,
        "failures": failures,
        "errors": errors,
        "by_category": categories
    }


def save_results(results: List[Dict[str, Any]], summary: Dict[str, Any]) -> Path:
    timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
    payload = {
        "generated_at": timestamp,
        "summary": summary,
        "cases": results
    }
    output_path = RESULTS_DIR / f"{timestamp}.json"
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)

    latest_path = RESULTS_DIR / "latest.json"
    with open(latest_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)

    return output_path


def print_summary(summary: Dict[str, Any]) -> None:
    total = summary["total_cases"]
    passes = summary["passes"]
    failures = summary["failures"]
    errors = summary["errors"]
    print("\n=== QA Regression Summary ===")
    print(f"Total cases: {total}")
    print(f"Passes    : {passes}")
    print(f"Failures  : {failures}")
    print(f"Errors    : {errors}")

    print("\nBy category:")
    for category, counts in sorted(summary["by_category"].items()):
        print(f"  {category.title():<20} pass={counts.get('pass', 0)} "
              f"fail={counts.get('fail', 0)} error={counts.get('error', 0)}")


def main(argv: List[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Run QA regression checks.")
    parser.add_argument(
        "--categories",
        type=lambda s: [part.strip() for part in s.split(",") if part.strip()],
        default=None,
        help="Comma-separated list of categories to evaluate."
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=None,
        help="Optional limit on number of cases."
    )
    parser.add_argument(
        "--fail-fast",
        action="store_true",
        help="Stop after the first failure or error."
    )
    args = parser.parse_args(argv)

    cases = load_cases(args.categories, args.limit)
    if not cases:
        print("No cases matched the filters.", file=sys.stderr)
        return 1

    rag = RAGSystem()
    rag.vector_store.load_existing_index()
    if not rag.vector_store.documents:
        print("Vector store is empty. Run setup before regression.", file=sys.stderr)
        return 1
    rag.initialize_model()

    results: List[Dict[str, Any]] = []
    for case in cases:
        result = evaluate_case(rag, case)
        results.append(result)
        status = result["status"].upper()
        print(f"[{status}] {case['id']}  :: {case['question']}")
        if status != "PASS":
            metrics = result.get("metrics", {})
            if result.get("error"):
                print(f"    error: {result['error']}")
            else:
                missing = metrics.get("missing_required_terms") or []
                if missing:
                    print(f"    missing terms: {missing}")
                if not metrics.get("preferred_sources_hit", False):
                    print(f"    preferred sources not found in top results")
                if metrics.get("generic_detected"):
                    print("    generic response detected")
            if args.fail_fast:
                break

    summary = summarise(results)
    save_path = save_results(results, summary)
    print_summary(summary)
    print(f"\nDetailed results saved to: {save_path}")

    return 0 if summary["failures"] == 0 and summary["errors"] == 0 else 1


if __name__ == "__main__":
    sys.exit(main())

