#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

# pyre-ignore-all-errors[56]

import math
import random
import unittest
from unittest.mock import MagicMock, patch

import hypothesis.strategies as st
import numpy as np
import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
    CacheAlgorithm,
    EmbeddingLocation,
    PoolingMode,
)
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
    ComputeDevice,
    RESParams,
    SplitTableBatchedEmbeddingBagsCodegen,
)
from fbgemm_gpu.tbe.utils import (
    b_indices,
    generate_requests,
    get_table_batched_offsets_from_dense,
    round_up,
    to_device,
)
from hypothesis import assume, given, HealthCheck, settings, Verbosity

from .. import common  # noqa E402
from ..common import (
    format_ref_tensors_in_mixed_B_layout,
    FORWARD_MAX_THREADS,
    gen_mixed_B_batch_sizes,
    get_max_thread_blocks,
    MAX_EXAMPLES_LONG_RUNNING,
    open_source,
)

if open_source:
    # pyre-ignore[21]
    from test_utils import (
        additional_decorators,
        gpu_unavailable,
        is_nvidia_device,
        optests,
        TEST_WITH_ROCM,
    )
else:
    from fbgemm_gpu.test.test_utils import (
        additional_decorators,
        gpu_unavailable,
        is_nvidia_device,
        optests,
        TEST_WITH_ROCM,
    )

VERBOSITY: Verbosity = Verbosity.verbose

fp8_dtype: torch.dtype = (
    torch.float8_e4m3fnuz if torch.version.hip is not None else torch.float8_e4m3fn
)

# pyre-ignore
additional_decorators.update(
    {
        # TODO: Implement the operator registrations later
        "test_faketensor__test_forward_cpu_int8": [
            unittest.skip("Operator not implemented for Meta tensors"),
        ],
        "test_faketensor__test_forward_fused_pooled_emb_quant": [
            unittest.skip("Operator not implemented for Meta tensors"),
        ],
        "test_faketensor__test_forward_gpu_no_cache_int8": [
            unittest.skip("Operator not implemented for Meta tensors"),
        ],
        "test_faketensor__test_forward_gpu_uvm_cache_int8": [
            unittest.skip("Operator not implemented for Meta tensors"),
        ],
        "test_faketensor__test_forward_cpu_fp32": [
            unittest.skip("Operator not implemented for Meta tensors"),
        ],
        # TODO: Make it compatible with opcheck tests
        "test_schema__test_forward_gpu_uvm_cache_fp16": [
            unittest.skip(
                "Failed with Argument lxu_cache_locations_output is not defined to alias output but was aliasing"
            ),
        ],
        "test_schema__test_forward_gpu_uvm_cache_fp32": [
            unittest.skip(
                "Failed with Argument lxu_cache_locations_output is not defined to alias output but was aliasing"
            ),
        ],
        # learning rate tensor needs to be on CPU to avoid D->H sync point since it will be used as float in the kernel
        # this fails fake_tensor test as the test expects all tensors to be on the same device
        "test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function": [
            unittest.skip(
                "Operator failed on FakeTensor test since learning rate tensor is always on CPU regardless of other tensors"
            ),
        ],
    }
)


