#!/bin/sh

# shellcheck disable=SC1091
. ../../lib/sh-test-lib
OUTPUT="$(pwd)/output"
RESULT_FILE="${OUTPUT}/result.txt"
export RESULT_FILE

SKIP_INSTALL="False"
MODELS="resnet18"
DEVICE="cpu"
TEST_MODE="eval"
ITERATIONS="10"
TORCHBENCH_REPO="https://github.com/pytorch/benchmark.git"
TORCHBENCH_VERSION="main"
PACKAGE_MANAGER="uv"
PYTHON_VERSION="system"
UV_VERSION="0.5.11"
TORCHBENCH_DIR="${OUTPUT}/benchmark"

usage() {
    echo "Usage: $0 [-s <true|false>] [-m MODELS] [-d DEVICE] [-t TEST_MODE] [-i ITERATIONS] [-r REPO] [-v VERSION] [-p PACKAGE_MANAGER] [-y PYTHON_VERSION] [-u UV_VERSION]" 1>&2
    echo "  -s: Skip install (default: False)"
    echo "  -m: Models to benchmark, comma-separated (default: resnet18)"
    echo "      Popular: resnet18, resnet50, mobilenet_v2, squeezenet1_1, mnasnet1_0"
    echo "      Use 'all' for all installed models"
    echo "  -d: Device: cpu, cuda (default: cpu)"
    echo "  -t: Test mode: eval, train (default: eval)"
    echo "  -i: Number of iterations (default: 10)"
    echo "  -r: TorchBench git repository URL"
    echo "  -v: TorchBench version/branch (default: main)"
    echo "  -p: Package manager: pip, uv (default: uv)"
    echo "  -y: Python version for uv: 3.11, 3.12, 3.13, system (default: system)"
    echo "  -u: uv version to install (default: 0.5.11)"
    exit 1
}

while getopts "s:m:d:t:i:r:v:p:y:u:h" o; do
    case "$o" in
        s) SKIP_INSTALL="${OPTARG}" ;;
        m) MODELS="${OPTARG}" ;;
        d) DEVICE="${OPTARG}" ;;
        t) TEST_MODE="${OPTARG}" ;;
        i) ITERATIONS="${OPTARG}" ;;
        r) TORCHBENCH_REPO="${OPTARG}" ;;
        v) TORCHBENCH_VERSION="${OPTARG}" ;;
        p) PACKAGE_MANAGER="${OPTARG}" ;;
        y) PYTHON_VERSION="${OPTARG}" ;;
        u) UV_VERSION="${OPTARG}" ;;
        h|*) usage ;;
    esac
done

install_uv() {
    if ! command -v uv > /dev/null 2>&1; then
        info_msg "Installing uv version ${UV_VERSION}..."
        curl -LsSf "https://astral.sh/uv/${UV_VERSION}/install.sh" | sh
        # shellcheck disable=SC1091
        . "$HOME/.local/bin/env"
    fi
}

install_deps_packages() {
    dist_name
    # shellcheck disable=SC2154
    case "${dist}" in
        debian|ubuntu)
            install_deps "python3 python3-pip python3-venv python3-psutil git build-essential curl ca-certificates" "${SKIP_INSTALL}"
            ;;
        fedora|centos)
            install_deps "python3 python3-pip python3-psutil git gcc gcc-c++ make curl ca-certificates" "${SKIP_INSTALL}"
            ;;
        *)
            warn_msg "Unsupported distro: ${dist}! Package install skipped"
            ;;
    esac
}

setup_python_env() {
    if [ "${SKIP_INSTALL}" = "False" ] || [ "${SKIP_INSTALL}" = "false" ]; then
        if [ "${PACKAGE_MANAGER}" = "uv" ]; then
            install_uv

            info_msg "Creating uv environment with Python ${PYTHON_VERSION}..."
            if [ "${PYTHON_VERSION}" = "system" ]; then
                uv venv "${OUTPUT}/venv" --system-site-packages
            else
                uv venv "${OUTPUT}/venv" --python "${PYTHON_VERSION}" --system-site-packages
            fi
        else
            # Use traditional pip
            info_msg "Setting up Python virtual environment with pip..."
            python3 -m venv "${OUTPUT}/venv" --system-site-packages
            pip install --upgrade pip
        fi

        # shellcheck disable=SC1091
        . "${OUTPUT}/venv/bin/activate"

        info_msg "Installing PyTorch..."
        if [ "${DEVICE}" = "cuda" ]; then
            "${PACKAGE_MANAGER}" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
        else
            "${PACKAGE_MANAGER}" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
        fi
    else
        if [ -d "${OUTPUT}/venv" ]; then
            # shellcheck disable=SC1091
            . "${OUTPUT}/venv/bin/activate"
        else
            warn_msg "No virtual environment found, using system Python"
        fi
    fi
}

