| import html |
| import os |
| import re |
|
|
| import gradio as gr |
| import modules.textual_inversion.preprocess |
| import modules.textual_inversion.textual_inversion |
| from modules import devices, sd_hijack, shared |
| from modules.hypernetworks import hypernetwork |
|
|
| not_available = ["hardswish", "multiheadattention"] |
| keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) |
|
|
| def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): |
| |
| name = "".join( x for x in name if (x.isalnum() or x in "._- ")) |
|
|
| fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") |
| if not overwrite_old: |
| assert not os.path.exists(fn), f"file {fn} already exists" |
|
|
| if type(layer_structure) == str: |
| layer_structure = [float(x.strip()) for x in layer_structure.split(",")] |
|
|
| hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( |
| name=name, |
| enable_sizes=[int(x) for x in enable_sizes], |
| layer_structure=layer_structure, |
| activation_func=activation_func, |
| weight_init=weight_init, |
| add_layer_norm=add_layer_norm, |
| use_dropout=use_dropout, |
| ) |
| hypernet.save(fn) |
|
|
| shared.reload_hypernetworks() |
|
|
| return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" |
|
|
|
|
| def train_hypernetwork(*args): |
|
|
| initial_hypernetwork = shared.loaded_hypernetwork |
|
|
| assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' |
|
|
| try: |
| sd_hijack.undo_optimizations() |
|
|
| hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) |
|
|
| res = f""" |
| Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. |
| Hypernetwork saved to {html.escape(filename)} |
| """ |
| return res, "" |
| except Exception: |
| raise |
| finally: |
| shared.loaded_hypernetwork = initial_hypernetwork |
| shared.sd_model.cond_stage_model.to(devices.device) |
| shared.sd_model.first_stage_model.to(devices.device) |
| sd_hijack.apply_optimizations() |
|
|
|
|