#!/bin/sh

# shellcheck disable=SC1091
. ../../lib/sh-test-lib
OUTPUT="$(pwd)/output"
RESULT_FILE="${OUTPUT}/result.txt"
export RESULT_FILE

SKIP_INSTALL="False"
WORKLOADS="all"
DEVICE="cpu"
MATRIX_SIZE="1024"
BATCH_SIZE="32"
ITERATIONS="100"
PACKAGE_MANAGER="uv"
PYTHON_VERSION="system"
UV_VERSION="0.5.11"

usage() {
    echo "Usage: $0 [-s <true|false>] [-w WORKLOADS] [-d DEVICE] [-m MATRIX_SIZE] [-b BATCH_SIZE] [-i ITERATIONS] [-p PACKAGE_MANAGER] [-v PYTHON_VERSION] [-u UV_VERSION]" 1>&2
    echo "  -s: Skip install (default: False)"
    echo "  -w: Workloads to run: tensor_ops, matrix_multiply, conv2d, linear, inference, training, all (default: all)"
    echo "  -d: Device: cpu, cuda (default: cpu)"
    echo "  -m: Matrix size for matrix multiply (default: 1024)"
    echo "  -b: Batch size for neural network workloads (default: 32)"
    echo "  -i: Number of iterations (default: 100)"
    echo "  -p: Package manager: pip, uv (default: uv)"
    echo "  -v: Python version for uv: 3.11, 3.12, 3.13, system (default: system)"
    echo "  -u: uv version to install (default: 0.5.11)"
    exit 1
}

while getopts "s:w:d:m:b:i:p:v:u:h" o; do
    case "$o" in
        s) SKIP_INSTALL="${OPTARG}" ;;
        w) WORKLOADS="${OPTARG}" ;;
        d) DEVICE="${OPTARG}" ;;
        m) MATRIX_SIZE="${OPTARG}" ;;
        b) BATCH_SIZE="${OPTARG}" ;;
        i) ITERATIONS="${OPTARG}" ;;
        p) PACKAGE_MANAGER="${OPTARG}" ;;
        v) PYTHON_VERSION="${OPTARG}" ;;
        u) UV_VERSION="${OPTARG}" ;;
        h|*) usage ;;
    esac
done

install_uv() {
    if ! command -v uv > /dev/null 2>&1; then
        info_msg "Installing uv version ${UV_VERSION}..."
        curl -LsSf "https://astral.sh/uv/${UV_VERSION}/install.sh" | sh
        # shellcheck disable=SC1091
        . "$HOME/.local/bin/env"
    fi
}

install() {
    dist_name
    # shellcheck disable=SC2154
    case "${dist}" in
        debian|ubuntu)
            install_deps "python3 python3-pip python3-venv curl ca-certificates" "${SKIP_INSTALL}"
            ;;
        fedora|centos)
            install_deps "python3 python3-pip curl ca-certificates" "${SKIP_INSTALL}"
            ;;
        *)
            warn_msg "Unsupported distro: ${dist}! Package install skipped"
            ;;
    esac

    if [ "${SKIP_INSTALL}" = "False" ] || [ "${SKIP_INSTALL}" = "false" ]; then
        if [ "${PACKAGE_MANAGER}" = "uv" ]; then
            install_uv

            info_msg "Creating uv environment with Python ${PYTHON_VERSION}..."
            if [ "${PYTHON_VERSION}" = "system" ]; then
                uv venv "${OUTPUT}/venv" --system-site-packages
            else
                uv venv "${OUTPUT}/venv" --python "${PYTHON_VERSION}" --system-site-packages
            fi
        else
            # Use traditional pip
            info_msg "Setting up Python virtual environment with pip..."
            python3 -m venv "${OUTPUT}/venv" --system-site-packages
            pip install --upgrade pip
        fi

        # shellcheck disable=SC1091
        . "${OUTPUT}/venv/bin/activate"

        info_msg "Installing PyTorch..."
        if [ "${DEVICE}" = "cuda" ]; then
            "${PACKAGE_MANAGER}" pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
        else
            "${PACKAGE_MANAGER}" pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
        fi
    else
        if [ -d "${OUTPUT}/venv" ]; then
            # shellcheck disable=SC1091
            . "${OUTPUT}/venv/bin/activate"
        fi
    fi
}

check_pytorch() {
    info_msg "Checking PyTorch installation..."
    python3 -c "import torch; print(f'PyTorch version: {torch.__version__}')"
    check_return "pytorch-import"

    if [ "${DEVICE}" = "cuda" ]; then
        python3 -c "import torch; assert torch.cuda.is_available(), 'CUDA not available'"
        check_return "cuda-available"
    fi
}

run_tensor_ops() {
    info_msg "Running tensor operations benchmark..."
    python3 << EOF
import torch
import time

device = torch.device("${DEVICE}")
iterations = ${ITERATIONS}
size = ${MATRIX_SIZE}

# Tensor creation
start = time.time()
for _ in range(iterations):
    x = torch.randn(size, size, device=device)
elapsed = time.time() - start
print(f"tensor_creation pass {elapsed/iterations*1000:.3f} ms")

# Element-wise operations
x = torch.randn(size, size, device=device)
y = torch.randn(size, size, device=device)

start = time.time()
for _ in range(iterations):
    z = x + y
    if device.type == "cuda":
        torch.cuda.synchronize()
elapsed = time.time() - start
print(f"tensor_add pass {elapsed/iterations*1000:.3f} ms")

start = time.time()
for _ in range(iterations):
    z = x * y
    if device.type == "cuda":
        torch.cuda.synchronize()
elapsed = time.time() - start
print(f"tensor_mul pass {elapsed/iterations*1000:.3f} ms")

start = time.time()
for _ in range(iterations):
    z = torch.sin(x)
    if device.type == "cuda":
        torch.cuda.synchronize()
elapsed = time.time() - start
print(f"tensor_sin pass {elapsed/iterations*1000:.3f} ms")
EOF
    check_return "tensor-ops"
}

