import json
import os
import asyncio
import base64
from datetime import datetime
# pyrefly: ignore [missing-import]
from cryptography.fernet import Fernet
from config import DB_ENCRYPTION_KEY

class DatabaseManager:
    _instance = None
    _lock = asyncio.Lock()
    DB_PATH = "data/db.json"

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(DatabaseManager, cls).__new__(cls)
            cls._instance._initialized = False
        return cls._instance

    def __init__(self):
        if self._initialized: return
        os.makedirs("data", exist_ok=True)
        self._db_cache = None
        self._initialized = True
        # Initialize Fernet with our secret key
        self.fernet = Fernet(DB_ENCRYPTION_KEY.encode())

    async def initialize(self):
        """Creates the JSON file if it doesn't exist with initial schema."""
        async with self._lock:
            if not os.path.exists(self.DB_PATH):
                initial_data = {
                    "users": {},      # user_id as key
                    "products": {},   # name as key
                    "vouchers": {},   # product_name as key, list of codes
                    "settings": {
                        "usdt_rate": "125",
                        "payment_info": "Bkash: 017xxxxxxxx"
                    },
                    "payments": {},   # txnid as key
                    "logs": [],
                    "orders": []
                }
                await self._save_db(initial_data)
                self._db_cache = initial_data
                print("[OK] Encrypted JSON Database initialized.")
            else:
                self._db_cache = self._load_db()

    def _load_db(self):
        try:
            if not os.path.exists(self.DB_PATH): return {}
            with open(self.DB_PATH, "rb") as f:
                encrypted_data = f.read()
                if not encrypted_data: return {}
                
                # Decrypt data
                decrypted_data = self.fernet.decrypt(encrypted_data)
                return json.loads(decrypted_data.decode('utf-8'))
        except Exception as e:
            print(f"[ERROR] Database loading failed: {e}")
            # If decryption fails (maybe key changed), return empty or handle carefully
            return {}

    async def _save_db(self, data):
        self._db_cache = data # Update cache
        def write_to_disk():
            # Convert to JSON string then to bytes
            json_data = json.dumps(data, indent=4).encode('utf-8')
            # Encrypt data
            encrypted_data = self.fernet.encrypt(json_data)
            
            with open(self.DB_PATH, "wb") as f:
                f.write(encrypted_data)
        
        await asyncio.to_thread(write_to_disk)

    # --- USER METHODS ---
    async def get_user(self, user_id):
        user_id = str(user_id)
        async with self._lock:
            return self._db_cache["users"].get(user_id)

    async def add_user(self, user_id, fullname, role="user", group_id=None):
        user_id = str(user_id)
        async with self._lock:
            db = self._db_cache
            if user_id not in db["users"]:
                db["users"][user_id] = {
                    "user_id": user_id,
                    "fullname": fullname,
                    "role": role,
                    "status": "active",
                    "balance": 0.0,
                    "due": 0.0,
                    "due_limit": 0.0,
                    "currency": "BDT",
                    "group_id": group_id,
                    "created_at": str(datetime.now())
                }
                await self._save_db(db)

    async def update_balance(self, user_id, amount, is_add=True):
        user_id = str(user_id)
        async with self._lock:
            db = self._db_cache
            if user_id in db["users"]:
                if is_add:
                    db["users"][user_id]["balance"] += float(amount)
                else:
                    db["users"][user_id]["balance"] -= float(amount)
                await self._save_db(db)

    async def update_due(self, user_id, amount, is_add=True):
        user_id = str(user_id)
        async with self._lock:
            db = self._db_cache
            if user_id in db["users"]:
                if is_add:
                    db["users"][user_id]["due"] += float(amount)
                else:
                    db["users"][user_id]["due"] -= float(amount)
                await self._save_db(db)

    async def update_due_limit(self, user_id, amount):
        user_id = str(user_id)
        async with self._lock:
            db = self._db_cache
            if user_id in db["users"]:
                db["users"][user_id]["due_limit"] = float(amount)
                await self._save_db(db)

    async def clear_all_due(self, user_id):
        user_id = str(user_id)
        async with self._lock:
            db = self._db_cache
            if user_id in db["users"]:
                db["users"][user_id]["due"] = 0.0
                await self._save_db(db)

    async def update_user_status(self, user_id, status):
        user_id = str(user_id)
        async with self._lock:
            db = self._db_cache
            if user_id in db["users"]:
                db["users"][user_id]["status"] = status
                await self._save_db(db)

    # --- PRODUCT METHODS ---
    async def get_product(self, name):
        async with self._lock:
            return self._db_cache["products"].get(name)

    async def get_all_products(self):
        async with self._lock:
            return list(self._db_cache["products"].values())

    async def update_product(self, name, rate=None, increment_stock=None):
        async with self._lock:
            db = self._db_cache
            if name not in db["products"]:
                db["products"][name] = {"name": name, "rate": 0.0, "stock": 0}
            
            if rate is not None:
                db["products"][name]["rate"] = float(rate)
            if increment_stock is not None:
                db["products"][name]["stock"] += int(increment_stock)
            await self._save_db(db)

    # --- VOUCHER METHODS ---
    async def add_vouchers(self, product_name, codes):
        async with self._lock:
            db = self._db_cache
            if product_name not in db["vouchers"]:
                db["vouchers"][product_name] = []
            db["vouchers"][product_name].extend(codes)
            
            # Also update stock count in products
            if product_name not in db["products"]:
                db["products"][product_name] = {"name": product_name, "rate": 0.0, "stock": 0}
            db["products"][product_name]["stock"] += len(codes)
            
            await self._save_db(db)

    async def get_vouchers(self, product_name, qty):
        async with self._lock:
            db = self._db_cache
            if product_name in db["vouchers"] and len(db["vouchers"][product_name]) >= qty:
                extracted = db["vouchers"][product_name][:qty]
                db["vouchers"][product_name] = db["vouchers"][product_name][qty:]
                
                # Update stock count
                if product_name in db["products"]:
                    db["products"][product_name]["stock"] -= qty
                
                await self._save_db(db)
                return extracted
            return []

    # --- SETTINGS METHODS ---
    async def get_setting(self, key, default=None):
        async with self._lock:
            return self._db_cache["settings"].get(key, default)

    async def set_setting(self, key, value):
        async with self._lock:
            db = self._db_cache
            db["settings"][key] = str(value)
            await self._save_db(db)

    # --- LOGS & STATS ---
    async def add_log(self, user_id, action, details):
        async with self._lock:
            db = self._db_cache
            db["logs"].append({
                "user_id": str(user_id),
                "action": action,
                "details": details,
                "timestamp": str(datetime.now())
            })
            await self._save_db(db)

    async def get_stats(self):
        async with self._lock:
            db = self._db_cache
            total_bal = sum(u["balance"] for u in db["users"].values())
            total_due = sum(u["due"] for u in db["users"].values())
            return {
                "total_users": len(db["users"]),
                "total_balance": total_bal,
                "total_due": total_due
            }

    # --- PAYMENT METHODS ---
    async def get_payment(self, txnid):
        async with self._lock:
            db = self._db_cache
            return db["payments"].get(txnid)

    async def add_payment(self, txnid, amount, sender, method):
        async with self._lock:
            db = self._db_cache
            db["payments"][txnid] = {
                "txnid": txnid,
                "amount": amount,
                "sender": sender,
                "method": method,
                "status": "verified",
                "timestamp": str(datetime.now())
            }
            await self._save_db(db)
    async def clear_user_data(self, user_id):
        user_id = str(user_id)
        async with self._lock:
            db = self._db_cache
            if user_id in db["users"]:
                del db["users"][user_id]
                await self._save_db(db)

    async def get_all_user_ids(self):
        async with self._lock:
            return list(self._db_cache["users"].keys())
