"""
okx_framework.py - General OKX data access framework
Place this file in the project; other data types only need to import and inherit from it
"""

import time, json, threading, os, math, queue, asyncio
from typing import List, Any, Optional
from abc import ABC, abstractmethod
import numpy as np
from ..configs import DDB,OKX_BASE_CONFIG
from okx.websocket.WsPublicAsync import WsPublicAsync
import dolphindb as ddb


# ========================= JSON safe serialization =========================
def normalize_scalar(x: Any) -> Any:
    """Convert NumPy scalar/time types and floating NaN/Inf values into JSON-friendly types."""
    if x is None:
        return None
    if isinstance(x, np.datetime64):
        return int(np.datetime64(x, 'ms').astype('datetime64[ms]').astype(np.int64))
    if isinstance(x, np.integer):
        return int(x)
    if isinstance(x, np.floating):
        fx = float(x)
        if math.isnan(fx) or math.isinf(fx):
            return None
        return fx
    if isinstance(x, np.bool_):
        return bool(x)
    if isinstance(x, float):
        if math.isnan(x) or math.isinf(x):
            return None
        return x
    if isinstance(x, (int, bool, str)):
        return x
    return x

def normalize_row(row: Any) -> Any:
    """Recursively convert list/dict/ndarray structures into something json.dumps can handle."""
    if row is None:
        return None
    if isinstance(row, (list, tuple)):
        return [normalize_row(v) for v in row]
    if isinstance(row, dict):
        return {k: normalize_row(v) for k, v in row.items()}
    if isinstance(row, np.ndarray):
        return normalize_row(row.tolist())
    return normalize_scalar(row)


# ========================= Generic SafeWsPublicAsync =========================
class SafeWsPublicAsync(WsPublicAsync):
    """
    A “safety wrapper” around the official WsPublicAsync:
      - Do not use its loop.create_task; manage Task handles ourselves
      - stop(): unsubscribe → cancel tasks → close websocket/factory (do not call loop.stop)
      - unsubscribe(): do not override callback; support unsubscribing all
    """
    def __init__(self, url: str):
        super().__init__(url)   # Initialize parent WsPublicAsync
        self._task: Optional[asyncio.Task] = None # Task handle
        self._subs: List[dict] = []     # List of subscription parameters

    async def start(self):
        """Connect the WebSocket and start the async consumer task."""
        await self.connect()
        self._task = asyncio.create_task(self.consume())

    async def subscribe(self, params: list, callback):
        """Subscribe to data and register a callback function."""
        self.callback = callback
        if params:
            self._subs.extend(params)   # Add subscription parameters to _subs
        payload = json.dumps({"op": "subscribe", "args": params})
        await self.websocket.send(payload)

    async def unsubscribe(self, params: Optional[List[dict]] = None):
        """Unsubscribe from specified WebSocket subscriptions."""
        if params is None:
            params = list(self._subs)
        if not params:
            return
        payload = json.dumps({"op": "unsubscribe", "args": params})
        try:
            await self.websocket.send(payload)
        except Exception:
            pass
        try:
            rm = {json.dumps(p, sort_keys=True) for p in params}
            self._subs = [p for p in self._subs if json.dumps(p, sort_keys=True) not in rm] # Remove unsubscribed entries
        except Exception:
            pass

    async def stop(self):
        """Stop the WebSocket connection."""
        try:
            await self.unsubscribe()
        except Exception:
            pass

        if getattr(self, "websocket", None):
            try:
                await self.websocket.close()                     # Close WebSocket connection
                if hasattr(self.websocket, "wait_closed"):
                    await asyncio.wait_for(self.websocket.wait_closed(), timeout=5.0)   # Wait for connection to close
            except Exception:
                pass
            finally:
                self.websocket = None

        if self._task and not self._task.done():
            self._task.cancel()
            try:
                await asyncio.wait_for(self._task, timeout=5.0)
            except Exception:
                pass
        self._task = None

        try:
            await self.factory.close()
        except Exception:
            pass


