from __future__ import annotations

import asyncio
from datetime import timedelta
import base64
import hashlib
import logging
import socket
from typing import Any, Dict, Optional, Tuple

from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.entity_registry import async_get as async_get_entity_registry
from homeassistant.helpers.storage import Store
from homeassistant.helpers import instance_id as hass_instance_id
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed

from .const import (
    DOMAIN,
    REGISTER_URL,
    FETCH_CODE_URL,
    STORAGE_VERSION,
    DATA_DEVICE_TOKEN,
    KEY_VALID_UNTIL,
    KEY_HA_INTEGRATION,
    KEY_STATUS,
    KEY_UPDATE_INTERVAL_HOURS,
    DATA_UPDATE_INTERVAL_HOURS,
    DATA_LAST_STATUS,
    CONNECTION_RETRY_INTERVAL_MINUTES
)

_LOGGER = logging.getLogger(__name__)




async def _store_for(hass: HomeAssistant, key: str) -> Store:
    return Store(hass, STORAGE_VERSION, key)


async def get_ha_uuid(hass: HomeAssistant) -> str:
    return hass.data.get("core.uuid") or await hass_instance_id.async_get(hass)


async def persist_device_token(hass: HomeAssistant, license_id: str, token: str) -> None:
    store = await _store_for(hass, f"onoff_licenser.state.{license_id}")
    data = await store.async_load() or {}
    data[DATA_DEVICE_TOKEN] = token
    await store.async_save(data)


async def load_state(hass: HomeAssistant, license_id: str) -> Dict[str, Any]:
    store = await _store_for(hass, f"onoff_licenser.state.{license_id}")
    return (await store.async_load()) or {}


async def save_state(hass: HomeAssistant, license_id: str, data: Dict[str, Any]) -> None:
    store = await _store_for(hass, f"onoff_licenser.state.{license_id}")
    current = await store.async_load() or {}
    current.update(data)
    await store.async_save(current)


def _get_private_ip_blocking() -> str:
    """Blocking helper that returns the host's private IP address."""
    try:
        # Use a UDP socket to let the OS pick the correct outbound interface
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
            s.connect(("8.8.8.8", 80))  # no packets are actually sent
            return s.getsockname()[0]
    except Exception:
        return "0.0.0.0"


async def async_get_private_ip(hass: HomeAssistant) -> str:
    """Async wrapper around the blocking private IP function."""
    return await hass.async_add_executor_job(_get_private_ip_blocking)


async def async_get_host_address(hass: HomeAssistant) -> str:
    """Return a human-readable address for HA's configured location."""
    lat = hass.config.latitude
    lon = hass.config.longitude

    if lat is None or lon is None:
        return "Unknown"

    session = async_get_clientsession(hass)

    url = "https://nominatim.openstreetmap.org/reverse"
    params = {
        "format": "jsonv2",
        "lat": str(lat),
        "lon": str(lon),
    }
    headers = {
        # Nominatim requires a proper User-Agent identifying your app
        "User-Agent": "home-assistant-onoff_licenser/1.0",
        "Accept-Language": "en",
    }

    try:
        async with session.get(url, params=params, headers=headers, timeout=5) as resp:
            if resp.status != 200:
                return "Unknown"
            data = await resp.json()
    except Exception:
        return "Unknown"

    address = data.get("address", {}) or {}

    # Build a nice address string; adjust as you like
    city = address.get("city") or address.get("town") or address.get("village")
    parts = [
        address.get("road"),
        address.get("house_number"),
        city,
        address.get("state"),
        address.get("postcode"),
        address.get("country"),
    ]

    parts = [p for p in parts if p]
    if parts:
        return ", ".join(parts)

    return data.get("display_name", "Unknown")


