"""
Case Document Processor for AI Lawyer
Handles case-specific document upload, processing, and chunking
"""

import os
import hashlib
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from datetime import datetime
import json

# PDF processing
try:
    import PyPDF2
    HAS_PYPDF2 = True
except ImportError:
    HAS_PYPDF2 = False

try:
    import fitz  # PyMuPDF
    HAS_PYMUPDF = True
except ImportError:
    HAS_PYMUPDF = False


class CaseDocumentProcessor:
    """Process documents for a specific legal case"""

    # Legal document type classification keywords
    DOCUMENT_PATTERNS = {
        'complaint': ['complaint', 'petition', 'plaintiff', 'cause of action', 'wherefore'],
        'answer': ['answer', 'defendant', 'denies', 'admits', 'affirmative defense'],
        'motion': ['motion', 'memorandum', 'points and authorities', 'opposition', 'reply'],
        'discovery': ['interrogatory', 'request for production', 'request for admission', 'subpoena'],
        'deposition': ['deposition', 'q.', 'a.', 'witness', 'sworn testimony', 'direct examination'],
        'contract': ['agreement', 'contract', 'parties agree', 'whereas', 'witnesseth', 'consideration'],
        'order': ['order', 'it is hereby ordered', 'judgment', 'decree', 'ruling'],
        'brief': ['brief', 'argument', 'issue presented', 'statement of facts', 'conclusion'],
        'declaration': ['declaration', 'declare under penalty', 'sworn', 'affidavit'],
        'exhibit': ['exhibit', 'attachment', 'appendix']
    }

    # Document categories
    CATEGORIES = {
        'pleadings': ['complaint', 'answer', 'counterclaim', 'cross-claim', 'reply'],
        'motions': ['motion', 'brief', 'memorandum', 'opposition', 'reply_brief'],
        'discovery': ['interrogatory', 'request_production', 'request_admission', 'deposition', 'subpoena'],
        'evidence': ['exhibit', 'declaration', 'affidavit', 'evidence'],
        'orders': ['order', 'judgment', 'ruling', 'decree'],
        'contracts': ['contract', 'agreement', 'lease', 'license'],
        'correspondence': ['letter', 'email', 'memo']
    }

    def __init__(self, case_uuid: str, user_id: int, base_path: str = None):
        self.case_uuid = case_uuid
        self.user_id = user_id

        if base_path is None:
            base_path = Path(__file__).parent / 'data' / 'case_docs'

        self.case_dir = Path(base_path) / str(user_id) / case_uuid
        self.case_dir.mkdir(parents=True, exist_ok=True)

    def process_upload(self, file_content: bytes, filename: str,
                       document_type: str = None) -> Dict:
        """
        Process an uploaded legal document

        Args:
            file_content: Raw file bytes
            filename: Original filename
            document_type: Optional document type override

        Returns:
            Dict with processing results
        """
        # Generate file hash for deduplication
        file_hash = hashlib.md5(file_content).hexdigest()

        # Create unique filename
        ext = Path(filename).suffix.lower()
        safe_filename = f"{file_hash}_{self._sanitize_filename(filename)}"
        filepath = self.case_dir / safe_filename

        # Save file
        with open(filepath, 'wb') as f:
            f.write(file_content)

        # Extract text
        text, page_info = self._extract_text(filepath, ext)

        if not text:
            return {
                'success': False,
                'error': 'Could not extract text from document',
                'filename': filename
            }

        # Classify document type if not provided
        if not document_type or document_type == 'auto':
            document_type = self._classify_document(filename, text)

        # Get category
        category = self._get_category(document_type)

        # Extract legal metadata
        metadata = self._extract_legal_metadata(text, document_type)

        # Create intelligent chunks
        chunks = self._create_legal_chunks(text, document_type, page_info)

        return {
            'success': True,
            'filename': safe_filename,
            'original_filename': filename,
            'file_hash': file_hash,
            'file_size': len(file_content),
            'file_type': ext.lstrip('.'),
            'document_type': document_type,
            'category': category,
            'text_length': len(text),
            'page_count': len(page_info) if page_info else 1,
            'chunks': chunks,
            'chunk_count': len(chunks),
            'metadata': metadata
        }

    def _sanitize_filename(self, filename: str) -> str:
        """Sanitize filename for safe storage"""
        # Remove path components
        filename = os.path.basename(filename)
        # Replace problematic characters
        filename = re.sub(r'[^\w\-_\.]', '_', filename)
        # Limit length
        if len(filename) > 200:
            name, ext = os.path.splitext(filename)
            filename = name[:200-len(ext)] + ext
        return filename

    def _extract_text(self, filepath: Path, ext: str) -> Tuple[str, List[Dict]]:
        """Extract text from document with page information"""
        text = ""
        page_info = []

        if ext == '.pdf':
            text, page_info = self._extract_pdf(filepath)
        elif ext in ['.docx', '.doc']:
            text, page_info = self._extract_docx(filepath)
        elif ext == '.txt':
            with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
                text = f.read()
            page_info = [{'page': 1, 'start': 0, 'end': len(text)}]
        else:
            # Try to read as text
            try:
                with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
                    text = f.read()
                page_info = [{'page': 1, 'start': 0, 'end': len(text)}]
            except:
                pass

        return text, page_info

    def _extract_pdf(self, filepath: Path) -> Tuple[str, List[Dict]]:
        """Extract text from PDF with page tracking"""
        text = ""
        page_info = []
        current_pos = 0

        if HAS_PYMUPDF:
            try:
                doc = fitz.open(filepath)
                for page_num in range(len(doc)):
                    page = doc[page_num]
                    page_text = page.get_text()

                    start_pos = current_pos
                    text += f"\n[PAGE {page_num + 1}]\n{page_text}\n"
                    current_pos = len(text)

                    page_info.append({
                        'page': page_num + 1,
                        'start': start_pos,
                        'end': current_pos
                    })
                doc.close()
                return text, page_info
            except Exception as e:
                print(f"PyMuPDF extraction failed: {e}")

        if HAS_PYPDF2:
            try:
                with open(filepath, 'rb') as f:
                    reader = PyPDF2.PdfReader(f)
                    for page_num, page in enumerate(reader.pages):
                        page_text = page.extract_text() or ""

                        start_pos = current_pos
                        text += f"\n[PAGE {page_num + 1}]\n{page_text}\n"
                        current_pos = len(text)

                        page_info.append({
                            'page': page_num + 1,
                            'start': start_pos,
                            'end': current_pos
                        })
                return text, page_info
            except Exception as e:
                print(f"PyPDF2 extraction failed: {e}")

        return text, page_info

    def _extract_docx(self, filepath: Path) -> Tuple[str, List[Dict]]:
        """Extract text from DOCX"""
        try:
            from docx import Document
            doc = Document(filepath)
            paragraphs = [p.text for p in doc.paragraphs]
            text = '\n\n'.join(paragraphs)
            page_info = [{'page': 1, 'start': 0, 'end': len(text)}]
            return text, page_info
        except Exception as e:
            print(f"DOCX extraction failed: {e}")
            return "", []

    def _classify_document(self, filename: str, text: str) -> str:
        """Classify document type based on filename and content"""
        filename_lower = filename.lower()
        text_lower = text[:5000].lower()  # Check first 5000 chars

        # Check filename patterns first
        for doc_type, patterns in self.DOCUMENT_PATTERNS.items():
            for pattern in patterns:
                if pattern in filename_lower:
                    return doc_type

        # Check content patterns
        scores = {}
        for doc_type, patterns in self.DOCUMENT_PATTERNS.items():
            score = sum(1 for p in patterns if p in text_lower)
            if score > 0:
                scores[doc_type] = score

        if scores:
            return max(scores, key=scores.get)

        return 'general_document'

    def _get_category(self, document_type: str) -> str:
        """Get the category for a document type"""
        for category, types in self.CATEGORIES.items():
            if document_type in types or any(t in document_type for t in types):
                return category
        return 'other'

    def _extract_legal_metadata(self, text: str, document_type: str) -> Dict:
        """Extract legal metadata from document text"""
        metadata = {
            'parties': {},
            'dates': [],
            'case_number': None,
            'court': None,
            'claims': [],
            'key_terms': []
        }

        text_upper = text[:10000]  # Focus on beginning

        # Extract case number patterns
        case_patterns = [
            r'Case\s*(?:No\.?|Number:?)\s*([A-Z0-9\-:]+)',
            r'(?:Civil|Criminal)\s*(?:No\.?|Action\s*No\.?)\s*([A-Z0-9\-:]+)',
            r'Docket\s*(?:No\.?|Number:?)\s*([A-Z0-9\-:]+)'
        ]
        for pattern in case_patterns:
            match = re.search(pattern, text_upper, re.IGNORECASE)
            if match:
                metadata['case_number'] = match.group(1).strip()
                break

        # Extract court
        court_patterns = [
            r'(?:IN\s+THE\s+)?(?:UNITED\s+STATES\s+)?(DISTRICT\s+COURT[^\n]+)',
            r'(?:IN\s+THE\s+)?(SUPERIOR\s+COURT[^\n]+)',
            r'(?:IN\s+THE\s+)?(CIRCUIT\s+COURT[^\n]+)',
            r'(?:IN\s+THE\s+)?(COURT\s+OF\s+APPEALS?[^\n]+)'
        ]
        for pattern in court_patterns:
            match = re.search(pattern, text_upper, re.IGNORECASE)
            if match:
                metadata['court'] = match.group(1).strip()
                break

        # Extract parties (plaintiff/defendant)
        party_patterns = [
            (r'([A-Z][A-Za-z\s,\.]+),?\s*Plaintiff', 'plaintiff'),
            (r'([A-Z][A-Za-z\s,\.]+),?\s*Defendant', 'defendant'),
            (r'Plaintiff[:\s]+([A-Z][A-Za-z\s,\.]+)', 'plaintiff'),
            (r'Defendant[:\s]+([A-Z][A-Za-z\s,\.]+)', 'defendant')
        ]
        for pattern, party_type in party_patterns:
            match = re.search(pattern, text_upper)
            if match and party_type not in metadata['parties']:
                party_name = match.group(1).strip()
                # Clean up the name
                party_name = re.sub(r'\s+', ' ', party_name)
                party_name = party_name.rstrip(',.')
                if len(party_name) < 100:  # Sanity check
                    metadata['parties'][party_type] = party_name

        # Extract dates
        date_patterns = [
            r'\b(\d{1,2}/\d{1,2}/\d{2,4})\b',
            r'\b(\d{1,2}-\d{1,2}-\d{2,4})\b',
            r'\b([A-Z][a-z]+\s+\d{1,2},?\s+\d{4})\b'
        ]
        for pattern in date_patterns:
            matches = re.findall(pattern, text_upper)
            metadata['dates'].extend(matches[:10])  # Limit to 10 dates

        # Extract claims/causes of action
        if document_type in ['complaint', 'petition']:
            claim_patterns = [
                r'(?:FIRST|SECOND|THIRD|FOURTH|FIFTH)\s+(?:CAUSE\s+OF\s+ACTION|CLAIM)[:\s]+([^\n]+)',
                r'COUNT\s+(?:I|II|III|IV|V|ONE|TWO|THREE)[:\s]+([^\n]+)',
                r'(?:FOR|CLAIM\s+FOR)\s+(BREACH\s+OF\s+CONTRACT|NEGLIGENCE|FRAUD|[A-Z][A-Za-z\s]+)'
            ]
            for pattern in claim_patterns:
                matches = re.findall(pattern, text_upper, re.IGNORECASE)
                metadata['claims'].extend([m.strip() for m in matches[:10]])

        # Extract key legal terms
        legal_terms = [
            'breach of contract', 'negligence', 'fraud', 'misrepresentation',
            'breach of fiduciary duty', 'unjust enrichment', 'conversion',
            'defamation', 'intentional infliction', 'strict liability',
            'summary judgment', 'default judgment', 'preliminary injunction',
            'temporary restraining order', 'class action', 'derivative action'
        ]
        for term in legal_terms:
            if term.lower() in text_upper.lower():
                metadata['key_terms'].append(term)

        return metadata

    def _create_legal_chunks(self, text: str, document_type: str,
                            page_info: List[Dict]) -> List[Dict]:
        """Create intelligent chunks for legal documents"""
        chunks = []

        if document_type in ['complaint', 'answer', 'motion', 'brief']:
            # Chunk by legal sections
            chunks = self._chunk_by_sections(text, document_type, page_info)
        elif document_type == 'contract':
            # Chunk by contract sections/clauses
            chunks = self._chunk_contract(text, page_info)
        elif document_type == 'deposition':
            # Chunk by Q&A exchanges
            chunks = self._chunk_deposition(text, page_info)
        else:
            # Default semantic chunking
            chunks = self._semantic_chunk(text, page_info)

        # Add metadata to chunks
        for i, chunk in enumerate(chunks):
            chunk['chunk_index'] = i
            chunk['document_type'] = document_type
            chunk['case_uuid'] = self.case_uuid

        return chunks

    def _chunk_by_sections(self, text: str, document_type: str,
                          page_info: List[Dict]) -> List[Dict]:
        """Chunk legal documents by their natural sections"""
        chunks = []

        # Section header patterns
        section_patterns = [
            r'\n\s*((?:FIRST|SECOND|THIRD|FOURTH|FIFTH|SIXTH|SEVENTH|EIGHTH|NINTH|TENTH)\s+(?:CAUSE\s+OF\s+ACTION|CLAIM|COUNT)[^\n]*)',
            r'\n\s*(COUNT\s+(?:I|II|III|IV|V|VI|VII|VIII|IX|X|ONE|TWO|THREE|FOUR|FIVE)[^\n]*)',
            r'\n\s*((?:I|II|III|IV|V|VI|VII|VIII|IX|X)\.\s+[A-Z][^\n]+)',
            r'\n\s*(JURISDICTION\s+AND\s+VENUE)',
            r'\n\s*(PARTIES)',
            r'\n\s*(FACTUAL\s+(?:ALLEGATIONS?|BACKGROUND))',
            r'\n\s*(STATEMENT\s+OF\s+(?:FACTS|THE\s+CASE))',
            r'\n\s*(ARGUMENT|DISCUSSION)',
            r'\n\s*(CONCLUSION)',
            r'\n\s*(PRAYER\s+FOR\s+RELIEF)',
            r'\n\s*(WHEREFORE)',
        ]

        # Find all section boundaries
        boundaries = [(0, 'BEGINNING')]
        for pattern in section_patterns:
            for match in re.finditer(pattern, text, re.IGNORECASE):
                boundaries.append((match.start(), match.group(1).strip()))

        # Sort by position
        boundaries.sort(key=lambda x: x[0])
        boundaries.append((len(text), 'END'))

        # Create chunks from sections
        for i in range(len(boundaries) - 1):
            start_pos = boundaries[i][0]
            end_pos = boundaries[i + 1][0]
            section_name = boundaries[i][1]

            section_text = text[start_pos:end_pos].strip()

            if len(section_text) < 50:  # Skip tiny sections
                continue

            # Get page number
            page_num = self._get_page_for_position(start_pos, page_info)

            # If section is too long, split further
            if len(section_text) > 2000:
                sub_chunks = self._split_long_section(section_text, page_info, start_pos)
                for j, sub_chunk in enumerate(sub_chunks):
                    chunks.append({
                        'content': sub_chunk['content'],
                        'section': section_name,
                        'subsection': j + 1,
                        'page': sub_chunk.get('page', page_num),
                        'start_char': start_pos + sub_chunk.get('start', 0),
                        'content_length': len(sub_chunk['content'])
                    })
            else:
                chunks.append({
                    'content': section_text,
                    'section': section_name,
                    'page': page_num,
                    'start_char': start_pos,
                    'content_length': len(section_text)
                })

        # If no sections found, use semantic chunking
        if len(chunks) <= 1:
            return self._semantic_chunk(text, page_info)

        return chunks

    def _chunk_contract(self, text: str, page_info: List[Dict]) -> List[Dict]:
        """Chunk contract by sections and clauses"""
        chunks = []

        # Contract section patterns
        patterns = [
            r'\n\s*(\d+\.)\s+([A-Z][A-Z\s]+)\n',  # "1. DEFINITIONS"
            r'\n\s*(ARTICLE\s+[IVXLC\d]+)[:\.\s]+([^\n]+)',
            r'\n\s*(SECTION\s+\d+)[:\.\s]+([^\n]+)',
            r'\n\s*(RECITALS|WHEREAS|WITNESSETH)',
            r'\n\s*(DEFINITIONS|TERM|PAYMENT|TERMINATION|CONFIDENTIAL|INDEMNIF|LIMITATION|GOVERNING\s+LAW)'
        ]

        boundaries = [(0, 'PREAMBLE')]
        for pattern in patterns:
            for match in re.finditer(pattern, text, re.IGNORECASE):
                section_name = match.group(0).strip()
                boundaries.append((match.start(), section_name[:100]))

        boundaries.sort(key=lambda x: x[0])
        boundaries.append((len(text), 'END'))

        for i in range(len(boundaries) - 1):
            start_pos = boundaries[i][0]
            end_pos = boundaries[i + 1][0]
            section_name = boundaries[i][1]

            section_text = text[start_pos:end_pos].strip()
            if len(section_text) < 30:
                continue

            page_num = self._get_page_for_position(start_pos, page_info)

            chunks.append({
                'content': section_text[:3000],  # Limit chunk size
                'section': section_name,
                'page': page_num,
                'start_char': start_pos,
                'content_length': min(len(section_text), 3000)
            })

        if len(chunks) <= 1:
            return self._semantic_chunk(text, page_info)

        return chunks

    def _chunk_deposition(self, text: str, page_info: List[Dict]) -> List[Dict]:
        """Chunk deposition by Q&A exchanges"""
        chunks = []

        # Find Q&A patterns
        qa_pattern = r'(Q\.?\s+[^\n]+(?:\n(?![QA]\.)[^\n]+)*)\s*(A\.?\s+[^\n]+(?:\n(?![QA]\.)[^\n]+)*)'

        matches = list(re.finditer(qa_pattern, text, re.MULTILINE))

        if matches:
            current_chunk = ""
            chunk_start = 0

            for match in matches:
                qa_text = match.group(0)

                if len(current_chunk) + len(qa_text) > 1500:
                    if current_chunk:
                        page_num = self._get_page_for_position(chunk_start, page_info)
                        chunks.append({
                            'content': current_chunk.strip(),
                            'section': 'Q&A Exchange',
                            'page': page_num,
                            'start_char': chunk_start,
                            'content_length': len(current_chunk)
                        })
                    current_chunk = qa_text
                    chunk_start = match.start()
                else:
                    current_chunk += "\n\n" + qa_text

            if current_chunk:
                page_num = self._get_page_for_position(chunk_start, page_info)
                chunks.append({
                    'content': current_chunk.strip(),
                    'section': 'Q&A Exchange',
                    'page': page_num,
                    'start_char': chunk_start,
                    'content_length': len(current_chunk)
                })

        if not chunks:
            return self._semantic_chunk(text, page_info)

        return chunks

    def _semantic_chunk(self, text: str, page_info: List[Dict],
                       chunk_size: int = 1000, overlap: int = 100) -> List[Dict]:
        """Default semantic chunking by paragraphs"""
        chunks = []

        # Split by paragraphs
        paragraphs = re.split(r'\n\s*\n', text)

        current_chunk = ""
        chunk_start = 0
        current_pos = 0

        for para in paragraphs:
            para = para.strip()
            if not para:
                current_pos += 2
                continue

            if len(current_chunk) + len(para) > chunk_size:
                if current_chunk:
                    page_num = self._get_page_for_position(chunk_start, page_info)
                    chunks.append({
                        'content': current_chunk.strip(),
                        'section': 'content',
                        'page': page_num,
                        'start_char': chunk_start,
                        'content_length': len(current_chunk)
                    })

                # Start new chunk with overlap
                overlap_text = current_chunk[-overlap:] if len(current_chunk) > overlap else ""
                current_chunk = overlap_text + para
                chunk_start = current_pos - len(overlap_text)
            else:
                if current_chunk:
                    current_chunk += "\n\n" + para
                else:
                    current_chunk = para
                    chunk_start = current_pos

            current_pos += len(para) + 2

        # Add final chunk
        if current_chunk.strip():
            page_num = self._get_page_for_position(chunk_start, page_info)
            chunks.append({
                'content': current_chunk.strip(),
                'section': 'content',
                'page': page_num,
                'start_char': chunk_start,
                'content_length': len(current_chunk)
            })

        return chunks

    def _split_long_section(self, text: str, page_info: List[Dict],
                           base_pos: int) -> List[Dict]:
        """Split a long section into smaller chunks"""
        return self._semantic_chunk(text, page_info, chunk_size=1500, overlap=150)

    def _get_page_for_position(self, pos: int, page_info: List[Dict]) -> int:
        """Get the page number for a character position"""
        if not page_info:
            return 1

        for info in page_info:
            if info['start'] <= pos < info['end']:
                return info['page']

        return page_info[-1]['page'] if page_info else 1