# ========================= Generic IOThread =========================
class IOThread(threading.Thread):
    """
    Centralized control of the **single write channel** for “write MTW / write to local / replay”, executed serially to naturally avoid out-of-order writes.
    Three-state state machine:
      - live   : DDB healthy; read from real-time queue and write to MTW (switch to offline on failure)
      - offline: DDB unhealthy; flush real-time queue to local file; after cooldown, probe from first local row
      - replay : only write local cache back to MTW (strictly do not consume real-time queue); switch to live after successful replay
    """
    def __init__(self, config):
        super().__init__(daemon=True)
        self.config = config
        self.path = config.BUFFER_FILE
        self.mode = 'live'
        self.next_probe_ts = 0.0
        self.writer = None
        self.realtime_q = config.realtime_q
        self.file_lock = config.file_lock

    def ensure_table_exists(self):
        s = ddb.session()
        s.connect(
            self.config.dolphindb_address, 
            self.config.dolphindb_port,
            self.config.dolphindb_user, 
            self.config.dolphindb_password
        )
        try:
            s.run(self.config.get_create_table_script())
            s.close()
        except:
            pass

    def build_mtw(self) -> bool:
        """Try to rebuild MTW; return True on success, False on failure."""
        try:
            self.ensure_table_exists()
            self.writer = ddb.MultithreadedTableWriter(
                self.config.dolphindb_address, 
                self.config.dolphindb_port, 
                self.config.dolphindb_user, 
                self.config.dolphindb_password,
                dbPath="", 
                tableName=self.config.tableName,
                batchSize=10000, 
                throttle=1, 
                threadCount=1, 
                reconnect=False
            )
            return True
        except Exception as e:
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Failed to rebuild MTW: {e}")
            self.writer = None
            return False

    def save_unwritten_to_local(self):
        """
        **Critical ordering point**: once MTW write fails, first persist MTW’s internal
        unwrittenData to local storage, then persist the current row / real-time queue data, ensuring “data that entered MTW earlier is written to local first”.
        """
        if self.writer is None:
            return
        try:
            unwritten = self.writer.getUnwrittenData()
        except Exception as e:
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] getUnwrittenData failed: {e}")
            self.writer = None
            return

        if not unwritten:
            return

        with self.file_lock, open(self.path, "a", encoding="utf-8") as f:
            for row in unwritten:
                f.write(json.dumps(normalize_row(row), ensure_ascii=False) + "\n")
            f.flush()
        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Persisted {len(unwritten)} unwritten MTW rows to local storage")

    def _append_rows_to_local(self, rows: List[List[Any]]):
        if not rows:
            return
        with self.file_lock, open(self.path, "a", encoding="utf-8") as f:
            for row in rows:
                f.write(json.dumps(normalize_row(row), ensure_ascii=False) + "\n")
            f.flush()

    def _save_queue_to_local(self, max_n: int = 50000):
        moved = 0
        buf: List[List[Any]] = []
        while moved < max_n:
            try:
                row = self.realtime_q.get_nowait()
            except queue.Empty:
                break
            buf.append(row)
            moved += 1
        if buf:
            self._append_rows_to_local(buf)

    def _insert_one(self, row: List[Any]):
        if self.writer is None and not self.build_mtw():
            raise RuntimeError("MTW unavailable")
        # print(self.writer.getStatus())
        res = self.writer.insert(*row)
        if hasattr(res, "hasError") and res.hasError():
            raise RuntimeError(res.errorInfo)

    def _probe_from_local_first_line(self) -> bool:
        if not os.path.exists(self.path):
            return False
        if self.writer is None and not self.build_mtw():
            return False

        with self.file_lock:
            with open(self.path, "r", encoding="utf-8", newline="\n") as f:
                line = f.readline()
                if not line or not line.endswith("\n"):
                    return False
                try:
                    row = json.loads(line)
                    self._insert_one(row)
                    print("First local row probe write succeeded")
                    return True
                except Exception:
                    self.writer = None
                    return False

    def _replay_all_local(self) -> bool:
        if not os.path.exists(self.path):
            return True
        if self.writer is None and not self.build_mtw():
            return False

        total = 0
        try:
            with self.file_lock:
                src = open(self.path, "r", encoding="utf-8", newline="\n")
                while True:
                    batch_lines = []
                    for _ in range(self.config.READ_BATCH_SIZE):
                        line = src.readline()
                        if not line:
                            break
                        if not line.endswith("\n"):
                            break
                        batch_lines.append(line)
                    if not batch_lines:
                        break

                    i = 0
                    try:
                        for i, line in enumerate(batch_lines):
                            row = json.loads(line)
                            self._insert_one(row)
                        total += len(batch_lines)
                        continue
                    except Exception as e:
                        print("Issue encountered during replay")
                        tmp_path = self.path + ".tmp"
                        try:
                            uw = self.writer.getUnwrittenData()
                        except:
                            print("getUnwrittenData failed during replay")
                            self.writer = None
                        with open(tmp_path, "w", encoding="utf-8") as tmp:
                            for row in uw:
                                tmp.write(json.dumps(normalize_row(row), ensure_ascii=False) + "\n")
                            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Replay interrupted: {e}, unwritten data written back to local")
                            tmp.writelines(batch_lines[i:])
                            for rest in src:
                                tmp.write(rest)
                        src.close()
                        os.replace(tmp_path, self.path)
                        self.writer = None
                        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Replay interrupted: {e}, unprocessed data written back to local")
                        return False

            src.close()
            with self.file_lock, open(self.path, "w", encoding="utf-8"):
                pass
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Replay completed, {total} rows written, cache cleared")
            return True

        except Exception as e:
            try:
                src.close()
            except:
                pass
            print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Replay failed: {e} (cache retained, will retry)")
            self.writer = None
            return False

    def run(self):
        while True:
            # If local cache exists, prioritize probing / replay
            if os.path.exists(self.path) and os.path.getsize(self.path) > 0:
                if self.mode != 'replay':
                    now = time.time()
                    # During cooldown: flush queue to local and wait for next probe
                    if self.mode == 'offline' and now < self.next_probe_ts:
                        self._save_queue_to_local()
                        time.sleep(0.1)
                        continue

                    # At probe time: try writing the first local row
                    ok = self._probe_from_local_first_line()
                    if ok:
                        self.mode = 'replay'
                        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Liveness check passed, entering replay mode")
                    else:
                        self.mode = 'offline'
                        self.next_probe_ts = time.time() + self.config.PROBE_COOLDOWN_SECS
                        self._save_queue_to_local()
                        time.sleep(0.2)
                        continue

                # Replay mode: strictly do not consume real-time queue
                if self.mode == 'replay':
                    success = self._replay_all_local()
                    if success:
                        self.mode = 'live'  # Replay finished, back to live
                    else:
                        self.mode = 'offline'
                        self.next_probe_ts = time.time() + self.config.PROBE_COOLDOWN_SECS
                        self._save_queue_to_local()
                        time.sleep(0.2)
                        continue

            # No local data or already cleared; live writing
            if self.mode == 'live':
                try:
                    row = self.realtime_q.get(timeout=self.config.LIVE_GET_TIMEOUT)
                except queue.Empty:
                    time.sleep(0.05)
                    continue

                try:
                    self._insert_one(row)
                except Exception as e:
                    print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Real-time write failed: {e}, switching offline and persisting")
                    self.save_unwritten_to_local()
                    self._append_rows_to_local([row])
                    self._save_queue_to_local()
                    self.writer = None
                    self.mode = 'offline'
                    self.next_probe_ts = time.time() + self.config.PROBE_COOLDOWN_SECS
                    time.sleep(0.2)
                    continue

            elif self.mode == 'offline':
                self._save_queue_to_local()
                time.sleep(0.1)
                continue


