import os
import sqlite3
import time
import re
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# Point OpenAI client at the local DeepSeek-Qwen OpenAI-compatible endpoints
BASE = os.getenv('MEM0_OPENAI_BASE_URL', 'http://127.0.0.1:5004/v1')
os.environ.setdefault('OPENAI_BASE_URL', BASE)
os.environ.setdefault('OPENAI_API_BASE', BASE)
os.environ.setdefault('OPENAI_API_KEY', os.getenv('MEM0_OPENAI_API_KEY', 'local'))

from mem0 import Memory

app = FastAPI(title='eventheodds-mem0', version='0.2.0')

# Use mem0 ONLY as the extractor. We persist results ourselves so they survive restarts.
extractor = Memory()

DB_PATH = os.getenv('MEM0_SQLITE_PATH', '/var/www/html/eventheodds/mem0-service/mem0.sqlite')

DIMS = 1536


def _hash_embedding(text: str, dims: int = DIMS) -> List[float]:
    vec = [0.0] * dims
    tokens = re.findall(r"[a-z0-9]+", (text or "").lower())
    for tok in tokens:
        h = 0
        for ch in tok:
            h = ((h << 5) - h) + ord(ch)
            h &= 0xFFFFFFFF
        h = abs(int(h))
        idx = h % dims
        sign = 1.0 if (h % 2) == 0 else -1.0
        vec[idx] += sign

    norm = sum(v * v for v in vec) ** 0.5
    if norm <= 0:
        return vec
    return [v / norm for v in vec]


def _cos(a: List[float], b: List[float]) -> float:
    if not a or not b or len(a) != len(b):
        return 0.0
    return float(sum(x * y for x, y in zip(a, b)))


def _db() -> sqlite3.Connection:
    os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
    conn = sqlite3.connect(DB_PATH)
    conn.execute(
        """
        CREATE TABLE IF NOT EXISTS memories (
          user_id TEXT NOT NULL,
          memory  TEXT NOT NULL,
          embedding BLOB NOT NULL,
          created_at INTEGER NOT NULL,
          PRIMARY KEY (user_id, memory)
        )
        """
    )
    return conn


class ChatMsg(BaseModel):
    role: str
    content: str


class AddRequest(BaseModel):
    user_id: str
    messages: List[ChatMsg]
    metadata: Optional[Dict[str, Any]] = None


class SearchRequest(BaseModel):
    user_id: str
    query: str
    limit: int = 5



class AddRawRequest(BaseModel):
    user_id: str
    memory: str
    created_at: Optional[int] = None


@app.post('/memories/add_raw')
def add_memory_raw(req: AddRawRequest):
    try:
        mem = (req.memory or '').strip()
        if not mem:
            return {'ok': True, 'stored': False}

        now = int(req.created_at) if req.created_at is not None else int(time.time())

        conn = _db()
        try:
            emb = _hash_embedding(mem)
            conn.execute(
                'INSERT OR IGNORE INTO memories (user_id, memory, embedding, created_at) VALUES (?,?,?,?)',
                (req.user_id, mem, sqlite3.Binary(float_to_bytes(emb)), now),
            )
            conn.commit()
        finally:
            conn.close()

        return {'ok': True, 'stored': True}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get('/health')
def health():
    return {'ok': True, 'base_url': BASE, 'db': DB_PATH}


@app.post('/memories/add')
def add_memories(req: AddRequest):
    try:
        msgs = [{'role': m.role, 'content': m.content} for m in req.messages]
        res = extractor.add(msgs, user_id=req.user_id)

        # Persist extracted memories (if any)
        extracted = []
        for r in (res or {}).get('results', []):
            mem = (r or {}).get('memory')
            if isinstance(mem, str) and mem.strip():
                extracted.append(mem.strip())

        now = int(time.time())
        if extracted:
            conn = _db()
            try:
                for mem in extracted:
                    emb = _hash_embedding(mem)
                    conn.execute(
                        'INSERT OR IGNORE INTO memories (user_id, memory, embedding, created_at) VALUES (?,?,?,?)',
                        (req.user_id, mem, sqlite3.Binary(bytes(bytearray(float_to_bytes(emb)))), now),
                    )
                conn.commit()
            finally:
                conn.close()

        return {'ok': True, 'result': res}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


def float_to_bytes(vec: List[float]) -> bytes:
    # Store float32 little-endian
    import struct
    return b''.join(struct.pack('<f', float(x)) for x in vec)


def bytes_to_float(b: bytes) -> List[float]:
    import struct
    n = len(b) // 4
    return list(struct.unpack('<' + 'f' * n, b[: n * 4]))


@app.post('/memories/search')
def search_memories(req: SearchRequest):
    try:
        q = _hash_embedding(req.query)
        conn = _db()
        try:
            cur = conn.execute('SELECT memory, embedding, created_at FROM memories WHERE user_id = ?', (req.user_id,))
            rows = cur.fetchall()
        finally:
            conn.close()

        scored = []
        for mem, emb_blob, created_at in rows:
            try:
                emb = bytes_to_float(emb_blob)
            except Exception:
                continue
            score = _cos(q, emb)
            scored.append({'memory': mem, 'score': score, 'created_at': created_at})

        scored.sort(key=lambda x: x['score'], reverse=True)
        out = scored[: max(1, int(req.limit or 5))]
        return {'ok': True, 'result': {'results': out}}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