async def api_register(
    hass: HomeAssistant,
    *,
    license_id: str,
    activation_code: str,
    owner_email: str,
) -> Dict[str, Any]:
    session = async_get_clientsession(hass)
    ha_uuid = await get_ha_uuid(hass)
    ha_name = hass.config.location_name
    private_ip = await async_get_private_ip(hass)
    address = await async_get_host_address(hass)
    payload = {
        "license_id": license_id,
        "activation_code": activation_code,
        "owner_email": owner_email,
        "ha_uuid": ha_uuid,
        "ha_name": ha_name,
        "private_ip": private_ip,
        "latitude": hass.config.latitude,
        "longitude": hass.config.longitude,
        "address": address,
    }
    async with session.post(REGISTER_URL, json=payload, timeout=30) as resp:
        data = await resp.json(content_type=None)
        if resp.status != 200 or not data.get("ok"):
            raise RuntimeError(f"registration_failed: {resp.status} {data}")
        return data


async def api_fetch(
    hass: HomeAssistant,
    *,
    license_id: str,
    device_token: str,
) -> Tuple[str | None, str | None, str | None, Dict[str, Any]]:
    session = async_get_clientsession(hass)
    ha_uuid = await get_ha_uuid(hass)
    ha_name = hass.config.location_name
    private_ip = await async_get_private_ip(hass)
    address = await async_get_host_address(hass)
    payload = {
        "license_id": license_id,
        "ha_uuid": ha_uuid,
        "device_token": device_token,
        "ha_name": ha_name,
        "private_ip": private_ip,
        "latitude": hass.config.latitude,
        "longitude": hass.config.longitude,
        "address": address,
    }
    async with session.post(FETCH_CODE_URL, json=payload, timeout=30) as resp:
        data = await resp.json(content_type=None)
        
        # Fail only on non-200 or if 'ok' is false (i.e. if token is invalid or server error)
        if resp.status != 200 or not data.get("ok"):
            raise RuntimeError(f"fetch_failed: {resp.status} {data}")
            
        # Check if code is present (only if license is active)
        if data.get("code_b64"):
            code_bytes = base64.b64decode(data["code_b64"])
            sha_calc = hashlib.sha256(code_bytes).hexdigest()
            if data.get("sha256") and sha_calc != data["sha256"]:
                raise RuntimeError("sha256_mismatch")
            # Active license: return code and metadata
            return code_bytes.decode("utf-8"), sha_calc, data.get("filename", "remote.py"), data
        
        # Inactive license (200 OK received, but no code): return None for code/sha/filename, but pass metadata
        return None, None, None, data


