from flask import Flask, request, jsonify
from contextlib import closing
from conf import connect_to_database, Config
from app import app
import razorpay
import os
import traceback
import csv

# Initialize Razorpay client
# RAZORPAY_KEY_ID = os.getenv('RAZORPAY_KEY_ID')
# RAZORPAY_KEY_SECRET = os.getenv('RAZORPAY_KEY_SECRET')

# Load config for mail
app.config.from_object(Config)

razorpay_client = razorpay.Client(auth=(Config.RAZORPAY_KEY_ID, Config.RAZORPAY_KEY_SECRET))


def create_razorpay_plan(package_data):
    """Create a Razorpay plan with improved validation and debug logging"""
    try:
        # Validate required fields
        if 'price' not in package_data or 'currency_code' not in package_data:
            raise ValueError("package_data must include 'price' and 'currency_code'")

        # Determine period/interval
        billing_cycle = package_data.get('billing_cycle', 'monthly').lower()
        if billing_cycle == 'monthly':
            period = 'monthly'
            interval = 1
        elif billing_cycle == 'annual' or billing_cycle == 'yearly':
            period = 'yearly'
            interval = 12
        elif billing_cycle == 'lifetime':
            # Razorpay has no native 'lifetime' plan; create a yearly plan with very long expiry via notes if you must.
            period = 'yearly'
            interval = 1
        else:
            period = 'monthly'
            interval = 1

        plan_params = {
            "period": period,
            "interval": int(interval),
            "item": {
                "name": package_data.get('package_name', 'Unnamed Plan'),
                "description": package_data.get('package_description', ''),
                "amount": int(round(float(package_data['price']) * 100)),  # paise
                "currency": package_data['currency_code']
            },
            "notes": {
                "country_id": package_data.get('country_id', ''),
                "package_id": package_data.get('package_id', 'new')
            }
        }

        # Optional: lifetime note
        if billing_cycle == 'lifetime':
            plan_params['notes']['expire_by'] = 100 * 365 * 24 * 60  # 100 years in minutes (your idea)

        # Debug print - log payload just before call
        app.logger.debug("Creating Razorpay plan with payload: %s", plan_params)

        # Create plan
        razorpay_plan = razorpay_client.plan.create(plan_params)

        app.logger.info("Razorpay plan created: %s", razorpay_plan)
        return razorpay_plan

    except razorpay.errors.BadRequestError as bre:
        app.logger.error("Razorpay BadRequestError: %s", str(bre))
        app.logger.debug(traceback.format_exc())
        raise
    except Exception as e:
        app.logger.error("Error creating Razorpay plan: %s", str(e))
        app.logger.debug(traceback.format_exc())
        raise

def update_razorpay_plan(plan_id, package_data):
    """Helper function to update a Razorpay plan"""
    try:
        # Razorpay doesn't support direct plan updates, so we need to create a new one
        # and potentially handle the transition for existing subscriptions
        
        # First get the existing plan to check if we need to update
        existing_plan = razorpay_client.plan.fetch(plan_id)
        
        # Check if any critical fields have changed
        critical_changes = (
            existing_plan['item']['amount'] != int(float(package_data['price']) * 100) or
            existing_plan['item']['name'] != package_data['package_name'] or
            existing_plan['item']['description'] != package_data.get('package_description', '') or
            existing_plan['period'] != ('monthly' if package_data['billing_cycle'] != 'lifetime' else 'yearly') or
            existing_plan['interval'] != (1 if package_data['billing_cycle'] == 'monthly' else 12)
        )
        
        if critical_changes:
            # For critical changes, create a new plan and return its ID
            new_plan = create_razorpay_plan(package_data)
            return new_plan['id']
        else:
            # No critical changes, keep the existing plan
            return plan_id
            
    except Exception as e:
        app.logger.error(f"Error updating Razorpay plan: {str(e)}")
        raise