clone_torchbench() {
    if [ "${SKIP_INSTALL}" = "False" ] || [ "${SKIP_INSTALL}" = "false" ]; then
        info_msg "Cloning TorchBench repository..."
        if [ -d "${TORCHBENCH_DIR}" ]; then
            rm -rf "${TORCHBENCH_DIR}"
        fi
        git clone --depth 1 --branch "${TORCHBENCH_VERSION}" "${TORCHBENCH_REPO}" "${TORCHBENCH_DIR}"
        check_return "torchbench-clone"

        cd "${TORCHBENCH_DIR}" || exit 1

        # Install TorchBench requirements
        info_msg "Installing TorchBench requirements..."
        "${PACKAGE_MANAGER}" pip install -r requirements.txt
        check_return "torchbench-requirements"
    else
        if [ ! -d "${TORCHBENCH_DIR}" ]; then
            error_msg "TorchBench directory not found. Run without SKIP_INSTALL first."
        fi
        cd "${TORCHBENCH_DIR}" || exit 1
    fi
}

install_models() {
    if [ "${SKIP_INSTALL}" = "False" ] || [ "${SKIP_INSTALL}" = "false" ]; then
        info_msg "Installing models: ${MODELS}"

        if [ "${MODELS}" = "all" ]; then
            # Install all models (this takes a long time and lots of disk space)
            python install.py
            check_return "torchbench-install-all-models"
        else
            # Install specific models
            for model in $(echo "${MODELS}" | tr ',' ' '); do
                info_msg "Installing model: ${model}"
                python install.py "${model}"
                check_return "torchbench-install-${model}"
            done
        fi
    fi
}

check_pytorch() {
    info_msg "Checking PyTorch installation..."
    python3 -c "import torch; print(f'PyTorch version: {torch.__version__}')"
    check_return "pytorch-import"

    if [ "${DEVICE}" = "cuda" ]; then
        python3 -c "import torch; assert torch.cuda.is_available(), 'CUDA not available'"
        check_return "cuda-available"
    fi
}

run_benchmark() {
    # shellcheck disable=SC2039
    local model="$1"
    info_msg "Running benchmark for model: ${model}"

    cd "${TORCHBENCH_DIR}" || exit 1

    # Run the benchmark and capture output
    benchmark_output="${OUTPUT}/${model}_benchmark.txt"

    if [ "${TEST_MODE}" = "train" ]; then
        python run.py "${model}" -d "${DEVICE}" -t train --iterations "${ITERATIONS}" > "${benchmark_output}" 2>&1
    else
        python run.py "${model}" -d "${DEVICE}" -t eval --iterations "${ITERATIONS}" > "${benchmark_output}" 2>&1
    fi

    # shellcheck disable=SC2039
    local exit_code=$?

    if [ ${exit_code} -eq 0 ]; then
        report_pass "torchbench-${model}-${TEST_MODE}"

        # Parse and report metrics from output
        parse_benchmark_results "${model}" "${benchmark_output}"
    else
        report_fail "torchbench-${model}-${TEST_MODE}"
        cat "${benchmark_output}"
    fi
}

parse_benchmark_results() {
    # shellcheck disable=SC2039
    local model="$1"
    # shellcheck disable=SC2039
    local output_file="$2"

    # TorchBench outputs metrics in various formats
    # Try to extract throughput/latency if available
    if [ -f "${output_file}" ]; then
        # Look for common metric patterns
        latency=$(grep -i "latency" "${output_file}" | grep -oE '[0-9]+\.?[0-9]*' | head -1)
        throughput=$(grep -i "throughput\|images/s\|samples/s" "${output_file}" | grep -oE '[0-9]+\.?[0-9]*' | head -1)

        if [ -n "${latency}" ]; then
            add_metric "torchbench-${model}-latency" "pass" "${latency}" "ms"
        fi

        if [ -n "${throughput}" ]; then
            add_metric "torchbench-${model}-throughput" "pass" "${throughput}" "samples/sec"
        fi

        # Show raw output for debugging
        info_msg "Raw benchmark output:"
        cat "${output_file}"
    fi
}

list_available_models() {
    info_msg "Available models in TorchBench:"
    cd "${TORCHBENCH_DIR}" || exit 1
    python -c "from torchbenchmark import list_models; print('\n'.join([m.name for m in list_models()]))" 2>/dev/null || \
        ls -1 torchbenchmark/models/ 2>/dev/null || \
        echo "Could not list models"
}

# Main execution
create_out_dir "${OUTPUT}"

install_deps_packages
setup_python_env
check_pytorch
clone_torchbench
install_models

# List available models for reference
list_available_models

# Run benchmarks
if [ "${MODELS}" = "all" ]; then
    # Get list of installed models and run each
    cd "${TORCHBENCH_DIR}" || exit 1
    model_list=$(python -c "from torchbenchmark import list_models; print(' '.join([m.name for m in list_models()]))" 2>/dev/null)
    if [ -z "${model_list}" ]; then
        # Fallback: list directories
        model_list=$(ls -1 torchbenchmark/models/ 2>/dev/null | tr '\n' ' ')
    fi

    for model in ${model_list}; do
        run_benchmark "${model}"
    done
else
    for model in $(echo "${MODELS}" | tr ',' ' '); do
        run_benchmark "${model}"
    done
fi

info_msg "TorchBench testing completed"