class OnOffCoordinator(DataUpdateCoordinator):
    """Coordinator that periodically fetches license code & status."""

    def __init__(self, hass: HomeAssistant, license_id: str, hours: float) -> None:
        self.license_id = license_id
        self._code: str | None = None
        super().__init__(
            hass,
            _LOGGER,
            name=f"{DOMAIN}_{license_id}",
            # decimal hours allowed, e.g. 0.25 = 15 minutes
            update_interval=timedelta(hours=float(hours)),
        )

    def get_received_code(self) -> str | None:
        """Return the last successfully fetched code string (if any)."""
        return self._code

    async def _async_update_data(self) -> Dict[str, Any]:
        """Fetch latest license info and adjust interval if server says so."""
        state = await load_state(self.hass, self.license_id)

        # Prefer normalized key, but keep backward-compatible fallbacks
        token = (
            state.get(DATA_DEVICE_TOKEN)
            or state.get("device_token")
            or state.get("deviceToken")
            or state.get("device-token")
        )
        if not token:
            raise UpdateFailed("device_token_missing")

        # --- RETRY LOGIC START ---
        # We loop forever until we succeed.
        while True:
            try:
                # api_fetch now returns None for code, sha, filename if the license is inactive (but 200 OK)
                code, sha, filename, raw = await api_fetch(
                    self.hass, license_id=self.license_id, device_token=token
                )
                
                # If we reached here, the call was successful (status 200 + ok=True)
                break 

            except Exception as e:
                _LOGGER.warning(
                    "Connection or fetch failed for license %s: %s. Retrying in %s minutes...", 
                    self.license_id, 
                    e, 
                    CONNECTION_RETRY_INTERVAL_MINUTES
                )
                await asyncio.sleep(CONNECTION_RETRY_INTERVAL_MINUTES * 60)
        # --- RETRY LOGIC END ---

        try:
            self._code = code # <- self._code is set to None if license is revoked, disabling component

            # Let the backend drive the interval (decimal hours supported)
            interval_raw = raw.get(KEY_UPDATE_INTERVAL_HOURS)

            try:
                # Attempt to cast string (e.g., "12") or number to float
                new_hours = float(interval_raw) if interval_raw is not None else 0.0
                
                if new_hours > 0:
                    # This code now runs even if the license is revoked/expired, updating the interval
                    if self.update_interval != timedelta(hours=new_hours):
                        self.update_interval = timedelta(hours=new_hours)
                        await save_state(
                            self.hass,
                            self.license_id,
                            {DATA_UPDATE_INTERVAL_HOURS: new_hours},
                        )
                        _LOGGER.info(
                            "Updated license %s interval to %s hours",
                            self.license_id,
                            new_hours,
                        )
                    else:
                        _LOGGER.debug(
                            "License %s interval remains %s hours",
                            self.license_id,
                            new_hours,
                        )
            except (ValueError, TypeError):
                _LOGGER.debug(
                    "Received invalid update interval '%s' for license %s", 
                    interval_raw, self.license_id
                )
            
            # Log transition (async)
            await log_status_transition(
                self.hass,
                self.license_id,
                str(raw.get(KEY_STATUS)),
            )
            # Start integration reload flow on status change (async)
            await handle_license_status_change(
                self.hass,
                self.license_id,
                str(raw.get(KEY_STATUS)),
            )
            
            return {
                "ok": True,
                "filename": filename,
                "sha256": sha,
                KEY_VALID_UNTIL: raw.get(KEY_VALID_UNTIL),
                KEY_HA_INTEGRATION: raw.get(KEY_HA_INTEGRATION),
                KEY_STATUS: raw.get(KEY_STATUS),
                "error": None,
            }
        except Exception as e:
            # This catch block handles errors during data processing (post-fetch)
            _LOGGER.warning("License %s processing failed: %s", self.license_id, e)
            return {"ok": False, "error": str(e)}


async def reload_integration(hass: HomeAssistant, domain: str) -> bool:
    """
    Reload a Home Assistant integration by domain name.

    Returns:
        True  – reload was triggered
        False – component not found
    """
    _LOGGER.info("Trying reload integration '%s'", domain)
    try:
        entries = hass.config_entries.async_entries(domain)
        if not entries:
            _LOGGER.warning("No config entries found for domain '%s'", domain)
            return False

        for entry in entries:
            _LOGGER.debug("Reloading integration '%s' (entry_id=%s)", domain, entry.entry_id)
            await hass.config_entries.async_reload(entry.entry_id)

        return True

    except Exception as exc:
        _LOGGER.error("Failed reloading integration '%s': %s", domain, exc)
        return False


def _integration_sensor_entity_id(license_id: str) -> str:
    # Matches sensor.py naming, normalize '-' to '_'
    normalized = str(license_id).replace("-", "_")
    return f"sensor.onoff_licenser_{normalized}_integration"


async def _get_target_domain_from_integration_sensor(
    hass: HomeAssistant, license_id: str
) -> Optional[str]:
    """
    Read the current state of the integration sensor for this license
    to determine the target domain to reload.
    """
    entity_id = _integration_sensor_entity_id(license_id)

    # Prefer registry lookup to ensure entity exists
    ent_reg = async_get_entity_registry(hass)
    if ent_reg.async_get(entity_id) is None:
        _LOGGER.debug("Integration sensor '%s' not found in registry", entity_id)

    state = hass.states.get(entity_id)
    if not state or not state.state or state.state in ("unknown", "unavailable"):
        _LOGGER.warning(
            "Cannot resolve target domain: sensor '%s' state is '%s'",
            entity_id, state.state if state else "None"
        )
        return None

    return state.state.strip()


