#!/usr/bin/env python3

# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC
# SPDX-License-Identifier: Apache-2.0

# Description:
#   Generate the glidein_startup.sh command to be run manually for given entry


import argparse
import itertools
import logging
import os
import pprint
import stat
import sys

try:
    import htcondor  # pylint: disable=import-error
except ImportError:
    import htcondor2 as htcondor

GLIDEIN_PARAM_PREFIX = "GlideinParam"
SCRIPT_TEMPLATE = """#!/bin/sh

# glidein_startup.sh command
%s
"""

# Following is duplicated from glideFactoryLib to keep this script standalone
escape_table = {
    ".": ".dot,",
    ",": ".comma,",
    "&": ".amp,",
    "\\": ".backslash,",
    "|": ".pipe,",
    "`": ".fork,",
    '"': ".quot,",
    "'": ".singquot,",
    "=": ".eq,",
    "+": ".plus,",
    "-": ".minus,",
    "<": ".lt,",
    ">": ".gt,",
    "(": ".open,",
    ")": ".close,",
    "{": ".gopen,",
    "}": ".gclose,",
    "[": ".sopen,",
    "]": ".sclose,",
    "#": ".comment,",
    "$": ".dollar,",
    "*": ".star,",
    "?": ".question,",
    "!": ".not,",
    "~": ".tilde,",
    ":": ".colon,",
    ";": ".semicolon,",
    " ": ".nbsp,",
}


# Following is duplicated from glideFactoryLib to keep this script standalone
def escape_param(param_str):
    global escape_table
    out_str = ""
    for c in param_str:
        out_str = out_str + escape_table.get(c, c)
    return out_str


def log_debug(msg, header=""):
    if header:
        logging.debug("=" * (len(header) + 2))
        logging.debug(" %s " % header)
        logging.debug("=" * (len(header) + 2))
    logging.debug(pprint.pformat(msg))


def parse_opts(argv):
    description = "Generate glidein_startup command\n\n"
    description += "Example:\n"
    description += 'manual_glidein_startup --wms-collector=fermicloud145.fnal.gov:8618 --client-name=Frontend-master-v1_0.main --req-name=TEST_SITE_2@v1_0@GlideinFactory-master --cmd-out-file=/tmp/glidein_startup_wrapper --override-args="-proxy http://httpproxy.mydomain -v fast"'

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter)

    parser.add_argument(
        "--wms-collector",
        type=str,
        action="store",
        dest="wms_collector",
        default="gfactory-2.opensciencegrid.org,gfactory-itb-1.opensciencegrid.org",
        help="COLLECTOR_HOST for WMS Collector(s) in CSV format (default: gfactory-2.opensciencegrid.org,gfactory-itb-1.opensciencegrid.org)",
    )

    parser.add_argument(
        "--req-name",
        type=str,
        action="store",
        dest="req_name",
        help="Factory entry info: ReqName in the glideclient classad",
    )

    parser.add_argument(
        "--client-name",
        type=str,
        action="store",
        dest="client_name",
        help="Frontend group info: ClientName in the glideinclient classad",
    )

    parser.add_argument(
        "--glidein-startup",
        type=str,
        action="store",
        dest="glidein_startup",
        default="glidein_startup.sh",
        help="Full path to glidein_startup.sh to use",
    )

    parser.add_argument(
        "--override-args", type=str, action="store", dest="override_args", help="Override args to glidein_startup.sh"
    )

    parser.add_argument(
        "--cmd-out-file",
        type=str,
        action="store",
        dest="cmd_out_file",
        help="File where glidein_startup.sh command is created",
    )

    parser.add_argument("--debug", action="store_true", dest="debug", default=False, help="Enable debug logging")

    options = parser.parse_args(argv[1:])

    if (options.req_name is None) or (options.client_name is None):
        logging.error("Missing one or more required options: --client-name, --req-name")
        sys.exit(1)

    # Initialize logging
    if options.debug:
        logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.DEBUG)
    else:
        logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)

    return options


def params2args(ad, param_prefix="GlideinParam"):
    args = {}
    for attr in ad:
        if attr.startswith(param_prefix):
            key = "-param_%s" % attr.replace(param_prefix, "")
            args[key] = escape_param(str(ad.get(attr)))
    return args