def delete_razorpay_plan(plan_id):
    """Helper function to delete a Razorpay plan"""
    try:
        # Razorpay doesn't allow deleting plans with active subscriptions
        # So we'll just deactivate it
        return True
    except Exception as e:
        app.logger.error(f"Error deleting Razorpay plan: {str(e)}")
        raise


def get_or_create_country(connection, country_data):
    """
    Check if country exists in llx_c_country table, if not create it.
    Returns the country_id.
    
    country_data should contain:
    - country: Full name (e.g., "United States")
    - country_code: 2-letter code (e.g., "US")
    - country_iso: 3-letter ISO code (e.g., "USA")
    - numeric_code: 3-digit code (e.g., "840")
    """
    try:
        with closing(connection.cursor()) as cursor:
            # Check if country exists by code
            cursor.execute(
                "SELECT id FROM llx_c_country WHERE code = %s", 
                (country_data['country_code'],)
            )
            country = cursor.fetchone()
            
            if country:
                app.logger.info(f"Country {country_data['country']} already exists with ID: {country[0]}")
                return country[0]
            
            # Country doesn't exist, insert it
            insert_query = """
                INSERT INTO llx_c_country (code, code_iso, numeric_code, label, eec, active, favorite)
                VALUES (%s, %s, %s, %s, 0, 1, 0)
            """
            cursor.execute(insert_query, (
                country_data['country_code'],
                country_data.get('country_iso', country_data['country_code'] + country_data['country_code'][0]),
                country_data.get('numeric_code', '000'),
                country_data['country']
            ))
            connection.commit()
            country_id = cursor.lastrowid
            
            app.logger.info(f"Created new country: {country_data['country']} with ID: {country_id}")
            return country_id
            
    except Exception as e:
        app.logger.error(f"Error in get_or_create_country: {str(e)}")
        raise


def get_or_create_currency(connection, currency_data):
    """
    Check if currency exists in currencies table, if not create it.
    Returns the currency_id.
    
    currency_data should contain:
    - currency_code: 3-letter code (e.g., "USD")
    - currency_symbol: Symbol (e.g., "$")
    """
    try:
        with closing(connection.cursor()) as cursor:
            # Check if currency exists
            cursor.execute(
                "SELECT id FROM currencies WHERE currency_code = %s", 
                (currency_data['currency_code'],)
            )
            currency = cursor.fetchone()
            
            if currency:
                app.logger.info(f"Currency {currency_data['currency_code']} already exists with ID: {currency[0]}")
                return currency[0]
            
            # Currency doesn't exist, insert it
            insert_query = """
                INSERT INTO currencies (currency_code, currency_symbol)
                VALUES (%s, %s)
            """
            cursor.execute(insert_query, (
                currency_data['currency_code'],
                currency_data.get('currency_symbol', '')
            ))
            connection.commit()
            currency_id = cursor.lastrowid
            
            app.logger.info(f"Created new currency: {currency_data['currency_code']} with ID: {currency_id}")
            return currency_id
            
    except Exception as e:
        app.logger.error(f"Error in get_or_create_currency: {str(e)}")
        raise