async def _maybe_trigger_reload_on_status_change(
    hass: HomeAssistant, license_id: str, new_status: Optional[str]
) -> None:
    """
    If status changed to/from 'active', reload the integration domain
    determined by the license's integration sensor.
    """
    status_norm = (new_status or "").lower()

    # Track last status per license
    last = hass.data.setdefault(DATA_LAST_STATUS, {}).get(license_id)
    hass.data[DATA_LAST_STATUS][license_id] = status_norm

    if last is None:
        # First observation, no transition to react to
        return

    changed = status_norm != last
    if not changed:
        return

    was_active = last == "active"
    is_active = status_norm == "active"
    if was_active == is_active:
        # Status changed but did not cross active boundary
        return

    target_domain = await _get_target_domain_from_integration_sensor(hass, license_id)
    if not target_domain:
        return

    _LOGGER.info(
        "License '%s' status changed %s -> %s; reloading integration '%s'",
        license_id, last, status_norm, target_domain
    )
    await reload_integration(hass, target_domain)


async def is_license_valid(hass: HomeAssistant, license_id: str, data: dict) -> bool:
    status = (data or {}).get("status")  # e.g., "active", "revoked", etc.

    # Trigger reload when crossing active boundary
    try:
        await _maybe_trigger_reload_on_status_change(hass, license_id, status)
    except Exception as exc:
        _LOGGER.error(
            "Failed handling reload for license '%s' on status change: %s",
            license_id, exc
        )

    # Existing logic to decide validity
    if status and status.lower() != "active":
        _LOGGER.warning(
            "License %s considered invalid due to status '%s'",
            license_id, status
        )
        return False

    return True


async def log_status_transition(hass: HomeAssistant, license_id: str, new_status: str) -> None:
    """Log when a license status changes; also log the *_integration sensor state and reload."""
    store = hass.data.setdefault(DATA_LAST_STATUS, {})
    old_status = store.get(license_id)

    # Update stored status
    store[license_id] = new_status

    # Resolve integration sensor state (this is the target domain)
    entity_id = _integration_sensor_entity_id(license_id)
    entity_state = hass.states.get(entity_id)
    target_domain = (
        entity_state.state.strip()
        if entity_state and entity_state.state not in (None, "unknown", "unavailable")
        else None
    )

    # Only log real changes
    if old_status is not None and str(old_status) != str(new_status):
        _LOGGER.debug(
            "onoff_licenser: License '%s' status changed: %s -> %s; %s state: %s",
            license_id, old_status, new_status, entity_id, entity_state.state if entity_state else "None"
        )

        # Reload the target integration when we have a valid domain
        if target_domain:
            await reload_integration(hass, target_domain)
        else:
            _LOGGER.debug(
                "onoff_licenser: Cannot reload — %s has no valid state (got: %s)",
                entity_id, entity_state.state if entity_state else "None"
            )


async def handle_license_status_change(hass: HomeAssistant, license_id: str, new_status: str) -> None:
    """Handle license status change: log and trigger integration reload via *_integration sensor state."""
    store = hass.data.setdefault(DATA_LAST_STATUS, {})
    old_status = store.get(license_id)

    # Update stored status
    store[license_id] = new_status

    # Resolve integration sensor state (this is the target domain)
    entity_id = _integration_sensor_entity_id(license_id)
    entity_state = hass.states.get(entity_id)
    target_domain = (
        entity_state.state.strip()
        if entity_state and entity_state.state not in (None, "unknown", "unavailable")
        else None
    )

    # Only log real changes
    if old_status is not None and str(old_status) != str(new_status):
        _LOGGER.debug(
            "onoff_licenser: License '%s' status changed: %s -> %s; %s state: %s",
            license_id, old_status, new_status, entity_id, entity_state.state if entity_state else "None"
        )

        # Reload the target integration when we have a valid domain
        if target_domain:
            await reload_integration(hass, target_domain)
        else:
            _LOGGER.debug(
                "onoff_licenser: Cannot reload — %s has no valid state (got: %s)",
                entity_id, entity_state.state if entity_state else "None"
            )