| | import os |
| | import torch |
| | import gradio as gr |
| |
|
| | import modules.scripts as scripts |
| | import modules.shared as shared |
| | from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks |
| | from modules.prompt_parser import SdConditioning |
| |
|
| |
|
| | class Script(scripts.Script): |
| |
|
| | def title(self): |
| | return "Negative Prompt Weight Extention" |
| |
|
| | def show(self, is_img2img): |
| | return scripts.AlwaysVisible |
| |
|
| | def ui(self, is_img2img): |
| | with gr.Accordion("Negative Prompt Weight", open=True, elem_id="npw"): |
| | with gr.Row(equal_height=True): |
| | with gr.Column(scale=100): |
| | weight_input_slider = gr.Slider(minimum=0.00, maximum=2.00, step=.05, value=1.00, label="Weight", interactive=True, elem_id="npw-slider") |
| | with gr.Column(scale=1, min_width=120): |
| | with gr.Row(): |
| | weight_input = gr.Number(value=1.00, precision=4, label="Negative Prompt Weight", show_label=False, elem_id="npw-number") |
| | reset_but = gr.Button(value='✕', elem_id='npw-x', size='sm') |
| |
|
| | js = """(v) => { |
| | ['#tab_txt2img #npw-x', '#tab_img2img #npw-x'].forEach((selector, index) => { |
| | const element = document.querySelector(selector); |
| | if (document.querySelector(`#tab_${index ? 'img2img' : 'txt2img'}`).style.display === 'block') { |
| | element.style.cssText += `outline:4px solid rgba(255,186,0,${Math.sqrt(Math.abs(v-1))}); border-radius: 0.4em !important;`; |
| | } |
| | }); |
| | return v; |
| | }""" |
| | |
| | weight_input.change(None, [weight_input], weight_input_slider, _js=js) |
| | weight_input_slider.change(None, weight_input_slider, weight_input, _js="(x) => x") |
| | reset_but.click(None, [], [weight_input,weight_input_slider], _js="(x) => [1,1]") |
| |
|
| |
|
| | self.infotext_fields = [] |
| | self.infotext_fields.extend([ |
| | (weight_input, "NPW_weight"), |
| | ]) |
| | self.paste_field_names = [] |
| | for _, field_name in self.infotext_fields: |
| | self.paste_field_names.append(field_name) |
| |
|
| | return [weight_input] |
| | |
| |
|
| | def process(self, p, weight): |
| |
|
| | weight = getattr(p, 'NPW_weight', weight) |
| | if weight != 1 : self.print_warning(weight) |
| | self.width = p.width |
| | self.height = p.height |
| | self.weight = weight |
| | self.empty_uncond = None |
| | |
| |
|
| | if hasattr(self, 'callbacks_added'): |
| | remove_current_script_callbacks() |
| | delattr(self, 'callbacks_added') |
| | |
| |
|
| | if self.weight != 1.0: |
| | self.empty_uncond = self.make_empty_uncond(self.width, self.height) |
| | on_cfg_denoiser(self.denoiser_callback) |
| | |
| | self.callbacks_added = True |
| |
|
| | p.extra_generation_params.update({ |
| | "NPW_weight": self.weight, |
| | }) |
| |
|
| | return |
| |
|
| | def postprocess(self, p, processed, *args): |
| | if hasattr(self, 'callbacks_added'): |
| | remove_current_script_callbacks() |
| | delattr(self, 'callbacks_added') |
| | |
| |
|
| | def denoiser_callback(self, params): |
| | def concat_and_lerp(empty, tensor, weight): |
| | if empty.shape[0] != tensor.shape[0]: |
| | empty = empty.expand(tensor.shape[0], *empty.shape[1:]) |
| | if tensor.shape[1] > empty.shape[1]: |
| | num_concatenations = tensor.shape[1] // empty.shape[1] |
| | empty_concat = torch.cat([empty] * num_concatenations, dim=1) |
| | if tensor.shape[1] == empty_concat.shape[1] + 1: |
| | |
| | empty_concat = torch.cat([tensor[:, :1, :], empty_concat], dim=1) |
| | new_tensor = torch.lerp(empty_concat, tensor, weight) |
| | else: |
| | new_tensor = torch.lerp(empty, tensor, weight) |
| | return new_tensor |
| |
|
| | uncond = params.text_uncond |
| | is_dict = isinstance(uncond, dict) |
| | if type(self.empty_uncond) != type(uncond): |
| | self.empty_uncond = self.make_empty_uncond(self.width, self.height) |
| | empty_uncond = self.empty_uncond |
| |
|
| | if is_dict: |
| | uncond, cross = uncond['vector'], uncond['crossattn'] |
| | empty_uncond, empty_cross = empty_uncond['vector'], empty_uncond['crossattn'] |
| | params.text_uncond['vector'] = concat_and_lerp(empty_uncond, uncond, self.weight) |
| | params.text_uncond['crossattn'] = concat_and_lerp(empty_cross, cross, self.weight) |
| | else: |
| | params.text_uncond = concat_and_lerp(empty_uncond, uncond, self.weight) |
| |
|
| | def make_empty_uncond(self, w, h): |
| | prompt = SdConditioning([""], is_negative_prompt=True, width=w, height=h) |
| | empty_uncond = shared.sd_model.get_learned_conditioning(prompt) |
| | return empty_uncond |
| |
|
| | def print_warning(self, value): |
| | if value == 1: |
| | return |
| | color_code = '\033[33m' |
| | if value < 0.5 or value > 1.5: |
| | color_code = '\033[93m' |
| | print(f"\n{color_code}ATTENTION: Negative prompt weight is set to {value}\033[0m") |
| |
|
| |
|
| | def xyz_support(): |
| | for scriptDataTuple in scripts.scripts_data: |
| | if os.path.basename(scriptDataTuple.path) == 'xyz_grid.py': |
| | xy_grid = scriptDataTuple.module |
| |
|
| | npw_weight = xy_grid.AxisOption( |
| | '[NPW] Weight', |
| | float, |
| | xy_grid.apply_field('NPW_weight') |
| | ) |
| | xy_grid.axis_options.extend([ |
| | npw_weight |
| | ]) |
| | try: |
| | xyz_support() |
| | except Exception as e: |
| | print(f'Error trying to add XYZ plot options for NPW', e) |
| |
|