#!/usr/bin/env python3
"""
Test script to debug Zitadel JWT creation and token exchange
"""

import os
import json
import jwt
import time
import httpx
import asyncio
from dotenv import load_dotenv


async def test_zitadel_connection():
    """Test the Zitadel connection and JWT creation."""
    load_dotenv()

    print("🔍 Testing Zitadel Connection...")
    print("=" * 50)

    # Get configuration
    zitadel_host = os.getenv("ZITADEL_HOST", "").rstrip("/")
    service_account_key = os.getenv("ZITADEL_SERVICE_ACCOUNT_KEY")

    if not all([zitadel_host, service_account_key]):
        print("❌ Missing required configuration")
        return False

    try:
        # Try to parse the service account key as JSON first
        try:
            key_data = json.loads(service_account_key)
            print("✅ Service account key is valid JSON")

            if "key" in key_data:
                private_key = key_data["key"]
                print("✅ Found key in JSON structure")
            else:
                raise ValueError("Service account key, keyID or userId is not JSON")

            if "keyId" in key_data:
                key_id = key_data["keyId"]
                print("✅ Found keyId in JSON structure")

            else:
                raise ValueError("Service account key, keyID or userId is not JSON")

            if "userId" in key_data:
                client_id = key_data["userId"]
                print("✅ Found userId in JSON structure")
            else:
                raise ValueError("Service account key, keyID or userId is not JSON")
        except json.JSONDecodeError:
            # Not JSON, assume it's the private key directly
            raise ValueError("Failed to parse JSON from service account key")

        print(f"Host: {zitadel_host}")
        print(f"Client ID: {client_id}")
        print(f"Key ID: {key_id}")
        print(f"Service Account Key: {'[SET]' if service_account_key else '[NOT SET]'}")

        # Clean up the private key - remove escaped newlines if present
        if "\\n" in private_key:
            private_key = private_key.replace("\\n", "\n")
            print("✅ Cleaned up escaped newlines in private key")

        # Create JWT payload
        payload = {
            "iss": client_id,
            "sub": client_id,
            "aud": f"{zitadel_host}",
            "iat": int(time.time()),
            "exp": int(time.time()) + 3600,  # 1 hour
        }

        # Create JWT header with keyID
        header = {"kid": key_id, "alg": "RS256", "typ": "JWT"}

        print(f"JWT Payload: {json.dumps(payload, indent=2)}")
        print(f"JWT Header: {json.dumps(header, indent=2)}")

        # Create JWT
        print(f"Private key: \n{private_key[:100]}... ...{private_key[-100:]}\n")
        try:
            jwt_token = jwt.encode(
                payload, private_key, algorithm="RS256", headers=header
            )
            print("✅ JWT token created successfully")
            print(f"JWT Token (first 50 chars): {jwt_token[:50]}...")
        except Exception as e:
            print(f"❌ Failed to create JWT: {e}")
            return False

        scope = "openid urn:zitadel:iam:org:project:id:zitadel:aud"
        grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer"

        print("Curl command:")
        print("curl --request POST \\")
        print(f"  --url {zitadel_host}/oauth/v2/token \\")
        print("  --header 'Content-Type: application/x-www-form-urlencoded' \\")
        print(f"  --data grant_type={grant_type} \\")
        print(f"  --data scope='{scope}' \\")
        print(f"  --data assertion={jwt_token}")

        # Exchange JWT for access token
        data = {
            "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
            "scope": "'openid urn:zitadel:iam:org:project:id:zitadel:aud'",
            "assertion": jwt_token,
        }

        token_url = f"{zitadel_host}/oauth/v2/token"
        print(f"\n🔗 Requesting token from: {token_url}")
        print(f"Request data: {json.dumps(data, indent=2)}")

        async with httpx.AsyncClient() as client:
            response = await client.post(
                token_url,
                data=data,
                headers={
                    "Content-Type": "application/x-www-form-urlencoded",
                },
            )

            print(f"Response status: {response.status_code}")
            print(f"Response headers: {dict(response.headers)}")

            if response.status_code == 200:
                token_data = response.json()
                print("✅ Successfully obtained access token")
                print(f"Token type: {token_data.get('token_type')}")
                print(f"Expires in: {token_data.get('expires_in')} seconds")
                print(
                    f"Access token (first 50 chars): {token_data.get('access_token', '')[:50]}..."
                )
                return True
            else:
                print("❌ Failed to get access token")
                print(f"Response: {response.text}")
                return False

    except Exception as e:
        print(f"❌ Error during testing: {e}")
        import traceback

        traceback.print_exc()
        return False


if __name__ == "__main__":
    asyncio.run(test_zitadel_connection())
