bbkdevops's picture
download
raw
13.3 kB
from __future__ import annotations
from collections import defaultdict
from datetime import datetime, timezone
import json
from pathlib import Path
import shutil
from typing import Any
import torch
from safetensors.torch import load_file, save_file
def _dir_size(path: Path) -> int:
if not path.exists():
return 0
return sum(p.stat().st_size for p in path.rglob("*") if p.is_file())
def _adapter_weight_path(adapter_dir: Path) -> Path:
safe = adapter_dir / "adapter_model.safetensors"
if safe.exists():
return safe
checkpoints = sorted(adapter_dir.glob("checkpoint-*"), key=lambda p: p.stat().st_mtime, reverse=True)
for ckpt in checkpoints:
safe = ckpt / "adapter_model.safetensors"
if safe.exists():
return safe
raise FileNotFoundError(f"adapter_model.safetensors not found under {adapter_dir}")
def _config_path(adapter_dir: Path, weight_path: Path) -> Path:
for candidate in (adapter_dir / "adapter_config.json", weight_path.parent / "adapter_config.json"):
if candidate.exists():
return candidate
raise FileNotFoundError(f"adapter_config.json not found for {adapter_dir}")
def _pair_key(key: str, marker: str) -> str:
return key.replace(marker, "{lora}")
def _component_energy(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# LoRA delta is sum_i outer(B[:, i], A[i, :]). The squared Frobenius
# energy of each rank-1 component is ||B[:, i]||^2 * ||A[i, :]||^2.
return a.float().pow(2).sum(dim=1) * b.float().pow(2).sum(dim=0)
def _select_components(a: torch.Tensor, b: torch.Tensor, target_rank: int) -> tuple[torch.Tensor, torch.Tensor]:
rank = a.shape[0]
scores = _component_energy(a, b)
if target_rank >= rank:
return torch.arange(rank), scores
return torch.topk(scores, k=target_rank, largest=True).indices.sort().values, scores
def _lora_pairs(tensors: dict[str, torch.Tensor]) -> dict[str, dict[str, str]]:
a_marker = ".lora_A."
b_marker = ".lora_B."
pairs: dict[str, dict[str, str]] = defaultdict(dict)
for key in tensors:
if a_marker in key:
pairs[_pair_key(key, a_marker)]["a"] = key
elif b_marker in key:
pairs[_pair_key(key, b_marker)]["b"] = key
return pairs
def _module_name_from_pair(base: str, target_modules: list[str]) -> str:
for module in sorted(target_modules, key=len, reverse=True):
if f".{module}" in base or module in base:
return module
parts = base.replace("{lora}", ".").split(".")
return parts[-2] if len(parts) >= 2 else base
def choose_lora_rank_plan(adapter_dir: str | Path, min_energy_retained: float = 0.995) -> dict[str, Any]:
source = Path(adapter_dir)
weight_path = _adapter_weight_path(source)
tensors = load_file(str(weight_path))
pair_energies: list[torch.Tensor] = []
pair_plans: list[dict[str, Any]] = []
max_rank = 0
for pair_name, keys in sorted(_lora_pairs(tensors).items()):
if "a" not in keys or "b" not in keys:
continue
a = tensors[keys["a"]].cpu()
b = tensors[keys["b"]].cpu()
if a.ndim != 2 or b.ndim != 2 or a.shape[0] != b.shape[1]:
continue
energy = _component_energy(a, b).sort(descending=True).values
pair_energies.append(energy)
max_rank = max(max_rank, int(energy.numel()))
pair_total = float(energy.sum().item())
pair_rank = int(energy.numel())
pair_candidates: list[dict[str, Any]] = []
for rank in range(1, int(energy.numel()) + 1):
kept = float(energy[:rank].sum().item())
retained = kept / pair_total if pair_total > 0 else 1.0
pair_candidates.append(
{
"rank": rank,
"energy_retained_ratio": retained,
"energy_lost_ratio": 1.0 - retained,
"rank_fraction": rank / int(energy.numel()),
}
)
if retained >= min_energy_retained:
pair_rank = rank
break
pair_plans.append(
{
"pair": pair_name,
"recommended_rank": pair_rank,
"old_rank": int(energy.numel()),
"energy_retained_ratio": pair_candidates[pair_rank - 1]["energy_retained_ratio"],
"energy_lost_ratio": pair_candidates[pair_rank - 1]["energy_lost_ratio"],
"component_energy_desc": [float(x) for x in energy.tolist()],
"rank_candidates_until_selected": pair_candidates,
}
)
if not pair_energies:
raise ValueError(f"no compatible LoRA tensor pairs found under {adapter_dir}")
total_energy = sum(float(e.sum().item()) for e in pair_energies)
candidates: list[dict[str, Any]] = []
recommended_rank = max_rank
for rank in range(1, max_rank + 1):
kept = sum(float(e[: min(rank, e.numel())].sum().item()) for e in pair_energies)
retained = kept / total_energy if total_energy > 0 else 1.0
candidate = {
"rank": rank,
"global_energy_retained_ratio": retained,
"global_energy_lost_ratio": 1.0 - retained,
"rank_fraction": rank / max_rank,
"score": retained / (rank / max_rank),
}
candidates.append(candidate)
if retained >= min_energy_retained and recommended_rank == max_rank:
recommended_rank = rank
selected = candidates[recommended_rank - 1]
return {
"schema_version": "tinymind-lora-rank-plan-v1",
"source_adapter": str(source),
"source_weight_path": str(weight_path),
"min_energy_retained": min_energy_retained,
"max_rank": max_rank,
"recommended_rank": recommended_rank,
"global_energy_retained_ratio": selected["global_energy_retained_ratio"],
"global_energy_lost_ratio": selected["global_energy_lost_ratio"],
"adaptive_total_rank": sum(int(row["recommended_rank"]) for row in pair_plans),
"uniform_total_rank": recommended_rank * len(pair_plans),
"per_pair_plans": pair_plans,
"rank_candidates": candidates,
"math_objective": {
"rank_selection": "smallest rank satisfying retained LoRA delta Frobenius energy threshold",
"adaptive_rank_selection": "smallest per-pair rank satisfying retained LoRA delta Frobenius energy threshold",
"component_energy": "||B[:, i]||_2^2 * ||A[i, :]||_2^2",
},
}
def compact_lora_adapter(
adapter_dir: str | Path,
out_dir: str | Path,
target_rank: int,
*,
copy_tokenizer: bool = True,
adaptive_rank_plan: dict[str, Any] | None = None,
) -> dict[str, Any]:
source = Path(adapter_dir)
out = Path(out_dir)
if target_rank <= 0:
raise ValueError("target_rank must be positive")
weight_path = _adapter_weight_path(source)
config_path = _config_path(source, weight_path)
tensors = load_file(str(weight_path))
pairs = _lora_pairs(tensors)
adaptive_ranks = {
str(row["pair"]): int(row["recommended_rank"])
for row in (adaptive_rank_plan or {}).get("per_pair_plans", [])
if "pair" in row and "recommended_rank" in row
}
compacted = dict(tensors)
pair_reports: list[dict[str, Any]] = []
global_total_energy = 0.0
global_kept_energy = 0.0
module_ranks: dict[str, int] = {}
module_old_ranks: dict[str, int] = {}
config = json.loads(config_path.read_text(encoding="utf-8"))
target_modules = [str(x) for x in config.get("target_modules", [])] if isinstance(config.get("target_modules"), list) else []
for base, keys in sorted(pairs.items()):
if "a" not in keys or "b" not in keys:
continue
a_key, b_key = keys["a"], keys["b"]
a = tensors[a_key].cpu()
b = tensors[b_key].cpu()
if a.ndim != 2 or b.ndim != 2 or a.shape[0] != b.shape[1]:
continue
old_rank = int(a.shape[0])
new_rank = min(adaptive_ranks.get(base, target_rank), old_rank)
module_name = _module_name_from_pair(base, target_modules)
module_ranks[module_name] = max(module_ranks.get(module_name, 0), new_rank)
module_old_ranks[module_name] = max(module_old_ranks.get(module_name, 0), old_rank)
keep, energy = _select_components(a, b, new_rank)
total_energy = float(energy.sum().item())
kept_energy = float(energy.index_select(0, keep).sum().item())
global_total_energy += total_energy
global_kept_energy += kept_energy
retained_ratio = kept_energy / total_energy if total_energy > 0 else 1.0
compacted[a_key] = a.index_select(0, keep).contiguous()
compacted[b_key] = b.index_select(1, keep).contiguous()
pair_reports.append(
{
"pair": base,
"method": "delta_frobenius_energy_topk",
"rank_mode": "adaptive_per_pair" if adaptive_rank_plan else "uniform",
"old_rank": old_rank,
"new_rank": new_rank,
"kept_components": [int(i) for i in keep.tolist()],
"component_energy": [float(x) for x in energy.tolist()],
"total_delta_energy": total_energy,
"kept_delta_energy": kept_energy,
"energy_retained_ratio": retained_ratio,
"energy_lost_ratio": 1.0 - retained_ratio,
"rank_reduction": old_rank - new_rank,
}
)
out.mkdir(parents=True, exist_ok=True)
save_file(compacted, str(out / "adapter_model.safetensors"))
original_rank = config.get("r")
if isinstance(original_rank, int):
config["r"] = min(target_rank, original_rank)
else:
config["r"] = target_rank
if module_ranks:
config["rank_pattern"] = dict(sorted(module_ranks.items()))
alpha_pattern = config.get("alpha_pattern") if isinstance(config.get("alpha_pattern"), dict) else {}
new_alpha_pattern: dict[str, int] = {}
for module_name, new_rank in module_ranks.items():
old_rank = max(1, module_old_ranks.get(module_name, new_rank))
old_alpha = int(alpha_pattern.get(module_name, config.get("lora_alpha", max(new_rank, 1))))
# Preserve alpha/r scaling after rank pruning.
new_alpha_pattern[module_name] = max(1, int(round(old_alpha * new_rank / old_rank)))
config["alpha_pattern"] = dict(sorted(new_alpha_pattern.items()))
config["tinymind_compaction"] = {
"method": "lora_component_norm_pruning",
"target_rank": target_rank,
"source_adapter": str(source),
"requires_eval_before_promotion": True,
"rank_pattern_rewritten_from_tensor_shapes": bool(module_ranks),
}
(out / "adapter_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2, sort_keys=True), encoding="utf-8")
if copy_tokenizer:
for name in ("tokenizer.json", "tokenizer.model", "tokenizer_config.json", "special_tokens_map.json", "chat_template.jinja", "README.md"):
src = weight_path.parent / name
if not src.exists():
src = source / name
if src.exists():
shutil.copy2(src, out / name)
source_bytes = _dir_size(weight_path.parent)
out_bytes = _dir_size(out)
manifest = {
"schema_version": "tinymind-lora-compaction-v1",
"created_at": datetime.now(timezone.utc).isoformat(),
"math_objective": {
"rank_selection": "maximize retained LoRA delta Frobenius energy under target_rank",
"adaptive_per_pair_rank": bool(adaptive_rank_plan),
"component_energy": "||B[:, i]||_2^2 * ||A[i, :]||_2^2",
"constraint": "target_rank and adapter tensor shape compatibility",
},
"source_adapter": str(source),
"source_weight_path": str(weight_path),
"output_adapter": str(out),
"target_rank": target_rank,
"adaptive_rank_plan": adaptive_rank_plan,
"global_energy_retained_ratio": global_kept_energy / global_total_energy if global_total_energy > 0 else 1.0,
"global_energy_lost_ratio": 1.0 - (global_kept_energy / global_total_energy if global_total_energy > 0 else 1.0),
"pairs_compacted": len(pair_reports),
"pair_reports": pair_reports[:80],
"size": {
"source_bytes": source_bytes,
"output_bytes": out_bytes,
"source_mb": source_bytes / (1024 * 1024),
"output_mb": out_bytes / (1024 * 1024),
"reduction_ratio": 1.0 - (out_bytes / source_bytes) if source_bytes else 0.0,
},
"claim_gate": {
"smaller_adapter_created": out_bytes > 0 and out_bytes < source_bytes,
"quality_preserved_claim_allowed": False,
"reason": "Compaction reduces adapter size, but promotion requires eval loss/drift/generation quality evidence.",
},
}
manifest_path = out / "adapter_compaction_manifest.json"
manifest["manifest_path"] = str(manifest_path)
manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2, sort_keys=True), encoding="utf-8")
return manifest

Xet Storage Details

Size:
13.3 kB
·
Xet hash:
c1c9f2fc16edb37ff0eac3a9ff38424b5893b4119ff183d2d2adc69d13d65f32

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.