# Create Package
def create_package():
    data = request.get_json()
    
    required_fields = ['package_name', 'price', 'billing_cycle', 'country', 'country_code', 'currency_code']
    
    for field in required_fields:
        if field not in data or data[field] is None or data[field] == '':
            return jsonify({'message': f'Missing required field: {field}', 'error': 'true'}), 400

    # Validate billing cycle
    valid_cycles = ['monthly', 'annual', 'lifetime']
    if data['billing_cycle'] not in valid_cycles:
        return jsonify({'message': f'Invalid billing cycle. Must be one of: {", ".join(valid_cycles)}', 'error': 'true'}), 400

    connection = connect_to_database()
    
    try:
        # Get or create country
        country_data = {
            'country': data['country'],
            'country_code': data['country_code'],
            'country_iso': data.get('country_iso', data['country_code']),
            'numeric_code': data.get('numeric_code', '000')
        }
        country_id = get_or_create_country(connection, country_data)
        
        # Get or create currency
        currency_data = {
            'currency_code': data['currency_code'],
            'currency_symbol': data.get('currency_symbol', '')
        }
        currency_id = get_or_create_currency(connection, currency_data)
        
        with closing(connection.cursor()) as cursor:
            # Prepare package data
            is_free = False
            razorpay_plan_id = None  # Default to None for free packages

            package_data = {
                'package_name': data['package_name'],
                'price': data['price'],
                'country_id': country_id,
                'currency_id': currency_id,
                'currency_code': data['currency_code'],
                'billing_cycle': data['billing_cycle'],
                'package_description': data.get('package_description', '')
            }
            
            # Create Razorpay plan only if it's NOT free
            if not is_free:
                try:
                    razorpay_plan = create_razorpay_plan(package_data)
                    razorpay_plan_id = razorpay_plan['id']

                    # Print log after creating Razorpay plan
                    print("Razorpay plan created successfully:")
                    print(f"  Razorpay Plan ID: {razorpay_plan_id}")
                    print(f"  Package Name: {data['package_name']}")
                    print(f"  Price: {data['price']}")
                    print(f"  Currency: {package_data.get('currency_code')}")
                    print(f"  Country ID: {country_id}")
                    print(f"  Billing Cycle: {data['billing_cycle']}")
                    print(f"  Razorpay Plan Response: {razorpay_plan}")

                except Exception as rp_err:
                    print(f"Error while creating Razorpay plan for {data.get('package_name')}: {rp_err}")
                    raise
            
            # Prepare fields and values for DB
            fields = [
                'package_name', 'price', 'country_id', 'currency_id', 
                'billing_cycle', 'razorpay_plan_id'
            ]
            values = [
                data['package_name'], 
                data['price'], 
                country_id, 
                currency_id, 
                data['billing_cycle'],
                razorpay_plan_id
            ]
            
            # Optional fields
            optional_fields = {
                'package_description': 'package_description',
                'maximum_bookings': 'maximum_bookings',
                'price_after_max': 'price_after_max',
                'is_active': 'is_active',
            }
            
            for field, db_field in optional_fields.items():
                if field in data:
                    fields.append(db_field)
                    values.append(data[field])
            
            fields_str = ', '.join(fields)
            placeholders = ', '.join(['%s'] * len(values))
            
            # Insert into database
            query = f'INSERT INTO packages ({fields_str}) VALUES ({placeholders})'
            cursor.execute(query, values)
            connection.commit()
            
            # Update Razorpay plan notes with the DB package ID (if not free)
            package_id = cursor.lastrowid
            
            # Fetch newly created package with joined data
            cursor.execute("""
                SELECT p.*, c.label AS country_name, c.code AS country_code, 
                       cr.currency_code, cr.currency_symbol 
                FROM packages p
                JOIN llx_c_country c ON p.country_id = c.id
                JOIN currencies cr ON p.currency_id = cr.id
                WHERE p.id = %s
            """, (package_id,))
            package = cursor.fetchone()
            column_names = [desc[0] for desc in cursor.description]
            package_data = dict(zip(column_names, package))
            
    except Exception as e:
        connection.rollback()
        if not is_free and 'razorpay_plan' in locals():
            try:
                delete_razorpay_plan(razorpay_plan['id'])
                print(f"Rolled back Razorpay plan due to DB error. Razorpay Plan ID: {razorpay_plan.get('id')}")
            except Exception as del_err:
                print(f"Failed to delete Razorpay plan during rollback: {del_err}")
        return jsonify({'message': f"Error creating package: {str(e)}", 'error': 'true'}), 500
    finally:
        export_db_schema_to_csv(connection)
        connection.close()
    
    return jsonify({
        'message': 'Package created successfully',
        'data': package_data,
        'error': 'false'
    }), 201