# ========================= Generic OKXLoopThread =========================
class OKXLoopThread(threading.Thread):
    """
    Dedicated event-loop thread:
      - connect → subscribe
      - callback handle_message(message) (parse only → enqueue)
      - no data for TIMEOUT seconds → shutdown → reconnect
      - generation ID to filter data from old connections
    """
    def __init__(self, config, inst_ids: List[str]):
        super().__init__(daemon=True)
        self.config = config
        self.inst_ids = inst_ids
        self.loop = None
        self.ws: Optional[SafeWsPublicAsync] = None
        self.last_recv_ms = 0
        self._current_generation = 0

    def run(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        try:
            self.loop.run_until_complete(self._main())
        finally:
            try:
                pending = [t for t in asyncio.all_tasks(self.loop) if not t.done()]
                for t in pending:
                    t.cancel()
                if pending:
                    self.loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
            except Exception:
                pass
            self.loop.close()

    async def _main(self):
        while True:
            try:
                self._current_generation += 1
                my_gen = self._current_generation

                self.ws = SafeWsPublicAsync(self.config.OKX_WS_URL)
                await self.ws.start()
                self.last_recv_ms = self._now_ms()

                # Get subscription parameters
                args = self.config.get_subscription_args(self.inst_ids)

                def _on_msg(raw: str):
                    if my_gen != self._current_generation:
                        return
                    # Call the configured message handler
                    self.config.handle_message(raw)
                    self.last_recv_ms = self._now_ms()

                await self.ws.subscribe(args, _on_msg)

                while True:
                    await asyncio.sleep(1)
                    if self._now_ms() - self.last_recv_ms > self.config.TIMEOUT * 1000:
                        break

            except Exception:
                await asyncio.sleep(self.config.RECONNECT_TIME)

            finally:
                try:
                    if self.ws:
                        await self.ws.stop()
                except Exception:
                    pass
                finally:
                    self.ws = None

    def _now_ms(self) -> int:
        return int(time.time() * 1000)


# ========================= Base configuration class =========================
class OKXBaseConfig(ABC):
    """
    Base class for all OKX data access configurations
    Inherit from this class and implement the abstract methods
    """
    
    # DolphinDB connection configuration
    dolphindb_address = DDB["HOST"]
    dolphindb_port = DDB["PORT"]
    dolphindb_user = DDB["USER"]
    dolphindb_password = DDB["PWD"]
    
    # Table name and cache file (must be set by subclasses)
    tableName = ""
    BUFFER_FILE = ""
    
    # Behavioral parameters (tunable as needed)
    TIMEOUT = OKX_BASE_CONFIG["TIMEOUT"]
    PROBE_COOLDOWN_SECS = OKX_BASE_CONFIG["PROBE_COOLDOWN_SECS"]
    READ_BATCH_SIZE = OKX_BASE_CONFIG["READ_BATCH_SIZE"]
    LIVE_GET_TIMEOUT = OKX_BASE_CONFIG["LIVE_GET_TIMEOUT"]
    
    # OKX WebSocket configuration
    OKX_WS_URL = OKX_BASE_CONFIG["OKX_WS_URL"]
    RECONNECT_TIME = OKX_BASE_CONFIG["RECONNECT_TIME"]
    
    def __init__(self):
        # Create shared objects
        self.realtime_q = queue.Queue(maxsize=1_000_000)
        self.file_lock = threading.Lock()
        self.writer = None
        
        # Ensure cache directory exists
        if self.BUFFER_FILE:
            os.makedirs(os.path.dirname(self.BUFFER_FILE) or '.', exist_ok=True)
    
    @abstractmethod
    def get_create_table_script(self) -> str:
        """Return the DolphinDB table creation script."""
        pass
    
    @abstractmethod
    def get_subscription_args(self, inst_ids: List[str]) -> List[dict]:
        """Return WebSocket subscription parameters."""
        pass
    
    @abstractmethod
    def handle_message(self, message: str):
        """Handle received messages."""
        pass
    
    def start(self, inst_ids: List[str]):
        """Start the data pipeline."""
        # Start the writer thread
        io_thread = IOThread(self)
        io_thread.start()
        
        # Start the WebSocket thread
        okx_thread = OKXLoopThread(self, inst_ids)
        okx_thread.start()
        
        print(f"[{time.strftime('%Y-%m-%dT%H:%M:%S')}] Data pipeline started")
        print(f"- Table: {self.tableName}")
        print(f"- Cache: {self.BUFFER_FILE}")
        print(f"- Instruments: {inst_ids}")
        
        return io_thread, okx_thread
