#!/usr/bin/env python3

# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC
# SPDX-License-Identifier: Apache-2.0

"""This tool creates and advertises a SciToken key in the Frontend

The private and public key generation follows the example in the scitokens library:
https://github.com/scitokens/scitokens/blob/master/src/scitokens/tools/admin_create_key.py

The advertising follows the instructions and conventions in
https://htcondor.readthedocs.io/en/latest/admin-manual/file-and-cred-transfer.html#generating-a-scitokens-key-pair
"""

import argparse
import binascii
import grp
import json
import os
import os.path
import pwd
import sys

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization

# from cryptography.hazmat.primitives.asymmetric import ec, rsa, types
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from scitokens.utils import bytes_from_long, string_from_long

from glideinwms.frontend import glideinFrontendConfig
from glideinwms.lib import defaults

verbose = False


# def log(*args, options={"file": sys.stderr, "logger": None, "verbose": False}):
def log(*args):
    if verbose:
        print(*args, file=sys.stderr)


def get_args() -> argparse.Namespace:
    """Parse command line arguments from sys.argv and set defaults using the RPM installation defaults.

    Uses constants defined in glideinwms.lib.defaults

    Returns:
        argparse.Namespace: The parsed arguments
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-v",
        "--verbose",
        help="Print verbose messages",
        action="store_true",
    )
    parser.add_argument(
        "-p",
        "--print-config",
        help="Print the configuration line of the credential",
        action="store_true",
    )
    parser.add_argument(
        "-k",
        "--private-key-file",
        help="Location of the unencrypted private key file. Will be created if it does not exist. Defaults to the standard one.",
        default=None,
    )
    parser.add_argument(
        "-f",
        "--force-new-key",
        help="Forces the creation of a new private key and save it to the private key file. Defaults to False.",
        action="store_true",
    )
    parser.add_argument(
        "-e",
        "--ec",
        help="Use elliptical curve cryptography (ES256) instead of RSA (RS256). Defaults to False or whatever is the type of the existing key.",
        action="store_true",
    )
    # group = parser.add_mutually_exclusive_group(required=False)  # If web-dir and web-url are provided, work-dir is ignored
    parser.add_argument(
        "-d",
        "--work-dir",
        help="Location of the Frontend configuration. This is used to set the base URI if not provided. Defaults to RPM installation.",
        default=None,
    )
    parser.add_argument(
        "-w",
        "--web-dir",
        help="Location of base directory for web server. A subdirectory is added if a group is specified. Defaults to the scitokens subdirectory of the Web area.",
        default=None,
    )
    parser.add_argument(
        "-u", "--web-url", help="Base URL for web server. Defaults to the /cred/scitokens URL in the Frontend."
    )
    parser.add_argument(
        "-g",
        "--group",
        help="Optional key group, must be one of the Frontend groups. This argument changes the key issuer and the publishing of the public key, not the private key (use '-k' for that). By default will use the global key, not specific to any group.",
    )
    parser.add_argument(
        "-c",
        "--client-user",
        help="The user running the client (Frontend or Decision Engine). This is used to determine files locations.",
    )
    out = parser.parse_args()
    if out.client_user:
        # Change the directory variables
        defaults.set_base_dir(out.client_user)
    elif os.getuid() == 0:
        # Could also guess from the directories in /etc
        raise ValueError("Set --client-user (e.g. to frontend/decisionengine) when running as root.")
    # Adjust default paths depending on the user
    if out.private_key_file is None:
        out.private_key_file = os.path.join(defaults.key_dir, "frontend_scitokens_key.pem")
    if out.work_dir is None:
        out.work_dir = os.path.join(defaults.base_dir, "vofrontend")
    if out.web_dir is None:
        out.web_dir = os.path.join(defaults.web_sec_dir, "scitokens")
    return out


# def make_key(private_key_file, use_ec=False, force_new=False) -> types.PrivateKeyTypes:  - AttributeError: module 'cryptography.hazmat.primitives.asymmetric.types' has no attribute 'PrivateKeyTypes'
def make_key(private_key_file, use_ec=False, force_new=False, owner=None):
    """Retrieve the RSA or EC private key.

    Read it from the key file, if existing, otherwise generate and save it.

    Args:
        private_key_file (str): Location of the private key.
        use_ec (bool): If True, use EC private key instead of RSA private key. Defaults to False.
        force_new (bool): If True, overwrite existing private key. Defaults to False.
        owner (str): The user running the client and owner of the file.

    Returns:
        types.PrivateKeyTypes: RSA or EC Private key bytes.
    """
    if os.path.isfile(private_key_file) and not force_new:
        log(f"Reading private key from {private_key_file}.")
        with open(private_key_file, "rb") as key_file:
            private_key = serialization.load_pem_private_key(key_file.read(), password=None, backend=default_backend())
    else:
        # Create the private key
        if use_ec:
            log(f"Generating EC private key and saving it to {private_key_file}.")
            private_key = ec.generate_private_key(ec.SECP256R1(), backend=default_backend())
        else:
            log(f"Generating RSA private key and saving it to {private_key_file}.")
            private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())
        # Write the PEM private key
        private_pem = private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption(),
        )
        pk_descriptor = os.open(
            path=private_key_file,
            flags=(
                os.O_WRONLY  # access mode: write only
                | os.O_CREAT  # create if not exists
                | os.O_TRUNC  # truncate the file to zero
            ),
            mode=0o600,
        )
        with open(pk_descriptor, "wb") as key_file:
            key_file.write(private_pem)
        if owner:
            user_info = pwd.getpwnam(owner)
            os.chown(private_key_file, user_info.pw_uid, user_info.pw_gid)
    return private_key


def get_jwk_pk_string(private_key, key_type) -> (str, str):
    """Retrieve the JWK public key from a private key.

    Args:
        private_key (types.PrivateKeyTypes): RSA or EC private key bytes.
        key_type (str): private key type, RSA or EC.

    Returns:
        (str, str): key ID and multiline JSON formatted string with serialized JWK public key.
    """
    public_key = private_key.public_key()
    # Get the public numbers
    public_numbers = public_key.public_numbers()
    # Hash the public "n", and use it for the Key ID (kid)
    digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
    if key_type == "EC":
        digest.update(bytes_from_long(public_numbers.x))
    else:  # RSA
        digest.update(bytes_from_long(public_numbers.n))
    kid = binascii.hexlify(digest.finalize())
    # Shorten the kid to 4 characters
    kid = kid[:4]
    # Define and return the JSON public key (jwk)
    if key_type == "EC":
        jwk_public_key = {
            "keys": [
                {
                    "alg": "ES256",
                    "crv": "P-256",
                    "x": string_from_long(public_numbers.x),
                    "y": string_from_long(public_numbers.y),
                    "kty": "EC",
                    "use": "sig",
                    "kid": kid.decode("utf-8"),
                }
            ]
        }
    else:
        jwk_public_key = {
            "keys": [
                {
                    "alg": "RS256",
                    "n": string_from_long(public_numbers.n),
                    "e": string_from_long(public_numbers.e),
                    "kty": "RSA",
                    "use": "sig",
                    "kid": kid.decode("utf-8"),
                }
            ]
        }
    log(f"Returning {key_type} JWK public key, kid: {kid.decode()}.")
    return kid.decode(), json.dumps(jwk_public_key, sort_keys=True, indent=4, separators=(",", ": "))


def prepare_web_dir(web_dir, group_name, uri, public_key, owner) -> (str, str):
    """Set up the web directory to advertise the SciToken key.

    Args:
        web_dir (str): base web directory to advertise the SciToken key.
        group_name (str): group name.
        uri (str): URI.
        public_key (str): public key.
        owner (str): User running the client and file owner.

    Returns:
        str, str: Public key issuer and advertising URI.
    """
    uid = gid = None
    if owner:
        user_info = pwd.getpwnam(owner)
        uid = user_info.pw_uid
        gid = user_info.pw_gid
        try:
            # If there is the glidein group, use that instead of the default one for the Web files
            group_info = grp.getgrnam("glidein")
            gid = group_info.gr_gid
        except KeyError:
            pass
    public_key_fname = "public-key.jwks"
    group_uri = ""
    if group_name:
        web_dir = os.path.join(web_dir, group_name)
        group_uri = f"/{group_name}"
    os.makedirs(web_dir, mode=0o775, exist_ok=True)
    log(f"Setting up web directory {web_dir} with config and public key '{public_key_fname}' files.")
    public_key_path = os.path.join(web_dir, public_key_fname)
    with open(public_key_path, "wt") as pk_file:
        pk_file.write(public_key)
    if uid is not None:
        os.chown(web_dir, uid, gid)
        os.chown(public_key_path, uid, gid)
        os.chmod(public_key_path, 0o664)
    web_dir = os.path.join(web_dir, ".well-known")
    os.makedirs(web_dir, mode=0o775, exist_ok=True)
    issuer = f"{uri}{group_uri}"
    config_file_path = os.path.join(web_dir, "openid-configuration")
    with open(config_file_path, "wt") as config_file:
        config_file.write("{\n" f'    "issuer":"{issuer}",\n    "jwks_uri":"{issuer}/{public_key_fname}"\n' "}\n")
    if uid is not None:
        os.chown(web_dir, uid, gid)
        os.chown(config_file_path, uid, gid)
        os.chmod(config_file_path, 0o664)
    return f"{issuer}", f"{issuer}/{public_key_fname}"


def main():
    """Main function"""
    # Get command line arguments and set up
    args = get_args()
    if args.verbose:
        # Set verbose logging via the global variable verbose
        global verbose
        verbose = True
    # Load the Frontend configuration to retrieve its URI
    group = args.group or "main"
    gfe_desc = glideinFrontendConfig.ElementMergedDescript(args.work_dir, group)
    # Calculate and verify consistency of private key
    private_key = make_key(args.private_key_file, args.ec, args.force_new_key, args.client_user)
    if isinstance(private_key, rsa.RSAPrivateKey):
        key_type = "RS256"
    elif isinstance(private_key, ec.EllipticCurvePrivateKey):
        key_type = "ES256"
    else:
        raise ValueError("The private key is of an unknown type (not RSA/RS256 or EC/ES256).")
    if args.ec and not key_type == "ES256":
        raise ValueError("Private key type mismatch. Requested EC/ES256, but loaded RSA/RS256.")
    # Calculate and publish public key
    key_id, pk_str = get_jwk_pk_string(private_key, key_type)
    uri = args.web_url or os.path.join(os.path.dirname(gfe_desc.frontend_data["MonitoringWebURL"]), "cred", "scitokens")
    if uri.startswith("http://"):
        # Forcing https, http is not accepted for token verification
        uri = "https:" + uri[5:]
    iss, adv_uri = prepare_web_dir(args.web_dir, args.group, uri, pk_str, args.client_user)
    log(f"Public key from iss: {iss} advertised at: {adv_uri}")
    log(
        f"""Example of XML credential configuration:
<credential absfname="SciTokenGenerator"
    context="{{'algorithm': '{key_type}', 'key_file': '{args.private_key_file}', 'key_id': '{key_id}', 'issuer': '{iss}', 'scope': 'compute.read compute.modify compute.create compute.cancel', 'type': 'scitoken'}}"
    purpose="request" security_class="frontend" trust_domain="grid" type="generator"/>"""
    )
    if args.print_config:
        print(
            f"""<credential absfname="SciTokenGenerator"
    context="{{'algorithm': '{key_type}', 'key_file': '{args.private_key_file}', 'key_id': '{key_id}', 'issuer': '{iss}', 'scope': 'compute.read compute.modify compute.create compute.cancel', 'type': 'scitoken'}}"
    purpose="request" security_class="frontend" trust_domain="grid" type="generator"/>"""
        )


# Script invocation
if __name__ == "__main__":
    try:
        sys.exit(main())
    except Exception as e:
        sys.stderr.write("ERROR: Exception msg %s\n" % str(e))
        sys.exit(1)
