/* Stevie, the client for stevie
 * (c) 2025 Michał Górny
 * SPDX-License-Identifier: GPL-2.0-or-later
 *
 * Inspired by nixos-jobserver (draft) and guildmaster:
 * https://github.com/RaitoBezarius/nixpkgs/blob/e97220ecf1e8887b949e4e16547bf0334826d076/pkgs/by-name/ni/nixos-jobserver/nixos-jobserver.cpp#L213
 * https://codeberg.org/amonakov/guildmaster/
 */

#include <cassert>
#include <cerrno>
#include <climits>
#include <cstring>
#include <print>
#include <utility>
#include <variant>
#include <vector>

#include <fcntl.h>
#include <getopt.h>
#include <sys/ioctl.h>
#include <sys/wait.h>
#include <unistd.h>

#include "steve.h"
#include "util.hxx"

struct token_guard {
	int jobserver_fd;
	char job_token;

	~token_guard() {
		ssize_t wr;
		while ((wr = write(jobserver_fd, &job_token, 1)) == -1 && errno == EINTR);
		if (wr == -1)
			perror("Writing job token failed");
	}
};

pid_t pid = -1;

void signal_handler(int signum, siginfo_t *, void *)
{
	if (pid != -1) {
		std::print("stevie: passing signal SIG{} to child {}, repeat to force termination\n",
				signal_name(signum), pid);
		if (kill(pid, signum) == -1 && errno != ESRCH) {
			perror("Unable to pass signal to child process");
			exit(1);
		}
	}

	struct sigaction sigact{};
	sigact.sa_handler = SIG_DFL;
	if (sigaction(signum, &sigact, NULL) == -1) {
		perror("Unable to restore default signal handler");
		exit(1);
	}
	/* if we didn't fork yet, just reraise */
	if (pid == -1) {
		if (raise(signum) == -1) {
			perror("Unable to reraise signal");
			exit(1);
		}
	}
}

static int run_command(int jobserver_fd, char **argv, const char *jobserver_path)
{
	char job_token;

	struct sigaction sigact{};
	sigact.sa_sigaction = signal_handler;
	sigact.sa_flags = SA_SIGINFO;
	for (int signum : {SIGHUP, SIGINT, SIGTERM, SIGUSR1, SIGUSR2}) {
		if (sigaction(signum, &sigact, NULL) == -1) {
			perror("Setting signal handler failed");
			return 1;
		}
	}

	ssize_t res;
	while ((res = read(jobserver_fd, &job_token, 1)) == -1 && errno == EINTR);
	if (res == 0) {
		std::print("EOF while waiting for job token\n");
		return 1;
	}
	if (res == -1) {
		perror("Reading job token failed");
		return 1;
	}

	pid = fork();
	if (pid == 0) {
		const char *old_makeflags = getenv("MAKEFLAGS");
		std::string new_makeflags;
		if (old_makeflags) {
			new_makeflags = old_makeflags;
			new_makeflags += ' ';
		}
		new_makeflags += "--jobserver-auth=fifo:";
		new_makeflags += jobserver_path;

		/* Some clients (e.g. LLVM) try to read -j from MAKEFLAGS to determine
		 * the upper jobs bound.  To accommodate them, try to grab the job count
		 * from steve.  Note that this is not 100% reliable, as it doesn't
		 * account for job token count changing at runtime.
		 *
		 * We do not need to offset the value since stevie takes one job token,
		 * so we match GNU make exactly.
		 */
		int64_t num_jobs;
		if (ioctl(jobserver_fd, STEVE_IOC_GET_JOBS, &num_jobs) == 0)
			new_makeflags += std::format(" -j{}", num_jobs);

		if (setenv("MAKEFLAGS", new_makeflags.c_str(), 1) == -1) {
			std::print(stderr, "Unable to set MAKEFLAGS={}\n", new_makeflags);
			_exit(1);
		}

		execvp(argv[0], argv);
		std::print("exec for {} failed: {}\n", argv[0], strerror(errno));
		_exit(1);
	}

	token_guard job_token_guard{jobserver_fd, job_token};
	if (pid == -1) {
		perror("Forking failed");
		return 1;
	}

	int wret;
	while ((res = waitpid(pid, &wret, 0)) == -1 && errno == EINTR);
	if (res == -1) {
		std::print("Waiting for PID {} failed: {}\n", pid, strerror(errno));
		return 1;
	}

	return WIFSIGNALED(wret) ? WTERMSIG(wret) + 128 : WEXITSTATUS(wret);
}

