//go:build integration
// +build integration

/*
** Copyright (C) 2001-2025 Zabbix SIA
**
** Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
** documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
** rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
** permit persons to whom the Software is furnished to do so, subject to the following conditions:
**
** The above copyright notice and this permission notice shall be included in all copies or substantial portions
** of the Software.
**
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
** WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
** SOFTWARE.
**/

package tlsconfig

import (
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"math/big"
	"os"
	"path/filepath"
	"strings"
	"testing"
	"time"
)

// testCerts holds paths and objects for generated TLS assets.
type testCerts struct {
	caCert        *x509.Certificate
	caCertPEM     []byte
	caKey         crypto.PrivateKey
	caFile        string
	serverCert    *x509.Certificate
	serverCertPEM []byte
	serverKeyPEM  []byte
	certFile      string
	keyFile       string
}

// generateTestCerts creates a self-signed CA and a server certificate signed by that CA.
// It writes the PEM-encoded assets to temporary files and returns a struct with their details.
func generateTestCerts(t *testing.T, dnsName string) testCerts {
	t.Helper()

	// 1. CA Setup
	caTpl := &x509.Certificate{
		SerialNumber: big.NewInt(2025),
		Subject:      pkix.Name{Organization: []string{"Test CA"}},
		NotBefore:    time.Now(),
		NotAfter:     time.Now().AddDate(1, 0, 0),
		IsCA:         true,
		ExtKeyUsage: []x509.ExtKeyUsage{
			x509.ExtKeyUsageClientAuth,
			x509.ExtKeyUsageServerAuth,
		},
		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		BasicConstraintsValid: true,
	}

	caPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		t.Fatalf("Failed to generate CA private key: %v", err)
	}

	caBytes, err := x509.CreateCertificate(rand.Reader, caTpl, caTpl, &caPrivKey.PublicKey, caPrivKey)
	if err != nil {
		t.Fatalf("Failed to create CA certificate: %v", err)
	}

	caCert, err := x509.ParseCertificate(caBytes)
	if err != nil {
		t.Fatalf("Failed to parse CA certificate: %v", err)
	}

	caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caBytes})

	// 2. Server Cert Setup
	certTpl := &x509.Certificate{
		SerialNumber: big.NewInt(2026),
		Subject:      pkix.Name{Organization: []string{"Test Server"}},
		DNSNames:     []string{dnsName},
		NotBefore:    time.Now(),
		NotAfter:     time.Now().AddDate(1, 0, 0),
		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		KeyUsage:     x509.KeyUsageDigitalSignature,
	}

	certPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		t.Fatalf("Failed to generate server private key: %v", err)
	}

	certBytes, err := x509.CreateCertificate(rand.Reader, certTpl, caCert, &certPrivKey.PublicKey, caPrivKey)
	if err != nil {
		t.Fatalf("Failed to create server certificate: %v", err)
	}

	serverCert, err := x509.ParseCertificate(certBytes)
	if err != nil {
		t.Fatalf("Failed to parse server certificate: %v", err)
	}

	certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes})
	keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey)})

	// 3. Write to temp files
	tempDir := t.TempDir()
	caFile := filepath.Join(tempDir, "ca.pem")
	certFile := filepath.Join(tempDir, "cert.pem")
	keyFile := filepath.Join(tempDir, "key.pem")

	if err = os.WriteFile(caFile, caPEM, 0600); err != nil {
		t.Fatalf("Failed to write CA file: %v", err)
	}

	if err = os.WriteFile(certFile, certPEM, 0600); err != nil {
		t.Fatalf("Failed to write cert file: %v", err)
	}

	if err = os.WriteFile(keyFile, keyPEM, 0600); err != nil {
		t.Fatalf("Failed to write key file: %v", err)
	}

	return testCerts{
		caCert:        caCert,
		caCertPEM:     caPEM,
		caKey:         caPrivKey,
		caFile:        caFile,
		serverCert:    serverCert,
		serverCertPEM: certPEM,
		serverKeyPEM:  keyPEM,
		certFile:      certFile,
		keyFile:       keyFile,
	}
}