run_matrix_multiply() {
    info_msg "Running matrix multiplication benchmark..."
    python3 << EOF
import torch
import time

device = torch.device("${DEVICE}")
iterations = ${ITERATIONS}
size = ${MATRIX_SIZE}

a = torch.randn(size, size, device=device)
b = torch.randn(size, size, device=device)

# Warmup
for _ in range(10):
    c = torch.mm(a, b)
if device.type == "cuda":
    torch.cuda.synchronize()

start = time.time()
for _ in range(iterations):
    c = torch.mm(a, b)
if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = time.time() - start

avg_time_ms = elapsed / iterations * 1000
flops = 2 * size**3 * iterations / elapsed / 1e9  # GFLOPS
print(f"matrix_multiply pass {avg_time_ms:.3f} ms")
print(f"matrix_multiply_gflops pass {flops:.2f} GFLOPS")
EOF
    check_return "matrix-multiply"
}

run_conv2d() {
    info_msg "Running Conv2D benchmark..."
    python3 << EOF
import torch
import torch.nn as nn
import time

device = torch.device("${DEVICE}")
iterations = ${ITERATIONS}
batch_size = ${BATCH_SIZE}

# Create a typical conv layer (like in ResNet)
conv = nn.Conv2d(64, 128, kernel_size=3, padding=1).to(device)
x = torch.randn(batch_size, 64, 56, 56, device=device)

# Warmup
for _ in range(10):
    y = conv(x)
if device.type == "cuda":
    torch.cuda.synchronize()

start = time.time()
for _ in range(iterations):
    y = conv(x)
if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = time.time() - start

avg_time_ms = elapsed / iterations * 1000
print(f"conv2d_forward pass {avg_time_ms:.3f} ms")
EOF
    check_return "conv2d"
}

run_linear() {
    info_msg "Running Linear layer benchmark..."
    python3 << EOF
import torch
import torch.nn as nn
import time

device = torch.device("${DEVICE}")
iterations = ${ITERATIONS}
batch_size = ${BATCH_SIZE}

# Create a linear layer
linear = nn.Linear(1024, 1024).to(device)
x = torch.randn(batch_size, 1024, device=device)

# Warmup
for _ in range(10):
    y = linear(x)
if device.type == "cuda":
    torch.cuda.synchronize()

start = time.time()
for _ in range(iterations):
    y = linear(x)
if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = time.time() - start

avg_time_ms = elapsed / iterations * 1000
print(f"linear_forward pass {avg_time_ms:.3f} ms")
EOF
    check_return "linear"
}

run_inference() {
    info_msg "Running inference benchmark with ResNet-18..."
    python3 << EOF
import torch
import torchvision.models as models
import time

device = torch.device("${DEVICE}")
iterations = ${ITERATIONS}
batch_size = ${BATCH_SIZE}

# Load pre-defined model (no pretrained weights to avoid download)
model = models.resnet18(weights=None).to(device)
model.eval()

x = torch.randn(batch_size, 3, 224, 224, device=device)

# Warmup
with torch.no_grad():
    for _ in range(10):
        y = model(x)
if device.type == "cuda":
    torch.cuda.synchronize()

start = time.time()
with torch.no_grad():
    for _ in range(iterations):
        y = model(x)
if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = time.time() - start

avg_time_ms = elapsed / iterations * 1000
throughput = batch_size * iterations / elapsed
print(f"resnet18_inference pass {avg_time_ms:.3f} ms")
print(f"resnet18_throughput pass {throughput:.2f} images/sec")
EOF
    check_return "inference"
}

run_training() {
    info_msg "Running training benchmark with simple model..."
    python3 << EOF
import torch
import torch.nn as nn
import torch.optim as optim
import time

device = torch.device("${DEVICE}")
iterations = ${ITERATIONS}
batch_size = ${BATCH_SIZE}

# Simple MLP for training benchmark
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.layers(x)

model = SimpleMLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

x = torch.randn(batch_size, 784, device=device)
y = torch.randint(0, 10, (batch_size,), device=device)

# Warmup
for _ in range(10):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()
if device.type == "cuda":
    torch.cuda.synchronize()

start = time.time()
for _ in range(iterations):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()
if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = time.time() - start

avg_time_ms = elapsed / iterations * 1000
print(f"mlp_training_step pass {avg_time_ms:.3f} ms")
EOF
    check_return "training"
}

# Test run
create_out_dir "${OUTPUT}"

install
check_pytorch

# Run selected workloads
case "${WORKLOADS}" in
    all)
        run_tensor_ops
        run_matrix_multiply
        run_conv2d
        run_linear
        run_inference
        run_training
        ;;
    tensor_ops)
        run_tensor_ops
        ;;
    matrix_multiply)
        run_matrix_multiply
        ;;
    conv2d)
        run_conv2d
        ;;
    linear)
        run_linear
        ;;
    inference)
        run_inference
        ;;
    training)
        run_training
        ;;
    *)
        # Support comma-separated list
        for workload in $(echo "${WORKLOADS}" | tr ',' ' '); do
            case "${workload}" in
                tensor_ops) run_tensor_ops ;;
                matrix_multiply) run_matrix_multiply ;;
                conv2d) run_conv2d ;;
                linear) run_linear ;;
                inference) run_inference ;;
                training) run_training ;;
                *) warn_msg "Unknown workload: ${workload}" ;;
            esac
        done
        ;;
esac
