"""
Advanced Keyword Management System for Knowledge Card Search
Provides comprehensive keyword extraction, tagging, filtering, and scoring capabilities.
"""

import re
import json
import math
from typing import Dict, List, Set, Tuple, Optional, Any
from collections import defaultdict, Counter
from pathlib import Path

# Optional NLTK imports
try:
    import nltk
    from nltk.corpus import stopwords
    from nltk.stem import WordNetLemmatizer
    NLTK_AVAILABLE = True
except ImportError:
    NLTK_AVAILABLE = False
    # Create dummy classes/functions for fallback
    class DummyLemmatizer:
        def lemmatize(self, word):
            return word
    WordNetLemmatizer = DummyLemmatizer

    class DummyStopwords:
        @staticmethod
        def words(lang):
            return [
                'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
                'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them',
                'this', 'that', 'these', 'those', 'is', 'am', 'are', 'was', 'were', 'be', 'been', 'being',
                'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'must'
            ]
    stopwords = DummyStopwords()

# Optional rapidfuzz imports
try:
    from rapidfuzz import fuzz
    RAPIDFUZZ_AVAILABLE = True
except ImportError:
    RAPIDFUZZ_AVAILABLE = False
    # Simple fallback implementation
    class FuzzFallback:
        @staticmethod
        def ratio(a, b):
            # Simple character overlap ratio
            a_set = set(a.lower())
            b_set = set(b.lower())
            intersection = len(a_set & b_set)
            union = len(a_set | b_set)
            return (intersection / union * 100) if union > 0 else 0

        @staticmethod
        def partial_ratio(a, b):
            # Simple substring matching
            a_lower = a.lower()
            b_lower = b.lower()
            if b_lower in a_lower:
                return 100
            # Check for partial matches
            shorter, longer = (a_lower, b_lower) if len(a_lower) < len(b_lower) else (b_lower, a_lower)
            matches = sum(1 for char in shorter if char in longer)
            return (matches / len(shorter) * 100) if len(shorter) > 0 else 0

    fuzz = FuzzFallback()
from airagagent.config import (
    TRADING_KEYWORD_TAXONOMY,
    TRADING_SYNONYMS,
    KEYWORD_WEIGHTS,
    DOMAIN_KEYWORD_MAPPING,
    TECHNICAL_TERMS
)