static constexpr char stevie_usage[] =
"usage: {0} [options] <argv>...\n"
"\n"
"options:\n"
"    --help, -h             print this help message\n"
"    --version, -V          print version\n"
"    --jobserver PATH, -s PATH\n"
"                           jobserver FIFO path (default: /dev/steve)\n"
"\n"
"other actions (executed before the command):\n"
"    --get-tokens, -t       print available token count\n"
"    --get-jobs, -j         print total job number\n"
"    --set-jobs JOBS, -J JOBS\n"
"                           set total job number\n"
"    --get-load-average, -l\n"
"                           print max load-average\n"
"    --set-load-average LOAD_AVG, -L LOAD_AVG\n"
"                           set max load-average\n"
"    --get-load-recheck-timeout, -r\n"
"                           print load-recheck-timeout\n"
"    --set-load-recheck-timeout TIMEOUT, -R TIMEOUT\n"
"                           set load-recheck-timeout\n"
"    --get-min-jobs, -m     print min-job number\n"
"    --set-min-jobs JOBS, -M JOBS\n"
"                           set min-job number\n"
"    --get-min-memory-avail, -a\n"
"                           print min. required available memory (in MIB)\n"
"    --set-min-memory-avail MIN_MEM_AVAIL, -A MIN_MEM_AVAIL\n"
"                           set min. required available memory (in MiB)\n"
"    --get-per-process-limit, -p\n"
"                           print per-process limit\n"
"    --set-per-process-limit LIMIT, -P LIMIT\n"
"                           set per-process limit\n";

static const struct option stevie_long_opts[] = {
	{"help", no_argument, 0, 'h'},
	{"version", no_argument, 0, 'V'},
	{"jobserver", required_argument, 0, 's'},
	{"get-tokens", no_argument, 0, 't'},
	{"get-jobs", no_argument, 0, 'j'},
	{"set-jobs", required_argument, 0, 'J'},
	{"get-load-average", no_argument, 0, 'l'},
	{"set-load-average", required_argument, 0, 'L'},
	{"get-load-recheck-timeout", no_argument, 0, 'r'},
	{"set-load-recheck-timeout", required_argument, 0, 'R'},
	{"get-min-jobs", no_argument, 0, 'm'},
	{"set-min-jobs", required_argument, 0, 'M'},
	{"get-min-memory-avail", no_argument, 0, 'a'},
	{"set-min-memory-avail", required_argument, 0, 'A'},
	{"get-per-process-limit", no_argument, 0, 'p'},
	{"set-per-process-limit", required_argument, 0, 'P'},
	{},
};

static const char *stevie_short_opts = "+hVs:tjJ:lL:rR:mM:aA:pP:";

typedef std::vector<
	std::pair<unsigned long, std::variant<int64_t, double>>
> stevie_action_vector;