def export_db_schema_to_csv(connection, output_path='db_schema.csv'):
    """Dump all tables and columns for the current database into a CSV file.

    CSV columns: table_schema, table_name, column_name, ordinal_position,
                 data_type, is_nullable, column_default, column_key, extra
    Returns the output_path on success.
    """
    try:
        with closing(connection.cursor()) as cursor:
            # Determine current database
            cursor.execute("SELECT DATABASE()")
            db_row = cursor.fetchone()
            db_name = db_row[0] if db_row else None
            if not db_name:
                raise RuntimeError("Could not determine current database name")

            query = """
                SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, ORDINAL_POSITION,
                       COLUMN_TYPE, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT,
                       COLUMN_KEY, EXTRA
                FROM INFORMATION_SCHEMA.COLUMNS
                WHERE TABLE_SCHEMA = %s
                ORDER BY TABLE_NAME, ORDINAL_POSITION
            """

            cursor.execute(query, (db_name,))
            rows = cursor.fetchall()

            # Write to CSV
            with open(output_path, mode='w', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    'table_schema', 'table_name', 'column_name', 'ordinal_position',
                    'column_type', 'data_type', 'is_nullable', 'column_default',
                    'column_key', 'extra'
                ])
                for r in rows:
                    writer.writerow(r)

        print(f"Database schema exported to {output_path}")
        return output_path

    except Exception as err:
        print(f"Failed to export DB schema: {err}")
        raise


# Get Packages with advanced filtering
def get_packages(request_args=None, **kwargs):
    # Get query parameters
    args = request_args if request_args is not None else request.args

    package_id = request.args.get('id')
    country_id = request.args.get('country_id')
    currency_id = request.args.get('currency_id')
    country_code = kwargs.get('country_code') or args.get('country_code')
    currency_code = kwargs.get('currency_code') or args.get('currency_code')
    is_active = kwargs.get('is_active') or args.get('is_active')
    billing_cycle = request.args.get('billing_cycle')
    limit = request.args.get('limit', default=10, type=int)
    page = request.args.get('page', default=1, type=int)
    order_by = request.args.get('order_by', default='created_at')
    sort = request.args.get('sort', default='DESC').upper()
    
    connection = connect_to_database()
    
    try:
        with closing(connection.cursor()) as cursor:
            # Base query with joins
            query = """
            SELECT 
                p.*, 
                c.label, 
                c.code,
                cr.currency_code,
                cr.currency_symbol
            FROM packages p
            JOIN llx_c_country c ON p.country_id = c.id
            JOIN currencies cr ON p.currency_id = cr.id
            WHERE 1=1
            """
            
            # Count query for pagination
            count_query = """
            SELECT COUNT(*)
            FROM packages p
            JOIN llx_c_country c ON p.country_id = c.id
            JOIN currencies cr ON p.currency_id = cr.id
            WHERE 1=1
            """
            
            conditions = []
            values = []
            
            # Add filters
            if package_id:
                conditions.append("p.id = %s")
                values.append(package_id)
            if country_id:
                conditions.append("p.country_id = %s")
                values.append(country_id)
            if currency_id:
                conditions.append("p.currency_id = %s")
                values.append(currency_id)
            if country_code:
                conditions.append("c.code = %s")
                values.append(country_code.upper())
            if currency_code:
                conditions.append("cr.currency_code = %s")
                values.append(currency_code.upper())
            if is_active is not None:
                conditions.append("p.is_active = %s")
                values.append(is_active.lower() == 'true')
            if billing_cycle:
                conditions.append("p.billing_cycle = %s")
                values.append(billing_cycle)
            
            # Apply conditions
            if conditions:
                condition_str = " AND " + " AND ".join(conditions)
                query += condition_str
                count_query += condition_str
            
            # Get total count
            cursor.execute(count_query, values)
            total_records = cursor.fetchone()[0]
            
            # Validate order_by
            valid_columns = {
                'id', 'package_name', 'price', 'country_id', 'currency_id',
                'is_active', 'billing_cycle', 'created_at', 'updated_at'
            }
            if order_by not in valid_columns:
                return jsonify({'message': f'Invalid order_by column: {order_by}', 'error': 'true'}), 400
            
            # Validate sort direction
            if sort not in ('ASC', 'DESC'):
                return jsonify({'message': f'Invalid sort direction: {sort}', 'error': 'true'}), 400
            
            # Add sorting
            query += f" ORDER BY p.{order_by} {sort}"
            
            # Add pagination
            offset = (page - 1) * limit
            query += " LIMIT %s OFFSET %s"
            values.extend([limit, offset])
            
            # Execute query
            cursor.execute(query, values)
            packages = cursor.fetchall()
            column_names = [desc[0] for desc in cursor.description]
            
            # Convert to list of dictionaries
            packages_data = [dict(zip(column_names, pkg)) for pkg in packages]
            
    except Exception as e:
        return jsonify({'message': f"Error fetching packages: {str(e)}", 'error': 'true'}), 500
    finally:
        connection.close()
    
    return jsonify({
        'data': packages_data,
        'total_records': total_records,
        'page': page,
        'per_page': limit,
        'error': 'false'
    }), 200


