| | import torch |
| | from PIL import Image |
| |
|
| | import random |
| | import pandas as pd |
| | import gradio as gr |
| | import numpy as np |
| | from sklearn.linear_model import LogisticRegression |
| | from sklearn.svm import SVC |
| | from sklearn import preprocessing |
| | import time |
| | import torch |
| | from matplotlib import pyplot as plt |
| |
|
| | from model import model, tokenizer, load_image |
| |
|
| | from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| |
|
| | device = 'cuda' |
| | dtype = torch.bfloat16 |
| |
|
| | base = "stabilityai/stable-diffusion-xl-base-1.0" |
| | repo = "ByteDance/SDXL-Lightning" |
| | ckpt = "sdxl_lightning_8step_unet.safetensors" |
| |
|
| | |
| | unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype) |
| | unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) |
| | pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device) |
| |
|
| | |
| | pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") |
| |
|
| |
|
| |
|
| |
|
| |
|
| | with torch.cuda.amp.autocast(True, dtype): |
| | |
| | pixel_values = load_image(image_file='blank.png', max_num=1).to(device) |
| | base_embed = model.extract_feature(pixel_values.to(dtype)).detach().float() |
| |
|
| |
|
| |
|
| | def get_text(embed): |
| | with torch.cuda.amp.autocast(True, dtype): |
| | generation_config = dict(max_new_tokens=32, do_sample=True, |
| | temperature=.5, top_p=.92) |
| |
|
| | |
| | pixel_values = 0 |
| | question = '''''' |
| | response = model.chat(tokenizer, pixel_values, question, generation_config, visual_features=embed.to(dtype)) |
| | print(response) |
| | return response |
| |
|
| | def get_image(text): |
| | return pipe(text, num_inference_steps=8, guidance_scale=0).images[0] |
| |
|
| | def get_embed(img): |
| | with torch.cuda.amp.autocast(True, dtype): |
| | |
| | pixel_values = load_image(image_file='', pil_image=img, max_num=1).to(device) |
| | embed = model.extract_feature(pixel_values.to(dtype)) |
| | return embed.float() |
| |
|
| |
|
| |
|
| | prompt_list = [p for p in list(set( |
| | pd.read_csv('/home/ryn_mote/Misc/twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] |
| | random.shuffle(prompt_list) |
| |
|
| |
|
| |
|
| | NOT_calibrate_prompts = [ |
| | 'an abstract painting', |
| | 'unique streetwear design that blends the old with the new. Combine bold, urban typography with retro graphics, taking inspiration from distressed signage and graffiti. Use a range of earthy tones to give the design a vintage aesthetic, while adding a modern twist with a stylistic rendering of the graphics', |
| | 'a photo of hell', |
| | '' |
| | ] |
| |
|
| | calibrate_prompts = [ |
| | "4k photo", |
| | 'surrealist art', |
| | 'a psychedelic, fractal view', |
| | 'a beautiful collage', |
| | 'an intricate portrait', |
| | 'an impressionist painting', |
| | 'abstract art', |
| | 'an eldritch image', |
| | 'a sketch', |
| | 'a city full of darkness and graffiti', |
| | 'a black & white photo', |
| | 'a brilliant, timeless tarot card of the world', |
| | '''eternity: a timeless, vivid painted portrait by ryan murdock''', |
| | '''a simple, timeless, & dark charcoal on canvas: death itself by ryan murdock''', |
| | '''a painted image with gorgeous red gradients: Persephone by ryan murdock''', |
| | '''a simple, timeless, & dark photo with gorgeous gradients: last night of my life by ryan murdock''', |
| | '''the sunflower -- a dark, simple painted still life by ryan murdock''', |
| | '''silence in the macrocosm -- a dark, intricate painting by ryan murdock''', |
| | '''beauty here -- a photograph by ryan murdock''', |
| | '''a timeless, haunting portrait: the necrotic jester''', |
| | '''a simple, timeless, & dark art piece with gorgeous gradients: serenity''', |
| | '''an elegant image of nature with gorgeous swirling gradients''', |
| | '''simple, timeless digital art with gorgeous purple spirals''', |
| | '''timeless digital art with gorgeous gradients: eternal slumber''', |
| | '''a simple, timeless image with gorgeous gradients''', |
| | '''a simple, timeless painted image of nature with beautiful gradients''', |
| | 'a timeless, dark digital art piece with gorgeous gradients: the hanged man', |
| | '', |
| | ] |
| |
|
| |
|
| |
|
| | global_idx = 0 |
| | embs = [] |
| | ys = [] |
| |
|
| | start_time = time.time() |
| |
|
| | def next_image(): |
| | with torch.no_grad(): |
| | if len(calibrate_prompts) > 0: |
| | prompt = calibrate_prompts.pop(0) |
| | print(f'######### Calibrating with sample: {prompt} #########') |
| |
|
| | image = get_image(prompt) |
| |
|
| |
|
| | |
| | with torch.cuda.amp.autocast(): |
| | embed = get_embed(image) |
| | |
| |
|
| | embs.append(embed) |
| | return image, prompt |
| | else: |
| | print('######### Roaming #########') |
| |
|
| | |
| | indices = range(len(ys)) |
| | pos_indices = [i for i in indices if ys[i] > .5] |
| | neg_indices = [i for i in indices if ys[i] <= .5] |
| | |
| | mini = min(len(pos_indices), len(neg_indices)) |
| | |
| | if mini < 1: |
| | feature_embs = torch.stack([torch.randn(1280), torch.randn(1280)]) |
| | ys_t = [0, 1] |
| | print('Not enough ratings.') |
| | else: |
| | |
| | ys_t = [ys[i] for i in indices] |
| | feature_embs = torch.stack([embs[e][0, 0].detach().cpu() for e in indices]).squeeze() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print(np.array(feature_embs).shape, np.array(ys_t).shape) |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | pos_sol = torch.stack([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5]).mean(0, keepdim=True).to(device, dtype) |
| | neg_sol = torch.stack([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] < .5]).mean(0, keepdim=True).to(device, dtype) |
| | |
| | |
| | latest_pos = (random.sample([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5], 1)[0]).to(device, dtype) |
| |
|
| | dif = pos_sol - neg_sol |
| | sol = latest_pos + ((dif / dif.std()) * latest_pos.std()) |
| |
|
| | print(sol.shape) |
| | |
| |
|
| | text = get_text(sol) |
| | image = get_image(text) |
| | embed = get_embed(image) |
| |
|
| | embs.append(embed) |
| |
|
| | plt.close() |
| | plt.hist(sol.detach().cpu().float().flatten()) |
| | plt.savefig('sol.jpg') |
| |
|
| |
|
| | plt.close() |
| | plt.hist(embed.detach().cpu().float().flatten()) |
| | plt.savefig('embed.jpg') |
| | |
| | |
| | return image, text |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def start(_): |
| | return [ |
| | gr.Button(value='Like', interactive=True), |
| | gr.Button(value='Neither', interactive=True), |
| | gr.Button(value='Dislike', interactive=True), |
| | gr.Button(value='Start', interactive=False), |
| | *next_image() |
| | ] |
| |
|
| |
|
| | def choose(choice): |
| | global global_idx |
| | global_idx += 1 |
| | if choice == 'Like': |
| | choice = 1 |
| | elif choice == 'Neither': |
| | _ = embs.pop(-1) |
| | return next_image() |
| | else: |
| | choice = 0 |
| | ys.append(choice) |
| | return next_image() |
| |
|
| | css = "div#output-image {height: 512px !important; width: 512px !important; margin:auto;}" |
| | with gr.Blocks(css=css) as demo: |
| | with gr.Row(): |
| | html = gr.HTML('''<div style='text-align:center; font-size:32'>You will callibrate for several prompts and then roam.</ div>''') |
| | with gr.Row(elem_id='output-image'): |
| | img = gr.Image(interactive=False, elem_id='output-image',) |
| | with gr.Row(elem_id='output-txt'): |
| | txt = gr.Textbox(interactive=False, elem_id='output-txt',) |
| | with gr.Row(equal_height=True): |
| | b3 = gr.Button(value='Dislike', interactive=False,) |
| | b2 = gr.Button(value='Neither', interactive=False,) |
| | b1 = gr.Button(value='Like', interactive=False,) |
| | b1.click( |
| | choose, |
| | [b1], |
| | [img, txt] |
| | ) |
| | b2.click( |
| | choose, |
| | [b2], |
| | [img, txt] |
| | ) |
| | b3.click( |
| | choose, |
| | [b3], |
| | [img, txt] |
| | ) |
| | with gr.Row(): |
| | b4 = gr.Button(value='Start') |
| | b4.click(start, |
| | [b4], |
| | [b1, b2, b3, b4, img, txt]) |
| |
|
| | demo.launch(share=True) |
| |
|
| |
|
| |
|
| | |
| |
|
| |
|