@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class ForwardTest(unittest.TestCase):
    def execute_forward_(  # noqa C901
        self,
        T: int,
        D: int,
        B: int,
        log_E: int,
        L: int,
        weights_precision: SparseType,
        weighted: bool,
        mixed: bool,
        mixed_B: bool,
        use_cache: bool,
        cache_algorithm: CacheAlgorithm,
        pooling_mode: PoolingMode,
        use_cpu: bool,
        output_dtype: SparseType,
        use_experimental_tbe: bool,
        enable_raw_embedding_streaming: bool = False,
        prefetch_pipeline: bool = False,
    ) -> None:
        # NOTE: cache is not applicable to CPU version.
        assume(not use_cpu or not use_cache)
        # NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
        assume(not use_cpu or T * B * L * D <= 2048)
        # NOTE: CPU does not support FP16.
        assume(not (use_cpu and weights_precision == SparseType.FP16))

        # NOTE: weighted operation can be done only for SUM.
        assume(pooling_mode == PoolingMode.SUM or not weighted)
        # NOTE: No bag ops, no mixed
        assume(not mixed or pooling_mode != PoolingMode.NONE)
        # NOTE: No bag CPU doesn't supprot INT8
        assume(
            not (
                use_cpu
                and weights_precision == SparseType.INT8
                and pooling_mode == PoolingMode.NONE
            )
        )
        # TODO: Support these cases
        assume(
            not mixed_B
            or (
                weights_precision != SparseType.INT8
                and output_dtype != SparseType.INT8
                and pooling_mode != PoolingMode.NONE
            )
        )
        # NOTE: Raw embedding streaming requires UVM cache
        assume(not enable_raw_embedding_streaming or use_cache)
        # NOTE: Raw embedding streaming not supported on CPU
        assume(not enable_raw_embedding_streaming or not use_cpu)

        emb_op = SplitTableBatchedEmbeddingBagsCodegen
        if pooling_mode == PoolingMode.SUM:
            mode = "sum"
            do_pooling = True
        elif pooling_mode == PoolingMode.MEAN:
            mode = "mean"
            do_pooling = True
        elif pooling_mode == PoolingMode.NONE:
            mode = "sum"
            do_pooling = False
        else:
            # This proves that we have exhaustively checked all PoolingModes
            raise RuntimeError("Unknown PoolingMode!")

        E = int(10**log_E)
        if use_cpu:
            D = (D + 15) // 16 * 4
        else:
            D = D * 4
        if not mixed:
            Ds = [D] * T
            Es = [E] * T
        else:
            Ds = [
                round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4)
                for _ in range(T)
            ]
            Es = [
                np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)
            ]

        if not mixed_B:
            Bs = [B] * T
            Bs_rank_feature = [[0]]
        else:
            Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, T)

        compute_device = ComputeDevice.CUDA
        if use_cpu:
            managed = [EmbeddingLocation.HOST] * T
            compute_device = ComputeDevice.CPU
        elif TEST_WITH_ROCM:
            # ROCm managed memory allocation is under development
            managed = [EmbeddingLocation.DEVICE] * T
        elif use_cache:
            managed = [EmbeddingLocation.MANAGED_CACHING] * T
            if mixed:
                average_D = sum(Ds) // T
                for t, d in enumerate(Ds):
                    managed[t] = (
                        EmbeddingLocation.DEVICE if d < average_D else managed[t]
                    )
        else:
            managed = [
                np.random.choice(
                    [
                        EmbeddingLocation.DEVICE,
                        EmbeddingLocation.MANAGED,
                    ]
                )
                for _ in range(T)
            ]
        if do_pooling:
            bs = [
                to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu)
                for (E, D) in zip(Es, Ds)
            ]
        else:
            bs = [
                to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu)
                for (E, D) in zip(Es, Ds)
            ]
        if weights_precision == SparseType.INT8:
            for t in range(T):
                bs[t].weight.data.copy_(
                    torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(
                        torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(
                            bs[t].weight.data
                        )
                    )
                )

        if weights_precision == SparseType.NFP8:
            for t in range(T):
                bs[t].weight.data.copy_(bs[t].weight.data.to(fp8_dtype).to(torch.float))

        if weights_precision == SparseType.FP16:
            bs = [b.half() for b in bs]

        # Generate indices
        xs = [
            to_device(torch.randint(low=0, high=e, size=(b, L)), use_cpu)
            for e, b in zip(Es, Bs)
        ]
        # Generate positional weights
        xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs]
        if weights_precision == SparseType.FP16:
            xws = [xw.half() for xw in xws]

        # Run baseline
        fs = (
            [
                b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling)
                for (b, x) in zip(bs, xs)
            ]
            if not weighted
            else [
                b_indices(
                    b,
                    x,
                    per_sample_weights=xw.view(-1),
                    use_cpu=use_cpu,
                    do_pooling=do_pooling,
                )
                for (b, x, xw) in zip(bs, xs, xws)
            ]
        )

        if do_pooling:
            if mixed_B:
                f = format_ref_tensors_in_mixed_B_layout(fs, Bs_rank_feature)
            else:
                f = torch.cat([f.view(B, -1) for f in fs], dim=1)
        else:
            f = torch.cat(fs, dim=0).view(-1, D)

        # Create RES parameters if raw embedding streaming is enabled
        res_params = None
        if enable_raw_embedding_streaming:
            res_params = RESParams(
                res_store_shards=1,
                table_names=[f"table_{i}" for i in range(T)],
                table_offsets=[sum(Es[:i]) for i in range(T + 1)],
                table_sizes=Es,
            )

        # Create a TBE op
        cc = emb_op(
            embedding_specs=[
                (
                    E,
                    D,
                    EmbeddingLocation(M),
                    compute_device,
                )
                for (E, D, M) in zip(Es, Ds, managed)
            ],
            weights_precision=weights_precision,
            optimizer=(
                OptimType.EXACT_ROWWISE_ADAGRAD if mixed_B else OptimType.EXACT_SGD
            ),
            learning_rate=0.05,
            cache_algorithm=cache_algorithm,
            pooling_mode=pooling_mode,
            output_dtype=output_dtype,
            use_experimental_tbe=use_experimental_tbe,
            prefetch_pipeline=prefetch_pipeline,
            enable_raw_embedding_streaming=enable_raw_embedding_streaming,
            res_params=res_params,
        )
        # Test torch JIT script compatibility
        if not use_cpu:
            cc = torch.jit.script(cc)

        for t in range(T):
            if weights_precision == SparseType.INT8:
                b_weight = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(
                    bs[t].weight
                )
            elif weights_precision == SparseType.NFP8:
                b_weight = bs[t].weight.to(fp8_dtype)
            else:
                b_weight = bs[t].weight
            cc.split_embedding_weights()[t].data.copy_(b_weight)

        x = torch.cat([x.contiguous().flatten() for x in xs], dim=0)
        xw = torch.cat([xw.contiguous().flatten() for xw in xws], dim=0)

        (indices, offsets) = get_table_batched_offsets_from_dense(
            x, L, sum(Bs), use_cpu
        )

        batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None

        # Run TBE
        fc2 = (
            cc(
                indices,
                offsets,
                batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
            )
            if not weighted
            else cc(
                indices,
                offsets,
                to_device(xw.contiguous().view(-1), use_cpu),
                batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
            )
        )

        # Compare results: f = baseline, fc2 = TBE
        tolerance = (
            1.0e-5
            if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32
            else 8.0e-3
        )
        torch.testing.assert_close(
            fc2.float(), f.float(), atol=tolerance, rtol=tolerance
        )

    def test_forward_cpu_int8(
        self,
    ) -> None:
        weights_precision = SparseType.INT8
        use_cpu = True
        T = random.randint(1, 10)
        D = random.randint(2, min(256, int(2048 / T)))
        B = random.randint(1, min(128, int(2048 / T / D)))
        L = random.randint(0, min(20, int(2048 / T / D / B)))
        log_E = random.randint(3, 5)

        use_cache = False
        # cache_algorithm is don't care as we don't use cache.
        cache_algorithm = CacheAlgorithm.LRU

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
            ]
        )
        mixed = False
        mixed_B = False
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            SparseType.FP32,
            False,  # use_experimental_tbe
        )

    def test_forward_cpu_fp32(
        self,
    ) -> None:
        weights_precision = SparseType.FP32
        use_cpu = True
        T = random.randint(1, 10)
        D = random.randint(2, min(256, int(2048 / T)))
        B = random.randint(1, min(128, int(2048 / T / D)))
        L = random.randint(0, min(20, int(2048 / T / D / B)))
        log_E = random.randint(3, 5)

        use_cache = False
        # cache_algorithm is don't care as we don't use cache.
        cache_algorithm = CacheAlgorithm.LRU

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
            ]
        )
        mixed = False
        mixed_B = random.choice([False, True])
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            SparseType.FP32,
            False,  # use_experimental_tbe
        )

    def test_forward_cpu_fp32_nobag(
        self,
    ) -> None:
        weights_precision = SparseType.FP32
        use_cpu = True
        T = random.randint(1, 10)
        D = random.randint(2, min(256, int(2048 / T)))
        B = random.randint(1, min(128, int(2048 / T / D)))
        L = random.randint(0, min(20, int(2048 / T / D / B)))
        log_E = random.randint(3, 5)

        use_cache = False
        # cache_algorithm is don't care as we don't use cache.
        cache_algorithm = CacheAlgorithm.LRU

        pooling_mode = PoolingMode.NONE
        mixed = False
        mixed_B = False
        weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            SparseType.FP32,
            False,  # use_experimental_tbe
        )

    @unittest.skipIf(True, "INT8 support is disabled")
    def test_forward_gpu_no_cache_int8(
        self,
    ) -> None:
        weights_precision = SparseType.INT8
        use_cpu = False
        T = random.randint(1, 10)
        D = random.randint(2, 256)
        B = random.randint(1, 128)
        L = random.randint(0, 20)
        log_E = random.randint(3, 5)

        use_cache = False
        # cache_algorithm is don't care as we don't use cache.
        cache_algorithm = CacheAlgorithm.LRU

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
                PoolingMode.NONE,
            ]
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
        else:
            mixed = random.choice([True, False])
        mixed_B = False
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            SparseType.FP32,
            False,  # use_experimental_tbe
        )

    def _test_forward_gpu_no_cache_fp16_impl(
        self,
        T: int,
        B: int,
        L: int,
        use_experimental_tbe: bool,
    ) -> None:
        weights_precision = SparseType.FP16
        use_cpu = False
        D = random.randint(2, 256)
        log_E = random.randint(3, 5)

        use_cache = False
        # cache_algorithm is don't care as we don't use cache.
        cache_algorithm = CacheAlgorithm.LRU

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
            ]
            + ([PoolingMode.NONE] if not use_experimental_tbe else [])
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
            mixed_B = False
        else:
            mixed = random.choice([True, False])
            mixed_B = (
                random.choice([True, False]) if not use_experimental_tbe else False
            )
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            SparseType.FP32,
            use_experimental_tbe,
        )

    @unittest.skipIf(*gpu_unavailable)
    @given(
        use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_gpu_no_cache_fp16(
        self,
        use_experimental_tbe: bool,
    ) -> None:
        return self._test_forward_gpu_no_cache_fp16_impl(
            random.randint(1, 10),
            random.randint(1, 128),
            random.randint(0, 20),
            use_experimental_tbe,
        )

    @unittest.skipIf(*gpu_unavailable)
    @given(
        use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_gpu_no_cache_fp16_large(
        self,
        use_experimental_tbe: bool,
    ) -> None:
        torch.cuda.empty_cache()

        max_num_threads = FORWARD_MAX_THREADS * get_max_thread_blocks(
            torch.cuda.current_stream()
        )
        # NOTE: L is arbitrarily chosen here
        L = 10
        # NOTE: Fix to the smallest value B such that (B x L) = (number of
        # indices) > (allowed grid size x block size)
        # Add 256 to B just in case B * L == max_num_threads
        B = 2 ** (math.ceil(math.log2(max_num_threads / L))) + 256
        # NOTE: T is chosen to be small enough to avoid OOM errors given that
        # B x L must be large enough
        T = 3

        assert (
            B * L > max_num_threads
        ), "Should be testing the case where B * L is larger than max_num_threads"

        return self._test_forward_gpu_no_cache_fp16_impl(
            T,
            B,
            L,
            use_experimental_tbe,
        )

    @optests.dontGenerateOpCheckTests("FP8 compute requires custom op support.")
    @unittest.skipIf(*gpu_unavailable)
    def test_forward_gpu_no_cache_fp8(
        self,
        use_experimental_tbe: bool = False,  # TODO This does not yet work when True.
    ) -> None:
        # Skip on rocm as fp8 is not supported for all versions.
        if not is_nvidia_device:
            return

        weights_precision = SparseType.NFP8
        use_cpu = False
        T = random.randint(1, 10)
        D = random.randint(2, 256)
        B = random.randint(1, 128)
        L = random.randint(0, 20)
        log_E = random.randint(3, 5)

        use_cache = False
        # cache_algorithm is don't care as we don't use cache.
        cache_algorithm = CacheAlgorithm.LRU

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
            ]
            + ([PoolingMode.NONE] if not use_experimental_tbe else [])
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
            mixed_B = False
        else:
            mixed = random.choice([True, False])
            mixed_B = (
                random.choice([True, False]) if not use_experimental_tbe else False
            )
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            SparseType.FP32,
            use_experimental_tbe,
        )

    @unittest.skipIf(*gpu_unavailable)
    @given(
        use_experimental_tbe=st.booleans(),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_gpu_no_cache_fp32(
        self,
        use_experimental_tbe: bool,
    ) -> None:
        weights_precision = SparseType.FP32
        use_cpu = False
        T = random.randint(1, 10)
        D = random.randint(2, 256)
        B = random.randint(1, 128)
        L = random.randint(0, 20)
        log_E = random.randint(3, 5)

        use_cache = False
        # cache_algorithm is don't care as we don't use cache.
        cache_algorithm = CacheAlgorithm.LRU

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
            ]
            + ([PoolingMode.NONE] if not use_experimental_tbe else [])
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
            mixed_B = False
        else:
            mixed = random.choice([True, False])
            mixed_B = (
                random.choice([True, False]) if not use_experimental_tbe else False
            )
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            SparseType.FP32,
            use_experimental_tbe,
        )

    @unittest.skipIf(True, "INT8 support is disabled")
    @given(
        cache_algorithm=st.sampled_from(CacheAlgorithm),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_gpu_uvm_cache_int8(
        self,
        cache_algorithm: CacheAlgorithm,
    ) -> None:
        weights_precision = SparseType.INT8
        use_cpu = False
        T = random.randint(1, 10)
        D = random.randint(2, 256)
        B = random.randint(1, 128)
        L = random.randint(0, 20)
        log_E = random.randint(3, 5)

        use_cache = True

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
                PoolingMode.NONE,
            ]
        )
        output_dtype = random.choice(
            [
                SparseType.FP32,
                SparseType.FP16,
                SparseType.BF16,
            ]
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
        else:
            mixed = random.choice([True, False])
        mixed_B = False
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            output_dtype,
            False,  # use_experimental_tbe
        )

    @unittest.skipIf(*gpu_unavailable)
    @optests.dontGenerateOpCheckTests("FP8 compute requires custom op support.")
    @given(
        cache_algorithm=st.sampled_from(CacheAlgorithm),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_gpu_uvm_cache_fp8(
        self,
        cache_algorithm: CacheAlgorithm,
    ) -> None:
        # Skip tests on rocm since it does not work for all versions.
        if not is_nvidia_device:
            return

        weights_precision = SparseType.NFP8
        use_cpu = False
        T = random.randint(1, 10)
        D = random.randint(2, 256)
        B = random.randint(1, 128)
        L = random.randint(0, 20)
        log_E = random.randint(3, 5)

        use_cache = True

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
                PoolingMode.NONE,
            ]
        )
        output_dtype = random.choice(
            [
                SparseType.FP32,
                SparseType.FP16,
                SparseType.BF16,
            ]
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
        else:
            mixed = random.choice([True, False])
        mixed_B = False
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            output_dtype,
            False,  # use_experimental_tbe
        )

    @unittest.skipIf(*gpu_unavailable)
    @given(
        cache_algorithm=st.sampled_from(CacheAlgorithm),
        use_experimental_tbe=st.booleans(),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_gpu_uvm_cache_fp16(
        self,
        cache_algorithm: CacheAlgorithm,
        use_experimental_tbe: bool,
    ) -> None:
        weights_precision = SparseType.FP16
        use_cpu = False
        T = random.randint(1, 10)
        D = random.randint(2, 256)
        B = random.randint(1, 128)
        L = random.randint(0, 20)
        log_E = random.randint(3, 5)

        use_cache = True

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
            ]
            + ([PoolingMode.NONE] if not use_experimental_tbe else [])
        )
        output_dtype = random.choice(
            [
                SparseType.FP32,
                SparseType.FP16,
                SparseType.BF16,
            ]
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
            mixed_B = False
        else:
            mixed = random.choice([True, False])
            mixed_B = (
                random.choice([True, False]) if not use_experimental_tbe else False
            )
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            output_dtype,
            use_experimental_tbe,
        )

    @unittest.skipIf(*gpu_unavailable)
    @given(
        cache_algorithm=st.sampled_from(CacheAlgorithm),
        use_experimental_tbe=st.booleans(),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_gpu_uvm_cache_fp32(
        self,
        cache_algorithm: CacheAlgorithm,
        use_experimental_tbe: bool,
    ) -> None:
        weights_precision = SparseType.FP32
        use_cpu = False
        T = random.randint(1, 10)
        D = random.randint(2, 256)
        B = random.randint(1, 128)
        L = random.randint(0, 20)
        log_E = random.randint(3, 5)

        use_cache = True

        pooling_mode = random.choice(
            [
                PoolingMode.SUM,
                PoolingMode.MEAN,
            ]
            + ([PoolingMode.NONE] if not use_experimental_tbe else [])
        )
        output_dtype = random.choice(
            [
                SparseType.FP32,
                SparseType.FP16,
                SparseType.BF16,
            ]
        )
        if pooling_mode == PoolingMode.NONE:
            mixed = False
            mixed_B = False
        else:
            mixed = random.choice([True, False])
            mixed_B = (
                random.choice([True, False]) if not use_experimental_tbe else False
            )
        if pooling_mode == PoolingMode.SUM:
            weighted = random.choice([True, False])
        else:
            weighted = False
        self.execute_forward_(
            T,
            D,
            B,
            log_E,
            L,
            weights_precision,
            weighted,
            mixed,
            mixed_B,
            use_cache,
            cache_algorithm,
            pooling_mode,
            use_cpu,
            output_dtype,
            use_experimental_tbe,
        )

    @unittest.skipIf(*gpu_unavailable)
    @given(
        T=st.integers(min_value=1, max_value=10),
        D=st.integers(min_value=2, max_value=128),
        B=st.integers(min_value=1, max_value=128),
        log_E=st.integers(min_value=3, max_value=5),
        L=st.integers(min_value=0, max_value=20),
        output_dtype=st.sampled_from([SparseType.FP16]),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much],
    )
    def test_forward_fused_pooled_emb_quant(
        self,
        T: int,
        D: int,
        B: int,
        log_E: int,
        L: int,
        output_dtype: SparseType,
    ) -> None:
        Ds = [
            round_up(np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)), 4)
            for _ in range(T)
        ]
        E = int(10**log_E)
        Es = [np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)]

        op = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (
                    E,
                    D,
                    EmbeddingLocation.DEVICE,
                    ComputeDevice.CUDA,
                )
                for (E, D) in zip(Es, Ds)
            ],
            output_dtype=output_dtype,
            device=torch.cuda.current_device(),
        )
        op_ref = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (
                    E,
                    D,
                    EmbeddingLocation.DEVICE,
                    ComputeDevice.CUDA,
                )
                for (E, D) in zip(Es, Ds)
            ],
            output_dtype=SparseType.FP32,
            device=torch.cuda.current_device(),
        )
        # sync weights between two ops
        split_weights = op.split_embedding_weights()
        ref_split_weights = op_ref.split_embedding_weights()
        for t in range(T):
            split_weights[t].data.copy_(ref_split_weights[t])

        requests = generate_requests(2, B, T, L, min(Es), reuse=0.1)

        for req in requests:
            indices, offsets = req.unpack_2()
            lowp_pooled_output = op(
                indices=indices,
                offsets=offsets,
            )
            fp32_pooled_output = op_ref(
                indices=indices,
                offsets=offsets,
            )
            lowp_pooled_emb_split = [
                d + 8 if output_dtype == SparseType.INT8 else d for d in op.dims
            ]
            lowp_pooled_output_per_table = torch.split(
                lowp_pooled_output, lowp_pooled_emb_split, dim=1
            )
            deq_lowp_pooled_output_per_table = [
                (
                    torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(t.contiguous())
                    if output_dtype == SparseType.INT8
                    else t.float()
                )
                for t in lowp_pooled_output_per_table
            ]
            fp32_pooled_output_per_table = torch.split(
                fp32_pooled_output, op.dims, dim=1
            )
            dq_fp32_pooled_output_per_table = [
                (
                    torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(
                        torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(
                            t.contiguous()
                        ).contiguous()
                    )
                    if output_dtype == SparseType.INT8
                    else t.half().float()
                )
                for t in fp32_pooled_output_per_table
            ]
            cat_deq_lowp_pooled_output = torch.cat(
                deq_lowp_pooled_output_per_table, dim=1
            )
            cat_dq_fp32_pooled_output = torch.cat(
                dq_fp32_pooled_output_per_table, dim=1
            )
            torch.testing.assert_close(
                cat_deq_lowp_pooled_output, cat_dq_fp32_pooled_output
            )

    def _check_raw_embedding_stream_call_counts(
        self,
        mock_raw_embedding_stream: unittest.mock.Mock,
        num_iterations: int,
        prefetch_pipeline: bool,
        L: int,
    ) -> None:
        # For TBE (not SSD), raw_embedding_stream is called once per prefetch
        # when there's data to stream
        expected_calls = num_iterations if L > 0 else 0
        if prefetch_pipeline:
            # With prefetch pipeline, there might be fewer calls initially
            expected_calls = max(0, expected_calls - 1)

        self.assertGreaterEqual(mock_raw_embedding_stream.call_count, 0)
        # Allow some flexibility in call count due to caching behavior
        self.assertLessEqual(mock_raw_embedding_stream.call_count, expected_calls + 2)

    @unittest.skipIf(*gpu_unavailable)
    @given(
        T=st.integers(min_value=1, max_value=5),
        D=st.integers(min_value=2, max_value=64),
        B=st.integers(min_value=1, max_value=32),
        log_E=st.integers(min_value=3, max_value=4),
        L=st.integers(min_value=1, max_value=10),
        weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
        cache_algorithm=st.sampled_from(CacheAlgorithm),
        pooling_mode=st.sampled_from([PoolingMode.SUM, PoolingMode.MEAN]),
        weighted=st.booleans(),
        mixed=st.booleans(),
        prefetch_pipeline=st.booleans(),
    )
    @settings(
        verbosity=VERBOSITY,
        max_examples=MAX_EXAMPLES_LONG_RUNNING,
        deadline=None,
        suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
    )
    def test_forward_raw_embedding_streaming(
        self,
        T: int,
        D: int,
        B: int,
        log_E: int,
        L: int,
        weights_precision: SparseType,
        cache_algorithm: CacheAlgorithm,
        pooling_mode: PoolingMode,
        weighted: bool,
        mixed: bool,
        prefetch_pipeline: bool,
    ) -> None:
        """Test raw embedding streaming functionality integrated with forward pass."""
        num_iterations = 5
        # only LRU supports prefetch_pipeline
        assume(not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU)

        with patch(
            "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
        ) as mock_streamer_class:
            # Mock the RawEmbeddingStreamer class
            mock_streamer_instance = MagicMock()
            mock_streamer_class.return_value = mock_streamer_instance

            # Run multiple iterations to test streaming behavior
            for _ in range(num_iterations):
                self.execute_forward_(
                    T=T,
                    D=D,
                    B=B,
                    log_E=log_E,
                    L=L,
                    weights_precision=weights_precision,
                    weighted=weighted,
                    mixed=mixed,
                    mixed_B=False,  # Keep simple for streaming tests
                    use_cache=True,  # Required for streaming
                    cache_algorithm=cache_algorithm,
                    pooling_mode=pooling_mode,
                    use_cpu=False,  # Streaming not supported on CPU
                    output_dtype=SparseType.FP32,
                    use_experimental_tbe=False,
                    enable_raw_embedding_streaming=True,
                    prefetch_pipeline=prefetch_pipeline,
                )

            self._check_raw_embedding_stream_call_counts(
                mock_streamer_instance, num_iterations, prefetch_pipeline, L
            )


if __name__ == "__main__":
    unittest.main()