# Update Package
def update_package():
    data = request.get_json()
    package_id = data.get('id')
    
    if not package_id:
        return jsonify({'message': 'Package ID is required', 'error': 'true'}), 400
    
    connection = connect_to_database()
    
    try:
        # Get or create country if country data is provided
        if 'country' in data and 'country_code' in data:
            country_data = {
                'country': data['country'],
                'country_code': data['country_code'],
                'country_iso': data.get('country_iso', data['country_code']),
                'numeric_code': data.get('numeric_code', '000')
            }
            country_id = get_or_create_country(connection, country_data)
            data['country_id'] = country_id
        
        # Get or create currency if currency data is provided
        if 'currency_code' in data:
            currency_data = {
                'currency_code': data['currency_code'],
                'currency_symbol': data.get('currency_symbol', '')
            }
            currency_id = get_or_create_currency(connection, currency_data)
            data['currency_id'] = currency_id
        
        with closing(connection.cursor()) as cursor:
            # Get existing package details including Razorpay plan ID
            cursor.execute("""
                SELECT p.*, c.code, cr.currency_code
                FROM packages p
                JOIN llx_c_country c ON p.country_id = c.id
                JOIN currencies cr ON p.currency_id = cr.id
                WHERE p.id = %s
            """, (package_id,))
            package = cursor.fetchone()
            
            if not package:
                return jsonify({'message': 'Package not found', 'error': 'true'}), 404
                
            column_names = [desc[0] for desc in cursor.description]
            existing_package = dict(zip(column_names, package))
            razorpay_plan_id = existing_package.get('razorpay_plan_id')
            
            # Prepare update fields
            update_fields = []
            values = []
            package_data = {
                'package_name': existing_package['package_name'],
                'price': existing_package['price'],
                'country_id': existing_package['country_id'],
                'currency_id': existing_package['currency_id'],
                'currency_code': existing_package['currency_code'],
                'billing_cycle': existing_package['billing_cycle'],
                'package_description': existing_package.get('package_description', '')
            }
            
            allowed_fields = {
                'package_name': 'package_name',
                'package_description': 'package_description',
                'price': 'price',
                'country_id': 'country_id',
                'currency_id': 'currency_id',
                'is_active': 'is_active',
                'billing_cycle': 'billing_cycle',
                'maximum_bookings': 'maximum_bookings',
                'price_after_max': 'price_after_max'
            }
            
            for field, db_field in allowed_fields.items():
                if field in data:
                    # Validate country_id and currency_id if being updated
                    if field == 'country_id':
                        cursor.execute("SELECT id, code FROM llx_c_country WHERE id = %s", (data[field],))
                        country = cursor.fetchone()
                        if not country:
                            return jsonify({'message': 'Country not found', 'error': 'true'}), 404
                        package_data['country_id'] = data[field]
                    elif field == 'currency_id':
                        cursor.execute("SELECT id, currency_code FROM currencies WHERE id = %s", (data[field],))
                        currency = cursor.fetchone()
                        if not currency:
                            return jsonify({'message': 'Currency not found', 'error': 'true'}), 404
                        package_data['currency_id'] = data[field]
                        package_data['currency_code'] = currency[1]
                    
                    # Update package_data for Razorpay
                    package_data[field] = data[field]
                    
                    update_fields.append(f"{db_field} = %s")
                    values.append(data[field])
            
            if not update_fields:
                return jsonify({'message': 'No fields to update', 'error': 'true'}), 400
            
            # Update Razorpay plan if needed
            if razorpay_plan_id and any(field in data for field in ['package_name', 'price', 'package_description', 'billing_cycle']):
                new_plan_id = update_razorpay_plan(razorpay_plan_id, package_data)
                if new_plan_id != razorpay_plan_id:
                    update_fields.append("razorpay_plan_id = %s")
                    values.append(new_plan_id)
            
            # Add package_id for WHERE clause
            values.append(package_id)
            
            # Build and execute update query
            query = f"UPDATE packages SET {', '.join(update_fields)} WHERE id = %s"
            cursor.execute(query, values)
            connection.commit()
            
            # Get the updated package with joined info
            cursor.execute("""
                SELECT p.*, c.label, c.code, cr.currency_code, cr.currency_symbol 
                FROM packages p
                JOIN llx_c_country c ON p.country_id = c.id
                JOIN currencies cr ON p.currency_id = cr.id
                WHERE p.id = %s
            """, (package_id,))
            package = cursor.fetchone()
            column_names = [desc[0] for desc in cursor.description]
            package_data = dict(zip(column_names, package))
            
    except Exception as e:
        connection.rollback()
        return jsonify({'message': f"Error updating package: {str(e)}", 'error': 'true'}), 500
    finally:
        connection.close()
    
    return jsonify({
        'message': 'Package updated successfully',
        'data': package_data,
        'error': 'false'
    }), 200