class KeywordManager:
    """
    Advanced keyword management system for knowledge card search.
    Handles keyword extraction, tagging, synonym expansion, and relevance scoring.
    """

    def __init__(self):
        # Initialize NLTK components (with fallbacks)
        if NLTK_AVAILABLE:
            # Download required NLTK data
            try:
                nltk.data.find('corpora/stopwords')
                nltk.data.find('corpora/wordnet')
            except LookupError:
                # Try to download, but don't fail if not available
                try:
                    nltk.download('stopwords', quiet=True)
                    nltk.download('wordnet', quiet=True)
                except:
                    pass

            self.lemmatizer = WordNetLemmatizer()
            self.stop_words = set(stopwords.words('english'))
        else:
            # Fallback without NLTK
            self.lemmatizer = lambda x: x  # Identity function fallback
            self.stop_words = set([
                'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
                'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them',
                'this', 'that', 'these', 'those', 'is', 'am', 'are', 'was', 'were', 'be', 'been', 'being',
                'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'must'
            ])

        self.lemmatizer = WordNetLemmatizer()
        self.stop_words = set(stopwords.words('english'))

        # Custom stop words for trading/finance domain
        self.domain_stop_words = {
            'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
            'trading', 'strategy', 'analysis', 'market', 'price', 'volume', 'time', 'chart', 'data'
        }
        self.stop_words.update(self.domain_stop_words)

        # Build comprehensive keyword index
        self.keyword_index = self._build_keyword_index()
        self.synonym_map = self._build_synonym_map()
        self.domain_filters = DOMAIN_KEYWORD_MAPPING

        # Pre-compute keyword vectors for fast similarity search
        self.keyword_vectors = self._build_keyword_vectors()

    def _build_keyword_index(self) -> Dict[str, Dict[str, Any]]:
        """Build comprehensive keyword index with metadata."""
        keyword_index = {}

        # Process taxonomy keywords
        for domain, keywords in TRADING_KEYWORD_TAXONOMY.items():
            for keyword in keywords:
                keyword_index[keyword.lower()] = {
                    'keyword': keyword,
                    'domain': domain,
                    'weight': KEYWORD_WEIGHTS.get(keyword.lower(), 0.6),
                    'length': len(keyword.split()),
                    'type': 'taxonomy'
                }

        # Add technical terms
        for term in TECHNICAL_TERMS:
            if term.lower() not in keyword_index:
                keyword_index[term.lower()] = {
                    'keyword': term,
                    'domain': 'general',
                    'weight': 0.5,
                    'length': len(term.split()),
                    'type': 'technical'
                }

        return keyword_index

    def _build_synonym_map(self) -> Dict[str, List[str]]:
        """Build synonym expansion map."""
        synonym_map = defaultdict(list)

        # Add explicit synonyms from config
        for canonical, synonyms in TRADING_SYNONYMS.items():
            canonical_lower = canonical.lower()
            synonym_map[canonical_lower].extend([s.lower() for s in synonyms])

            # Add reverse mappings
            for synonym in synonyms:
                synonym_lower = synonym.lower()
                if synonym_lower not in synonym_map:
                    synonym_map[synonym_lower] = []
                if canonical_lower not in synonym_map[synonym_lower]:
                    synonym_map[synonym_lower].append(canonical_lower)

        # Add lemmatized variations (if NLTK available)
        if self.lemmatizer:
            for keyword in list(self.keyword_index.keys()):
                lemma = self.lemmatizer.lemmatize(keyword)
                if lemma != keyword and lemma not in synonym_map[keyword]:
                    synonym_map[keyword].append(lemma)

        return dict(synonym_map)

    def _build_keyword_vectors(self) -> Dict[str, List[float]]:
        """Build simple keyword vectors for similarity matching."""
        vectors = {}

        # Simple character-based vectorization
        for keyword in self.keyword_index.keys():
            # Create vector based on character frequencies and word count
            vector = [0] * 256  # ASCII character space
            for char in keyword:
                if ord(char) < 256:
                    vector[ord(char)] += 1

            # Normalize
            total = sum(vector) or 1
            vector = [x/total for x in vector]

            # Add word count as additional dimension
            word_count = len(keyword.split())
            vector.extend([word_count / 10])  # Normalize word count

            vectors[keyword] = vector

        return vectors

    def extract_keywords(self, text: str, max_keywords: int = 20) -> List[Dict[str, Any]]:
        """
        Extract relevant keywords from text with scoring.

        Args:
            text: Input text to analyze
            max_keywords: Maximum number of keywords to return

        Returns:
            List of keyword dictionaries with scores and metadata
        """
        text_lower = text.lower()

        # Tokenize and clean
        words = re.findall(r'\b\w+\b', text_lower)
        words = [w for w in words if w not in self.stop_words and len(w) > 2]

        # Count word frequencies
        word_freq = Counter(words)

        # Extract n-grams (2-3 words)
        ngrams = self._extract_ngrams(text_lower, [2, 3])
        ngram_freq = Counter(ngrams)

        # Score candidates
        candidates = []

        # Score single words
        for word, freq in word_freq.items():
            if word in self.keyword_index:
                metadata = self.keyword_index[word]
                score = self._calculate_keyword_score(word, freq, metadata, text_lower)
                candidates.append({
                    'keyword': metadata['keyword'],
                    'score': score,
                    'frequency': freq,
                    'domain': metadata['domain'],
                    'type': metadata['type'],
                    'length': 1
                })

        # Score n-grams
        for ngram, freq in ngram_freq.items():
            if ngram in self.keyword_index:
                metadata = self.keyword_index[ngram]
                score = self._calculate_keyword_score(ngram, freq, metadata, text_lower)
                candidates.append({
                    'keyword': metadata['keyword'],
                    'score': score,
                    'frequency': freq,
                    'domain': metadata['domain'],
                    'type': metadata['type'],
                    'length': len(ngram.split())
                })

        # Sort by score and return top candidates
        candidates.sort(key=lambda x: x['score'], reverse=True)
        return candidates[:max_keywords]

    def _extract_ngrams(self, text: str, n_values: List[int]) -> List[str]:
        """Extract n-grams from text."""
        ngrams = []
        words = re.findall(r'\b\w+\b', text)

        for n in n_values:
            for i in range(len(words) - n + 1):
                ngram = ' '.join(words[i:i+n])
                if len(ngram) > 2:  # Filter very short ngrams
                    ngrams.append(ngram.lower())

        return ngrams

    def _calculate_keyword_score(self, keyword: str, frequency: int, metadata: Dict, text: str) -> float:
        """Calculate relevance score for a keyword."""
        base_score = metadata['weight']

        # Frequency boost
        freq_boost = min(frequency * 0.1, 1.0)

        # Position boost (keywords appearing early are more important)
        position_boost = 0
        keyword_pos = text.find(keyword.lower())
        if keyword_pos >= 0:
            position_boost = max(0, 1.0 - (keyword_pos / len(text)))

        # Length boost (longer phrases are more specific)
        length_boost = min(metadata['length'] * 0.1, 0.5)

        # Context boost (keywords near other trading terms)
        context_boost = self._calculate_context_boost(keyword, text)

        total_score = (base_score * 0.4 +
                      freq_boost * 0.2 +
                      position_boost * 0.1 +
                      length_boost * 0.1 +
                      context_boost * 0.2)

        return min(total_score, 1.0)

    def _calculate_context_boost(self, keyword: str, text: str) -> float:
        """Calculate boost based on surrounding trading context."""
        context_window = 100  # Characters around keyword
        keyword_pos = text.find(keyword.lower())

        if keyword_pos < 0:
            return 0

        # Extract context window
        start = max(0, keyword_pos - context_window)
        end = min(len(text), keyword_pos + len(keyword) + context_window)
        context = text[start:end]

        # Count trading-related terms in context
        trading_terms = ['trading', 'strategy', 'risk', 'profit', 'loss', 'market', 'price',
                        'analysis', 'indicator', 'signal', 'position', 'entry', 'exit']

        context_score = 0
        for term in trading_terms:
            if term in context and term != keyword:
                context_score += 0.1

        return min(context_score, 0.5)

    def expand_query(self, query: str, domain: Optional[str] = None) -> List[str]:
        """
        Expand query with synonyms and related terms.

        Args:
            query: Original query
            domain: Optional domain filter

        Returns:
            List of expanded query terms
        """
        expanded_terms = []

        # Tokenize query
        query_terms = re.findall(r'\b\w+\b', query.lower())

        for term in query_terms:
            # Add original term
            expanded_terms.append(term)

            # Add synonyms
            if term in self.synonym_map:
                expanded_terms.extend(self.synonym_map[term])

            # Add domain-specific related terms
            if domain and domain in self.domain_filters:
                for domain_keyword in self.domain_filters[domain]:
                    if fuzz.ratio(term, domain_keyword) > 80:
                        expanded_terms.append(domain_keyword)

        # Remove duplicates and filter stop words
        expanded_terms = list(set(expanded_terms))
        expanded_terms = [t for t in expanded_terms if t not in self.stop_words and len(t) > 2]

        return expanded_terms[:50]  # Limit expansion

    def find_similar_keywords(self, keyword: str, top_k: int = 10) -> List[Tuple[str, float]]:
        """
        Find semantically similar keywords using vector similarity.

        Args:
            keyword: Keyword to find similar terms for
            top_k: Number of similar keywords to return

        Returns:
            List of (keyword, similarity_score) tuples
        """
        if keyword.lower() not in self.keyword_vectors:
            return []

        target_vector = self.keyword_vectors[keyword.lower()]
        similarities = []

        for candidate, vector in self.keyword_vectors.items():
            if candidate != keyword.lower():
                similarity = self._cosine_similarity(target_vector, vector)
                similarities.append((candidate, similarity))

        # Sort by similarity
        similarities.sort(key=lambda x: x[1], reverse=True)

        return similarities[:top_k]

    def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
        """Calculate cosine similarity between two vectors."""
        dot_product = sum(a * b for a, b in zip(vec1, vec2))
        norm1 = math.sqrt(sum(a * a for a in vec1))
        norm2 = math.sqrt(sum(b * b for b in vec2))

        if norm1 == 0 or norm2 == 0:
            return 0

        return dot_product / (norm1 * norm2)

    def filter_by_domain(self, keywords: List[str], domain: str) -> List[str]:
        """
        Filter keywords by domain relevance.

        Args:
            keywords: List of keywords to filter
            domain: Target domain

        Returns:
            Filtered list of domain-relevant keywords
        """
        if domain not in self.domain_filters:
            return keywords

        domain_keywords = set(self.domain_filters[domain])
        filtered = []

        for keyword in keywords:
            # Direct match
            if keyword.lower() in domain_keywords:
                filtered.append(keyword)
                continue

            # Fuzzy match with domain keywords
            for domain_kw in domain_keywords:
                if fuzz.ratio(keyword.lower(), domain_kw) > 75:
                    filtered.append(keyword)
                    break

        return filtered

    def get_domain_keywords(self, domain: str) -> List[str]:
        """Get all keywords for a specific domain."""
        if domain not in self.domain_filters:
            return []

        return self.domain_filters[domain]

    def rank_keywords_by_relevance(self, keywords: List[str], query: str) -> List[Tuple[str, float]]:
        """
        Rank keywords by relevance to a query.

        Args:
            keywords: Keywords to rank
            query: Query string

        Returns:
            List of (keyword, relevance_score) tuples
        """
        query_lower = query.lower()
        ranked = []

        for keyword in keywords:
            keyword_lower = keyword.lower()

            # Exact match in query
            if keyword_lower in query_lower:
                score = 1.0
            # Partial match
            elif any(word in query_lower for word in keyword_lower.split()):
                score = 0.7
            # Fuzzy match
            else:
                score = fuzz.partial_ratio(keyword_lower, query_lower) / 100 * 0.5

            # Boost if it's a known keyword
            if keyword_lower in self.keyword_index:
                score *= 1.2

            ranked.append((keyword, min(score, 1.0)))

        ranked.sort(key=lambda x: x[1], reverse=True)
        return ranked

    def save_keyword_index(self, filepath: str):
        """Save keyword index to file."""
        data = {
            'keyword_index': self.keyword_index,
            'synonym_map': self.synonym_map,
            'domain_filters': self.domain_filters
        }

        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2)

    def load_keyword_index(self, filepath: str):
        """Load keyword index from file."""
        with open(filepath, 'r') as f:
            data = json.load(f)

        self.keyword_index = data.get('keyword_index', {})
        self.synonym_map = data.get('synonym_map', {})
        self.domain_filters = data.get('domain_filters', {})

# Singleton instance
_keyword_manager = None

def get_keyword_manager() -> KeywordManager:
    """Get singleton instance of KeywordManager."""
    global _keyword_manager
    if _keyword_manager is None:
        _keyword_manager = KeywordManager()
    return _keyword_manager
