#!/usr/libexec/platform-python

# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

from __future__ import print_function

import os
import sys
import time

from contextlib import contextmanager


class PackageManager:
    YUM = 1
    DNF = 2


def get_package_manager():
    try:
        import yum
        return PackageManager.YUM
    except ImportError:
        pass

    try:
        import dnf
        return PackageManager.DNF
    except ImportError:
        raise Exception("yum and dnf package managers are no exist.\
            To continue, you need to install one of them.")


@contextmanager
def silence_stdout():
    old_target = sys.stdout
    try:
        with open(os.devnull, "w") as new_target:
            sys.stdout = new_target
            yield new_target
    finally:
        sys.stdout = old_target


def error_print(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


def download_with_yum(dest_dir, package_name, only_list_deps, exist_deps, name_resolved=False):
    import yum
    from rpmUtils.arch import getBaseArch

    yb = yum.YumBase()
    yb.preconf.init_plugins=False
    # to support work as non root user.
    yb.setCacheDir()

    yb.conf.downloaddir = dest_dir

    for arch in ["noarch", getBaseArch()]:
        packages_list = yb.doPackageLists(patterns=[package_name + "." + arch])
        if packages_list.available:
            error_print(package_name + "." + arch + " is available")
            break
        if packages_list.installed:
            break

    if packages_list.installed:
        error_print(package_name + " is already installed")
        sys.exit(0)

    if not packages_list.available:
        if not name_resolved:
            providers = yb.searchPackageProvides(args=[package_name])
            if providers:
                # use the first available package
                return download_with_yum(dest_dir, next(iter(providers)).name, only_list_deps, exist_deps, True)
        error_print("Package '{0}' is not installed and not found in enabled package repositories. \
            To continue, you need to install '{0}' package.".format(package_name))
        sys.exit(50)

    dep_name_set = set()
    packages_to_download = packages_list.available
    yb.doTsSetup()
    for pkg in packages_to_download:
        yb.tsInfo.addInstall(pkg)
        yb.localPackages.append(pkg)
        dep_name_set.add(str(pkg))
    result, result_msg = yb.resolveDeps()
    if result == 1:
        for msg in result_msg:
            error_print(msg)
        error_print("Dependency resolution of {0} package failed. \
            To continue, you need to install {0} package.".format(package_name))
        sys.exit(50)

    for pkg in yb.tsInfo.getMembers():
        if pkg.ts_state in ("i", "u") and pkg.po not in packages_to_download:
            packages_to_download.append(pkg.po)
            dep_name_set.add(str(pkg.po))

    if only_list_deps:
        return list(dep_name_set)

    for pkg in packages_to_download:
        pkg.repo.copy_local = True
        pkg.repo.cache = 0

    if only_list_deps:
        return list(dep_name_set)

    probs = yb.downloadPkgs(packages_to_download)
    if probs:
        for key in probs:
            for error in probs[key]:
                error_print("%s: %s", key, error)
                error_print("Package {0} is not installed and not found in enabled package repositories. \
                    To continue, you need to install {0} package.".format(package_name))
                sys.exit(50)

    downloaded_packages = []
    for pkg in packages_to_download:
        downloaded_packages.append(pkg.localpath)

    return downloaded_packages


def is_package_installed(installed_pkgs, pkg):
    key = (pkg.name, pkg.arch, pkg.version)
    inst_pkg = installed_pkgs.get(key, [None])[0]
    return inst_pkg is not None


def parse_packages(base, installed_pkgs, pkgs):
    import dnf
    import hawkey
    pkg_sack = base.sack
    matches = set()
    for pkg in pkgs:
        hkpkgs = set()
        subject = dnf.subject.Subject(pkg)
        hkpkgs |= set(subject.get_best_selector(
            pkg_sack, obsoletes=True).matches())
        if len(matches) == 0:
            matches = hkpkgs
        else:
            matches |= hkpkgs
    result = list(matches)
    a = pkg_sack.query().available()
    result = a.filter(pkg=result, arch=[
                      "noarch", hawkey.detect_arch()]).latest().run()
    filtered_res = []
    for pkg in result:
        if not is_package_installed(installed_pkgs, pkg):
            filtered_res.append(pkg)
    return filtered_res


def locate_deps(base, installed_pkgs, pkgs):
    import hawkey
    pkg_sack = base.sack.query()
    results = {}
    a = pkg_sack.available()
    for pkg in pkgs:
        results[pkg] = {}
        reqs = pkg.requires
        pkg_results = results[pkg]
        for req in reqs:
            if str(req).startswith("rpmlib("):
                continue
            satisfiers = []
            for po in a.filter(provides=req, arch=["noarch", hawkey.detect_arch()]).latest():
                if is_package_installed(installed_pkgs, pkg):
                    satisfiers = []
                    break
                satisfiers.append(po)
            pkg_results[req] = satisfiers
    return results


def get_dependency_list(base, installed_pkgs, pkgs):
    match = parse_packages(base, installed_pkgs, pkgs)
    loc_pkgs = []
    for po in match:
        loc_pkgs.append(po)
    results = locate_deps(base, installed_pkgs, loc_pkgs)
    return results


def process_results(results):
    reqlist = {}
    notfound = {}
    for pkg in results:
        if not results[pkg]:
            continue
        for req in results[pkg]:
            rlist = results[pkg][req]
            if not rlist:
                notfound[str(req)] = []
                continue
            reqlist[req] = rlist
    found = {}
    for req, rlist in reqlist.items():
        found[str(req)] = []
        for r in rlist:
            result = {
                "name": r.name,
                "arch": r.arch,
                "epoch": str(r.epoch),
                "release": r.release,
                "version": r.version
            }
            dep = "{name}-{version}-{release}.{arch}".format(**result)
            if dep not in found[str(req)]:
                found[str(req)].append(dep)
    return found, notfound


def get_recursive_dep_list(base, installed_pkgs, pkgs):
    solved = []
    to_solve = pkgs
    all_results = {}

    while to_solve:
        results = get_dependency_list(base, installed_pkgs, pkgs)
        all_results.update(results)
        found = process_results(results)[0]

        solved += to_solve
        to_solve = []
        for _dep, fpkgs in found.items():
            for pkg in fpkgs:
                ndep = pkg
                solved = list(set(solved))
                if ndep not in solved:
                    to_solve.append(ndep)
        pkgs = to_solve
    return all_results


def get_centos_stream_version():
    try:  
        with open("/etc/os-release") as os_release:
            lines = [line.strip() for line in os_release.readlines() if line.strip() != ""]
            info = {k: v.strip("'\"") for k, v in (line.split("=", maxsplit=1) for line in lines)}

        if info["NAME"].find("centos stream"):
            return info["VERSION_ID"]
    except:
        pass
    return None

    
def download_with_dnf(dest_dir, package_name, only_list_deps):
    files_in_dest_dir = set(os.listdir(dest_dir))
    import dnf
    import tempfile
    downloaded_rpms = []
    dep_name_set = set()
    with tempfile.TemporaryDirectory() as tmp_cache_dirname:
        base = dnf.Base()
        base.conf.cachedir = tmp_cache_dirname
        base.conf.destdir = dest_dir
        last_exception = None
        centos_stream_version = get_centos_stream_version()
        if centos_stream_version:
            base.conf.substitutions["stream"] = centos_stream_version + "-stream"

        for repeat in range(5):
            try:
                base.read_all_repos()
                base.fill_sack(load_system_repo=True,load_available_repos=True)
                last_exception = None
                break
            except Exception as e:
                last_exception = e
                time.sleep(3)
                pass
        if last_exception:
            raise last_exception
        installed_pkgs = {}
        for pkg in base.sack.query().installed().run():
            key = (pkg.name, pkg.arch, pkg.version)
            installed_pkgs.setdefault(key, []).append(pkg)

        packages_requested = parse_packages(
            base, installed_pkgs, [package_name])
        dep_list_dict = get_recursive_dep_list(
            base, installed_pkgs, [package_name])
        found, not_found = process_results(dep_list_dict)
        dep_set = set([str(p) for p in packages_requested])
        for dep_key, dep_vals in found.items():
            dep_set.update(dep_vals)

        deps_to_install = []

        for dep_name in dep_set:
            for pkg in base.sack.query().filter(nevra=dep_name):
                if dep_name not in dep_name_set:
                    if not is_package_installed(installed_pkgs, pkg):
                        deps_to_install.append(pkg)
                        dep_name_set.add(dep_name)

        remote_pkgs, local_pkgs = base._select_remote_pkgs(deps_to_install)

        if not only_list_deps:
            base.repos.all().pkgdir = base.conf.destdir
            base.download_packages(deps_to_install)

            downloaded_rpms = set(os.listdir(dest_dir)) - files_in_dest_dir
            downloaded_rpms = [os.path.join(
                dest_dir, f) for f in downloaded_rpms if f.endswith(".rpm")]

    return list(dep_name_set) if only_list_deps else downloaded_rpms


def download_with_dnf_resolve(dest_dir, package_name, only_list_deps, exist_deps):
    files_in_dest_dir = set(os.listdir(dest_dir))
    import dnf
    import tempfile
    import platform
    downloaded_rpms = []
    dep_name_set = set()
    with tempfile.TemporaryDirectory() as tmp_cache_dirname:
        base = dnf.Base()
        base.conf.cachedir = tmp_cache_dirname
        base.conf.destdir = dest_dir
        base.conf.install_weak_deps = False
        
        centos_stream_version = get_centos_stream_version()
        if centos_stream_version:
            base.conf.substitutions["stream"] = centos_stream_version + "-stream"
        base.read_all_repos()

        last_exception = None
        for repeat in range(5):
            try:
                base.fill_sack(load_system_repo=True,
                               load_available_repos=True)
                last_exception = None
                break
            except Exception as e:
                last_exception = e
                time.sleep(3)
                pass
        if last_exception:
            raise last_exception

        try:
            base.install(package_name)
        except dnf.exceptions.MarkingError:
            error_print("Package is not found in enabled repositories. \
                To continue, you need to install {0} package.".format(package_name))
            return []

        base.resolve()
        not_installed_set = set(
            [pkg for pkg in set(base.transaction.install_set) if str(pkg) not in exist_deps])

        if only_list_deps:
            dep_list = []
            for pkg in not_installed_set:
                dep_list.append(str(pkg))
            return dep_list

        base.repos.all().pkgdir = base.conf.destdir
        base.download_packages(not_installed_set)

        downloaded_rpms = set(os.listdir(dest_dir)) - files_in_dest_dir
        downloaded_rpms = [os.path.join(dest_dir, f)
                           for f in downloaded_rpms if f.endswith(".rpm")]
        return downloaded_rpms


def str_to_bool(v):
    return v.lower() in ("yes", "true", "t", "1")


if __name__ == "__main__":
    package_list = []

    if len(sys.argv) < 4:
        error_print(
            "Destination directory and package name should be set to download packages")
        sys.exit(50)
    dest_dir = sys.argv[1]
    package_name = sys.argv[2]
    only_list_deps = True if len(
        sys.argv) >= 4 and str_to_bool(sys.argv[3]) else False

    exist_deps = []
    if len(sys.argv) >= 5:
        exist_deps.extend(sys.argv[4:])

    try:
        if not os.path.isdir(dest_dir):
            os.makedirs(dest_dir)
    except Exception as e:
        error_print("Destination {} directory cannot be opened: ".format(
            dest_dir) + str(e))

    with silence_stdout():
        package_manager = get_package_manager()
        if package_manager == PackageManager.YUM:
            package_list = download_with_yum(
                dest_dir, package_name, only_list_deps, exist_deps)
        elif package_manager == PackageManager.DNF:
            package_list = download_with_dnf_resolve(
                dest_dir, package_name, only_list_deps, exist_deps)

    print("\n".join(str(pkg) for pkg in package_list))