# Delete Package (Soft delete by setting is_active to False)
def delete_package():
    data = request.get_json()
    package_id = data.get('id')
    
    if not package_id:
        return jsonify({'message': 'Package ID is required', 'error': 'true'}), 400
    
    connection = connect_to_database()
    
    try:
        with closing(connection.cursor()) as cursor:
            # Get the package to check if it exists and get Razorpay plan ID
            cursor.execute("SELECT id, razorpay_plan_id FROM packages WHERE id = %s", (package_id,))
            package = cursor.fetchone()
            if not package:
                return jsonify({'message': 'Package not found', 'error': 'true'}), 404
                
            razorpay_plan_id = package[1]
            
            # Soft delete (set is_active to FALSE)
            cursor.execute("UPDATE packages SET is_active = FALSE WHERE id = %s", (package_id,))
            connection.commit()
            
            # Deactivate Razorpay plan if exists
            if razorpay_plan_id:
                try:
                    # Razorpay API requires an empty dictionary as data parameter
                    razorpay_client.plan.delete(razorpay_plan_id, {})
                except Exception as e:
                    app.logger.error(f"Error deleting Razorpay plan: {str(e)}")
                    # Continue even if plan deletion fails as we've already soft-deleted in DB
            
    except Exception as e:
        connection.rollback()
        return jsonify({'message': f"Error deleting package: {str(e)}", 'error': 'true'}), 500
    finally:
        connection.close()
    
    return jsonify({
        'message': 'Package deactivated successfully',
        'error': 'false'
    }), 200