//nolint:gocognit,gocyclo,cyclop // this is unit test
func Test_CreateConfig(t *testing.T) {
	t.Parallel()

	certs := generateTestCerts(t, "localhost")

	// Create a temporary file with invalid PEM data.
	badPemFile := filepath.Join(t.TempDir(), "bad.pem")
	if err := os.WriteFile(badPemFile, []byte("not a pem"), 0600); err != nil {
		t.Fatal(err)
	}

	var tests = []struct {
		name       string
		details    Details
		skipVerify bool
		validate   func(*testing.T, *tls.Config)
		wantErr    bool
		errText    string
	}{
		{
			name: "validConfigWithVerification",
			details: Details{
				RawURI:      "tcp://localhost:5432",
				TLSCaFile:   certs.caFile,
				TLSCertFile: certs.certFile,
				TLSKeyFile:  certs.keyFile,
			},
			skipVerify: false,
			validate: func(t *testing.T, cfg *tls.Config) {
				t.Helper()
				if cfg.InsecureSkipVerify {
					t.Error("InsecureSkipVerify should be false")
				}
				if cfg.ServerName != "localhost" {
					t.Errorf("got ServerName %q, want %q", cfg.ServerName, "localhost")
				}
				if len(cfg.Certificates) != 1 {
					t.Errorf("got %d certificates, want 1", len(cfg.Certificates))
				}
				if cfg.RootCAs == nil {
					t.Error("RootCAs should not be nil")
				}
			},
		},
		{
			name: "validConfig with skip verify",
			details: Details{
				RawURI:      "tcp://localhost:5432",
				TLSCaFile:   certs.caFile,
				TLSCertFile: certs.certFile,
				TLSKeyFile:  certs.keyFile,
			},
			skipVerify: true,
			validate: func(t *testing.T, cfg *tls.Config) {
				t.Helper()
				if !cfg.InsecureSkipVerify {
					t.Error("InsecureSkipVerify should be true")
				}
				if cfg.ServerName != "" {
					t.Errorf("got ServerName %q, want empty", cfg.ServerName)
				}
			},
		},
		{
			name:    "-invalidCAFilePath",
			details: Details{TLSCaFile: "/non/existent/ca.pem"},
			wantErr: true,
			errText: "no such file or directory",
		},
		{
			name:    "-invalidCertificateFilePath",
			details: Details{TLSCaFile: certs.caFile, TLSCertFile: "/non/existent/cert.pem"},
			wantErr: true,
			errText: "no such file or directory",
		},
		{
			name:    "-malformedCAFile",
			details: Details{TLSCaFile: badPemFile, TLSCertFile: certs.certFile, TLSKeyFile: certs.keyFile},
			wantErr: true,
			errText: "Failed to append PEM",
		},
		{
			name: "-invalidURI",
			details: Details{
				RawURI:      "://invalid",
				TLSCaFile:   certs.caFile,
				TLSCertFile: certs.certFile,
				TLSKeyFile:  certs.keyFile,
			},
			skipVerify: false,
			wantErr:    true,
			errText:    "missing protocol scheme",
		},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			config, err := CreateConfig(tc.details, tc.skipVerify)

			if tc.wantErr {
				if err == nil {
					t.Fatalf("CreateConfig() error = nil, want error containing %q", tc.errText)
				}

				if !strings.Contains(err.Error(), tc.errText) {
					t.Errorf("CreateConfig() error = %q, want it to contain %q", err.Error(), tc.errText)
				}

				return
			}

			if err != nil {
				t.Fatalf("CreateConfig() unexpected error: %v", err)
			}

			if tc.validate != nil {
				tc.validate(t, config)
			}
		})
	}
}

func Test_VerifyPeerCertificateFunc(t *testing.T) {
	t.Parallel()

	host := "zabbix.com"
	certs := generateTestCerts(t, host)

	rootPool := x509.NewCertPool()
	if !rootPool.AppendCertsFromPEM(certs.caCertPEM) {
		t.Fatal("Failed to create root CA pool")
	}

	rawServerCert := [][]byte{certs.serverCert.Raw}

	var tests = []struct {
		name         string
		dnsName      string
		rootPool     *x509.CertPool
		certificates [][]byte
		wantErr      bool
		errText      string
	}{
		{
			name:         "+validCertificate",
			dnsName:      host,
			rootPool:     rootPool,
			certificates: rawServerCert,
		},
		{
			name:         "-noCertificatesProvided",
			dnsName:      host,
			rootPool:     rootPool,
			certificates: [][]byte{},
			wantErr:      true,
			errText:      "No TLS certificates found",
		},
		{
			name:         "-malformedCertificateData",
			dnsName:      host,
			rootPool:     rootPool,
			certificates: [][]byte{[]byte("this is not a valid certificate")},
			wantErr:      true,
			errText:      "No TLS certificates found",
		},
		{
			name:         "-verificationFailsWrongDNSName",
			dnsName:      "wrong.host.com",
			rootPool:     rootPool,
			certificates: rawServerCert,
			wantErr:      true,
			errText:      "Failed to verify certificate",
		},
		{
			name:         "-verificationFailsUnknownRootCA",
			dnsName:      host,
			rootPool:     x509.NewCertPool(),
			certificates: rawServerCert,
			wantErr:      true,
			errText:      "certificate signed by unknown authority",
		},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			t.Parallel()

			verifyFunc := VerifyPeerCertificateFunc(tc.dnsName, tc.rootPool)

			err := verifyFunc(tc.certificates, nil)
			if tc.wantErr {
				if err == nil {
					t.Fatalf("verifyFunc() error = nil, want error containing %q", tc.errText)
				}

				if !strings.Contains(err.Error(), tc.errText) {
					t.Errorf("verifyFunc() error = %q, want it to contain %q", err.Error(), tc.errText)
				}
			} else {
				if err != nil {
					t.Fatalf("verifyFunc() unexpected error: %v", err)
				}
			}
		})
	}
}
