File size: 3,421 Bytes
cd16f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import annotations

import importlib.util
import shutil
from dataclasses import dataclass
from typing import Dict

import torch


@dataclass(frozen=True)
class XQSBackendReport:
    torch_version: str
    cuda_available: bool
    cuda_device_name: str
    bf16_supported: bool
    torch_compile_available: bool
    triton_available: bool
    deepspeed_available: bool
    bitsandbytes_available: bool
    flash_attn_available: bool
    nvcc_available: bool

    def as_dict(self) -> Dict[str, object]:
        return {
            "torch_version": self.torch_version,
            "cuda_available": self.cuda_available,
            "cuda_device_name": self.cuda_device_name,
            "bf16_supported": self.bf16_supported,
            "torch_compile_available": self.torch_compile_available,
            "triton_available": self.triton_available,
            "deepspeed_available": self.deepspeed_available,
            "bitsandbytes_available": self.bitsandbytes_available,
            "flash_attn_available": self.flash_attn_available,
            "nvcc_available": self.nvcc_available,
        }


def _has_module(name: str) -> bool:
    return importlib.util.find_spec(name) is not None



def detect_xqs_backends() -> XQSBackendReport:
    cuda_available = torch.cuda.is_available()
    device_name = torch.cuda.get_device_name(0) if cuda_available else "cpu"
    bf16_supported = bool(cuda_available and torch.cuda.is_bf16_supported())
    return XQSBackendReport(
        torch_version=torch.__version__,
        cuda_available=cuda_available,
        cuda_device_name=device_name,
        bf16_supported=bf16_supported,
        torch_compile_available=hasattr(torch, "compile"),
        triton_available=_has_module("triton"),
        deepspeed_available=_has_module("deepspeed"),
        bitsandbytes_available=_has_module("bitsandbytes"),
        flash_attn_available=_has_module("flash_attn"),
        nvcc_available=shutil.which("nvcc") is not None,
    )



def choose_attention_backend(prefer_flash: bool = True) -> str:
    report = detect_xqs_backends()
    if prefer_flash and report.flash_attn_available and report.cuda_available:
        return "flash_attn"
    if report.cuda_available:
        return "scaled_dot_product_attention"
    return "eager"



def choose_optimizer_backend(prefer_low_memory: bool = True) -> str:
    report = detect_xqs_backends()
    adamw_signature = getattr(torch.optim.AdamW, "__init__", None)
    fused_supported = bool(adamw_signature and "fused" in adamw_signature.__code__.co_varnames)
    if report.cuda_available and fused_supported:
        return "adamw_fused"
    if prefer_low_memory and report.bitsandbytes_available:
        return "adam8bit"
    if _has_module("transformers"):
        return "adafactor"
    return "sgd"



def choose_moe_backend(prefer_deepspeed: bool = True) -> str:
    report = detect_xqs_backends()
    if prefer_deepspeed and report.deepspeed_available and report.cuda_available:
        return "deepspeed"
    return "native"



def choose_quant_backend(prefer_triton: bool = True) -> str:
    report = detect_xqs_backends()
    if prefer_triton and report.triton_available and report.cuda_available:
        return "triton"
    return "pytorch"



def format_backend_report(report: XQSBackendReport) -> str:
    ordered = report.as_dict()
    return "\n".join(f"{key}={value}" for key, value in ordered.items())