#!/usr/libexec/platform-python

# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

from __future__ import print_function

import errno
import os
import subprocess
import sys

from contextlib import contextmanager


@contextmanager
def silence_stdout():
    old_target = sys.stdout
    try:
        with open(os.devnull, "w") as new_target:
            sys.stdout = new_target
            yield new_target
    finally:
        sys.stdout = old_target


def error_print(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


def find_binary(bin_name, dest_dir):
    paths = []
    if "PATH" in os.environ:
        paths += os.environ["PATH"].split(":")
    if "HOME" in os.environ:
        paths += [os.environ["HOME"] + "/bin"]
    relative_paths = ["/usr/local/sbin", "/usr/local/bin",
                      "/usr/sbin", "/usr/bin", "/sbin", "/bin"]
    paths += relative_paths
    while dest_dir and dest_dir.endswith(os.sep):
        dest_dir = dest_dir[:-1]

    if dest_dir:
        for rel_path in relative_paths:
            paths += [dest_dir + rel_path]

    for path in paths:
        bin_path = os.path.join(path, bin_name)
        if os.path.exists(bin_path):
            return bin_path

    error_print("'{0}' not found, exiting.".format(bin_name))
    error_print("To continue, you need to install '{0}'.".format(bin_name))
    sys.exit(1)


def is_subdir(path, directory):
    path = os.path.realpath(path)
    directory = os.path.realpath(directory)
    relative = os.path.relpath(path, directory)
    return not relative.startswith(os.pardir + os.sep)


def recreate_if_needed(cur_file_path, link_in_file, need_recreate):
    if need_recreate:
        os.remove(cur_file_path)
        os.symlink(link_in_file, cur_file_path)

# Restore symlinks in downloaded packages: a lot of paths there are absolute.
# This function checks for the existence of a symlink destination and if it does not
# exist, but exists equivalent destination in installed dir, restores it.


def restore_symlinks(dest_dir):
    for subdir, dirs, files in os.walk(dest_dir):
        for file in files:
            cur_file_path = os.path.join(subdir, file)
            is_symlink_broken = False
            try:
                if not os.path.exists(os.readlink(cur_file_path)):
                    is_symlink_broken = True
            except:
                continue
            if not is_symlink_broken:
                continue

            recreate = False
            link_in_file = os.readlink(cur_file_path)

            if link_in_file == file:
                continue
            if not link_in_file:
                continue

            comb_link_in_file = None
            full_comb_link_in_file = None
            system_link = None
            full_system_link = None
            try:
                link_part = link_in_file
                if link_part.startswith("/"):
                    link_part = link_part[1:]
                comb_link_in_file = os.path.join(dest_dir, link_part)
                full_comb_link_in_file = os.path.join(subdir, link_part)
                try:
                    if is_subdir(comb_link_in_file, dest_dir):
                        system_link = "/" + \
                            os.path.relpath(comb_link_in_file, dest_dir)
                except:
                    pass
                try:
                    if is_subdir(full_comb_link_in_file, dest_dir):
                        full_system_link = "/" + \
                            os.path.relpath(full_comb_link_in_file, dest_dir)
                except:
                    pass
            except:
                pass

            need_recreate = True
            for link in [full_system_link, system_link, link_in_file, comb_link_in_file, full_comb_link_in_file]:
                if not link:
                    continue
                try:
                    success = False
                    link_lib = link.replace("/lib/", "/lib64/")
                    link_lib64 = link.replace("/lib64/", "/lib/")
                    for l in [link, link_lib, link_lib64]:
                        if os.path.isdir(link) or os.path.isfile(link):
                            recreate_if_needed(
                                cur_file_path, link, need_recreate)
                            success = True
                            break
                    if success:
                        break
                except:
                    # not a link
                    pass


def extract_rpm(dest_dir, rpm_files):
    extracted_pkgs = []
    rpm2cpio = find_binary("rpm2cpio", dest_dir)
    cpio = find_binary("cpio", dest_dir)

    old_dir = os.getcwd()
    try:
        os.chdir(dest_dir)
        for rpm_file in rpm_files:
            p1 = subprocess.Popen([rpm2cpio, rpm_file], stdout=subprocess.PIPE)
            p2 = subprocess.Popen([cpio, "-idvu"], stdin=p1.stdout,
                                  stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            sout, serr = p2.communicate()
            extracted_pkgs.append(rpm_file)
    finally:
        os.chdir(old_dir)
    return extracted_pkgs


if __name__ == "__main__":
    extracted_pkgs = []
    with silence_stdout():
        if len(sys.argv) < 3:
            error_print(
                "Destination directory and RPM filenames should be set to unpack RPM")
            sys.exit(50)
        extracted_pkgs = extract_rpm(sys.argv[1], sys.argv[2:])
        restore_symlinks(sys.argv[1])

    print("\n".join(str(pkg) for pkg in extracted_pkgs))
