Instructions to use lsmpp/kontextrefiner with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use lsmpp/kontextrefiner with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("lsmpp/kontextrefiner", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import gc | |
| import inspect | |
| import logging | |
| import os | |
| import queue | |
| import threading | |
| from contextlib import nullcontext | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict, Optional, Union | |
| import pandas as pd | |
| import torch | |
| import torch.utils.benchmark as benchmark | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils.testing_utils import require_torch_gpu, torch_device | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") | |
| logger = logging.getLogger(__name__) | |
| NUM_WARMUP_ROUNDS = 5 | |
| def benchmark_fn(f, *args, **kwargs): | |
| t0 = benchmark.Timer( | |
| stmt="f(*args, **kwargs)", | |
| globals={"args": args, "kwargs": kwargs, "f": f}, | |
| num_threads=1, | |
| ) | |
| return float(f"{(t0.blocked_autorange().mean):.3f}") | |
| def flush(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_max_memory_allocated() | |
| torch.cuda.reset_peak_memory_stats() | |
| # Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py | |
| def calculate_flops(model, input_dict): | |
| try: | |
| from torchprofile import profile_macs | |
| except ModuleNotFoundError: | |
| raise | |
| # This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs. | |
| sig = inspect.signature(model.forward) | |
| param_names = [ | |
| p.name | |
| for p in sig.parameters.values() | |
| if p.kind | |
| in ( | |
| inspect.Parameter.POSITIONAL_ONLY, | |
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
| ) | |
| and p.name != "self" | |
| ] | |
| bound = sig.bind_partial(**input_dict) | |
| bound.apply_defaults() | |
| args = tuple(bound.arguments[name] for name in param_names) | |
| model.eval() | |
| with torch.no_grad(): | |
| macs = profile_macs(model, args) | |
| flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition) | |
| return flops | |
| def calculate_params(model): | |
| return sum(p.numel() for p in model.parameters()) | |
| # Users can define their own in case this doesn't suffice. For most cases, | |
| # it should be sufficient. | |
| def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs): | |
| model = model_cls.from_pretrained(**init_kwargs).eval() | |
| if group_offload_kwargs and isinstance(group_offload_kwargs, dict): | |
| model.enable_group_offload(**group_offload_kwargs) | |
| else: | |
| model.to(torch_device) | |
| if layerwise_upcasting: | |
| model.enable_layerwise_casting( | |
| storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16) | |
| ) | |
| return model | |
| class BenchmarkScenario: | |
| name: str | |
| model_cls: ModelMixin | |
| model_init_kwargs: Dict[str, Any] | |
| model_init_fn: Callable | |
| get_model_input_dict: Callable | |
| compile_kwargs: Optional[Dict[str, Any]] = None | |
| class BenchmarkMixin: | |
| def pre_benchmark(self): | |
| flush() | |
| torch.compiler.reset() | |
| def post_benchmark(self, model): | |
| model.cpu() | |
| flush() | |
| torch.compiler.reset() | |
| def run_benchmark(self, scenario: BenchmarkScenario): | |
| # 0) Basic stats | |
| logger.info(f"Running scenario: {scenario.name}.") | |
| try: | |
| model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs) | |
| num_params = round(calculate_params(model) / 1e9, 2) | |
| try: | |
| flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2) | |
| except Exception as e: | |
| logger.info(f"Problem in calculating FLOPs:\n{e}") | |
| flops = None | |
| model.cpu() | |
| del model | |
| except Exception as e: | |
| logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}") | |
| return {} | |
| self.pre_benchmark() | |
| # 1) plain stats | |
| results = {} | |
| plain = None | |
| try: | |
| plain = self._run_phase( | |
| model_cls=scenario.model_cls, | |
| init_fn=scenario.model_init_fn, | |
| init_kwargs=scenario.model_init_kwargs, | |
| get_input_fn=scenario.get_model_input_dict, | |
| compile_kwargs=None, | |
| ) | |
| except Exception as e: | |
| logger.info(f"Benchmark could not be run with the following error:\n{e}") | |
| return results | |
| # 2) compiled stats (if any) | |
| compiled = {"time": None, "memory": None} | |
| if scenario.compile_kwargs: | |
| try: | |
| compiled = self._run_phase( | |
| model_cls=scenario.model_cls, | |
| init_fn=scenario.model_init_fn, | |
| init_kwargs=scenario.model_init_kwargs, | |
| get_input_fn=scenario.get_model_input_dict, | |
| compile_kwargs=scenario.compile_kwargs, | |
| ) | |
| except Exception as e: | |
| logger.info(f"Compilation benchmark could not be run with the following error\n: {e}") | |
| if plain is None: | |
| return results | |
| # 3) merge | |
| result = { | |
| "scenario": scenario.name, | |
| "model_cls": scenario.model_cls.__name__, | |
| "num_params_B": num_params, | |
| "flops_G": flops, | |
| "time_plain_s": plain["time"], | |
| "mem_plain_GB": plain["memory"], | |
| "time_compile_s": compiled["time"], | |
| "mem_compile_GB": compiled["memory"], | |
| } | |
| if scenario.compile_kwargs: | |
| result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False) | |
| result["mode"] = scenario.compile_kwargs.get("mode", "default") | |
| else: | |
| result["fullgraph"], result["mode"] = None, None | |
| return result | |
| def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str): | |
| if not isinstance(scenarios, list): | |
| scenarios = [scenarios] | |
| record_queue = queue.Queue() | |
| stop_signal = object() | |
| def _writer_thread(): | |
| while True: | |
| item = record_queue.get() | |
| if item is stop_signal: | |
| break | |
| df_row = pd.DataFrame([item]) | |
| write_header = not os.path.exists(filename) | |
| df_row.to_csv(filename, mode="a", header=write_header, index=False) | |
| record_queue.task_done() | |
| record_queue.task_done() | |
| writer = threading.Thread(target=_writer_thread, daemon=True) | |
| writer.start() | |
| for s in scenarios: | |
| try: | |
| record = self.run_benchmark(s) | |
| if record: | |
| record_queue.put(record) | |
| else: | |
| logger.info(f"Record empty from scenario: {s.name}.") | |
| except Exception as e: | |
| logger.info(f"Running scenario ({s.name}) led to error:\n{e}") | |
| record_queue.put(stop_signal) | |
| logger.info(f"Results serialized to {filename=}.") | |
| def _run_phase( | |
| self, | |
| *, | |
| model_cls: ModelMixin, | |
| init_fn: Callable, | |
| init_kwargs: Dict[str, Any], | |
| get_input_fn: Callable, | |
| compile_kwargs: Optional[Dict[str, Any]], | |
| ) -> Dict[str, float]: | |
| # setup | |
| self.pre_benchmark() | |
| # init & (optional) compile | |
| model = init_fn(model_cls, **init_kwargs) | |
| if compile_kwargs: | |
| model.compile(**compile_kwargs) | |
| # build inputs | |
| inp = get_input_fn() | |
| # measure | |
| run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext() | |
| with run_ctx: | |
| for _ in range(NUM_WARMUP_ROUNDS): | |
| _ = model(**inp) | |
| time_s = benchmark_fn(lambda m, d: m(**d), model, inp) | |
| mem_gb = torch.cuda.max_memory_allocated() / (1024**3) | |
| mem_gb = round(mem_gb, 2) | |
| # teardown | |
| self.post_benchmark(model) | |
| del model | |
| return {"time": time_s, "memory": mem_gb} | |