int main(int argc, char **argv)
{
	const char *jobserver_path = "/dev/steve";

	int opt;
	stevie_action_vector actions;
	while ((opt = getopt_long(argc, argv, stevie_short_opts,
				stevie_long_opts, nullptr)) != -1) {
		long long_arg;
		double double_arg;

		switch (opt) {
		case 'h':
			std::print(stevie_usage, argv[0]);
			return 0;
		case 'V':
			std::print("stevie {}\n", STEVE_VERSION);
			return 0;
		case 's':
			jobserver_path = optarg;
			break;
		case 't':
			actions.emplace_back(STEVE_IOC_GET_TOKENS, 0);
			break;
		case 'j':
			actions.emplace_back(STEVE_IOC_GET_JOBS, 0);
			break;
		case 'J':
			if (!arg_to_long(optarg, &long_arg)) {
				std::print(stderr, "invalid --set-jobs value: {}\n", optarg);
				return 1;
			}
			actions.emplace_back(STEVE_IOC_SET_JOBS, long_arg);
			break;
		case 'l':
			actions.emplace_back(STEVE_IOC_GET_LOAD_AVG, 0.0);
			break;
		case 'L':
			if (!arg_to_double(optarg, &double_arg) || double_arg < 1) {
				std::print(stderr, "invalid --set-load-average value (must be >=1): {}\n", optarg);
				return 1;
			}
			actions.emplace_back(STEVE_IOC_SET_LOAD_AVG, double_arg);
			break;
		case 'm':
			actions.emplace_back(STEVE_IOC_GET_MIN_JOBS, 0);
			break;
		case 'M':
			if (!arg_to_long(optarg, &long_arg)) {
				std::print(stderr, "invalid --set-min-jobs value: {}\n", optarg);
				return 1;
			}
			actions.emplace_back(STEVE_IOC_SET_MIN_JOBS, long_arg);
			break;
		case 'a':
			actions.emplace_back(STEVE_IOC_GET_MIN_MEMORY_AVAIL, 0);
			break;
		case 'A':
			if (!arg_to_long(optarg, &long_arg)) {
				std::print(stderr, "invalid --set-min-memory-avail value: {}\n", optarg);
				return 1;
			}
			actions.emplace_back(STEVE_IOC_SET_MIN_MEMORY_AVAIL, long_arg);
			break;
		case 'p':
			actions.emplace_back(STEVE_IOC_GET_PER_PROCESS_LIMIT, 0);
			break;
		case 'P':
			if (!arg_to_long(optarg, &long_arg)) {
				std::print(stderr, "invalid --set-per-process-limit value: {}\n", optarg);
				return 1;
			}
			actions.emplace_back(STEVE_IOC_SET_PER_PROCESS_LIMIT, long_arg);
			break;
		case 'r':
			actions.emplace_back(STEVE_IOC_GET_LOAD_RECHECK_TIMEOUT, 0.0);
			break;
		case 'R':
			if (!arg_to_double(optarg, &double_arg) || double_arg < 0.000001 || double_arg > INT_MAX) {
				std::print(stderr, "invalid --set-load-recheck-timeout value (must be >=1 us): {}\n", optarg);
				return 1;
			}
			actions.emplace_back(STEVE_IOC_SET_LOAD_RECHECK_TIMEOUT, double_arg);
			break;
		default:
			std::print(stderr, stevie_usage, argv[0]);
			return 1;
		}
	}

	if (actions.empty() && !argv[optind]) {
		std::print(stderr, "{}: no command provided\n", argv[0]);
		return 1;
	}

	int jobserver_fd = open(jobserver_path, O_RDWR | O_CLOEXEC);
	if (jobserver_fd == -1) {
		std::print(stderr, "unable to open {}: {}\n", jobserver_path, strerror(errno));
		return 1;
	}
	fd_guard jobserver_fd_guard{jobserver_fd};

	for (auto &action : actions) {
		unsigned long ioctl_num = action.first;
		int64_t *ioctl_val = std::get_if<int64_t>(&action.second);
		double *ioctl_dval = std::get_if<double>(&action.second);

		int ret = -1;
		if (ioctl_val)
			ret = ioctl(jobserver_fd, ioctl_num, ioctl_val);
		else if (ioctl_dval)
			ret = ioctl(jobserver_fd, ioctl_num, ioctl_dval);
		else
			assert(0 && "not reached");

		if (ret != 0) {
			perror("ioctl failed");
			return 1;
		}

		if (STEVE_IOC_IS_GET(ioctl_num)) {
			if (ioctl_val)
				std::print("{}\n", *ioctl_val);
			else if (ioctl_dval)
				std::print("{}\n", *ioctl_dval);
		} else if (STEVE_IOC_IS_SET(ioctl_num))
			std::print("ok\n");
		else
			assert(0 && "not reached");
	}

	if (argv[optind])
		return run_command(jobserver_fd, &argv[optind], jobserver_path);
	return 0;
}