def main(argv):
    options = parse_opts(argv)
    req_name = options.req_name
    client_name = options.client_name
    wms_collector = options.wms_collector.split(",")

    constraint_gc = f'(MyType=="glideclient") && (ClientName=="{client_name}") && (ReqName=="{req_name}")'
    constraint_gf = '(MyType=="glidefactory") && (Name=="%s")' % (req_name)

    glidein_startup_args = {}
    for collector_host in wms_collector:
        collector = htcondor.Collector(collector_host)
        ads_gc = collector.query(htcondor.AdTypes.Any, constraint_gc)

        if ads_gc:
            # If a given client name has multiple credentials, there will be
            # multiple classads matching the criteria.
            # For manual startup of glidein, credentials will be provided by
            # the user as an option so just take first classad and ignore
            # all the other classads
            ad_gc = ads_gc[0]
            log_debug(ad_gc, header="glideclient classad")

            # entry_info = options.req_name.split('@')
            # if len(entry_info) != 3:
            #    raise InvalidValueError(
            #        '--req-name', options.req_name,
            #        '<entry_name>@gfactory_instance@<factory_name>')

            # Fetch the entry classad
            ads_gf = collector.query(htcondor.AdTypes.Any, constraint_gf)
            if len(ads_gf) == 0:
                # Did not find glidefactory classad. Cannot use this one.
                logging.warn(
                    "Found glideclient for ClientName=%s classad but corresponding entry classad with Name=%s not found in WMS Collector %s, ignoring"
                    % (client_name, req_name, collector_host)
                )
                continue

            ad_gf = ads_gf[0]
            log_debug(ad_gf, header="glidefactory classad")

            glidein_startup_args["-factory"] = ad_gf.get("FactoryName")
            glidein_startup_args["-name"] = ad_gf.get("GlideinName")
            glidein_startup_args["-entry"] = ad_gf.get("EntryName")
            glidein_startup_args["-clientname"] = ad_gc.get("FrontendName")
            glidein_startup_args["-clientgroup"] = ad_gc.get("GroupName")
            glidein_startup_args["-proxy"] = ad_gf.get("GLIDEIN_ProxyURL")
            glidein_startup_args["-web"] = ad_gf.get("WebURL")
            glidein_startup_args["-dir"] = ad_gf.get("GLIDEIN_WorkDir")
            glidein_startup_args["-sign"] = ad_gf.get("WebDescriptSign")
            glidein_startup_args["-signtype"] = ad_gf.get("WebSignType")
            glidein_startup_args["-signentry"] = ad_gf.get("WebEntryDescriptSign")
            glidein_startup_args["-cluster"] = 0
            glidein_startup_args["-subcluster"] = 0
            glidein_startup_args["-submitcredid"] = "UNAVAILABLE"
            glidein_startup_args["-schedd"] = "UNAVAILABLE"
            glidein_startup_args["-descript"] = ad_gf.get("WebDescriptFile")
            glidein_startup_args["-descriptentry"] = ad_gf.get("WebEntryDescriptFile")
            glidein_startup_args["-clientweb"] = ad_gc.get("WebURL")
            glidein_startup_args["-clientwebgroup"] = ad_gc.get("WebGroupURL")
            glidein_startup_args["-clientsign"] = ad_gc.get("WebDescriptSign")
            glidein_startup_args["-clientsigntype"] = ad_gc.get("WebSignType")
            glidein_startup_args["-clientsigngroup"] = ad_gc.get("WebGroupDescriptSign")
            glidein_startup_args["-clientdescript"] = ad_gc.get("WebDescriptFile")
            glidein_startup_args["-clientdescriptgroup"] = ad_gc.get("WebGroupDescriptFile")
            glidein_startup_args["-slotslayout"] = ad_gf.get("GLIDEIN_SlotsLayout")
            glidein_startup_args["-v"] = ad_gf.get("GLIDEIN_Verbosity")
            # Default para to add
            glidein_startup_args["-param_GLIDEIN_Client"] = ad_gc.get("ClientName")

            # Now extract all the params that are advertised in the ads
            # params from glideclient classad should override params from
            # the glidefactory classad
            for ad in (ad_gf, ad_gc):
                glidein_startup_args.update(params2args(ad))

    if not glidein_startup_args:
        logging.error("Could not find entry and/or client matching your criteria")
        sys.exit(1)

    # Override args
    if options.override_args:
        o_args_list = options.override_args.split(" ")
        o_args = dict(itertools.zip_longest(*[iter(o_args_list)] * 2, fillvalue=""))
        param_prefix = "-%s" % GLIDEIN_PARAM_PREFIX
        for o_a in o_args:
            if not o_a.startswith("-"):
                logging.error('Positioning error for "%s" in --override-args' % o_a)
                sys.exit(1)
            elif (not o_a.startswith(param_prefix)) and (o_a in glidein_startup_args):
                glidein_startup_args[o_a] = o_args[o_a]
            else:
                logging.error('Unrecognized option "%s" in --override-args' % o_a)
                sys.exit(1)

        # Override params
        glidein_startup_args.update(params2args(o_args, param_prefix=param_prefix))

    command = options.glidein_startup
    for arg in glidein_startup_args:
        if glidein_startup_args[arg] not in (None, ""):
            command = f"{command} {arg} {glidein_startup_args[arg]}"

    if options.cmd_out_file:
        try:
            with open(options.cmd_out_file, "w") as out:
                out.write(SCRIPT_TEMPLATE % command)
                os.fchmod(out.fileno(), stat.S_IRWXU | stat.S_IRWXG | stat.S_IROTH | stat.S_IXOTH)
        except Exception as e:
            logging.error("Unable to create cmd file %s" % options.cmd_out_file)
            logging.error(e)
    else:
        log_debug(glidein_startup_args, header="Arguments")
        print(command)

    ################################################################################


# Main
################################################################################

if __name__ == "__main__":
    sys.exit(main(sys.argv))
