| import streamlit as st |
| from PIL import Image |
| import torch |
| from transformers import ( |
| ViTFeatureExtractor, |
| ViTForImageClassification, |
| pipeline, |
| AutoTokenizer, |
| AutoModelForSeq2SeqLM |
| ) |
| from diffusers import StableDiffusionPipeline |
|
|
| |
| @st.cache_resource |
| def load_models(): |
| age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier') |
| age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier') |
| |
| gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2') |
| gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2') |
| |
| emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection') |
| emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection') |
| |
| object_detector = pipeline("object-detection", model="facebook/detr-resnet-50") |
| |
| action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') |
| action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224') |
| |
| prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance") |
| prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance") |
| prompt_enhancer = pipeline('text2text-generation', |
| model=prompt_enhancer_model, |
| tokenizer=prompt_enhancer_tokenizer, |
| repetition_penalty=1.2, |
| device="cpu") |
| |
| |
| pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16) |
| return (age_model, age_transforms, gender_model, gender_transforms, |
| emotion_model, emotion_transforms, object_detector, |
| action_model, action_transforms, prompt_enhancer, pipe) |
|
|
| models = load_models() |
| (age_model, age_transforms, gender_model, gender_transforms, |
| emotion_model, emotion_transforms, object_detector, |
| action_model, action_transforms, prompt_enhancer, pipe) = models |
|
|
| def predict(image, model, transforms): |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| |
| inputs = transforms(images=[image], return_tensors='pt') |
| output = model(**inputs) |
| proba = output.logits.softmax(1) |
| return proba.argmax(1).item() |
|
|
| def detect_attributes(image): |
| age = predict(image, age_model, age_transforms) |
| gender = predict(image, gender_model, gender_transforms) |
| emotion = predict(image, emotion_model, emotion_transforms) |
| action = predict(image, action_model, action_transforms) |
| |
| objects = object_detector(image) |
| |
| return { |
| 'age': age_model.config.id2label[age], |
| 'gender': gender_model.config.id2label[gender], |
| 'emotion': emotion_model.config.id2label[emotion], |
| 'action': action_model.config.id2label[action], |
| 'objects': [obj['label'] for obj in objects] |
| } |
|
|
| def generate_prompt(attributes): |
| prompt = f"A {attributes['age']} year old {attributes['gender']} person feeling {attributes['emotion']} " |
| prompt += f"while {attributes['action']}. " |
| if attributes['objects']: |
| prompt += f"Image has {', '.join(attributes['objects'])}. " |
| return prompt |
|
|
| def enhance_prompt(prompt): |
| prefix = "enhance prompt: " |
| enhanced = prompt_enhancer(prefix + prompt, max_length=256) |
| return enhanced[0]['generated_text'] |
|
|
| @st.cache_data |
| def generate_image(prompt): |
| |
| with torch.no_grad(): |
| image = pipe(prompt, num_inference_steps=50).images[0] |
| return image |
|
|
| st.title("Image Attribute Detection and Image Generation") |
|
|
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_file is not None: |
| image = Image.open(uploaded_file) |
| st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
| if st.button('Analyze Image'): |
| with st.spinner('Detecting attributes...'): |
| attributes = detect_attributes(image) |
|
|
| st.write("Detected Attributes:") |
| for key, value in attributes.items(): |
| st.write(f"{key.capitalize()}: {value}") |
|
|
| with st.spinner('Generating prompt...'): |
| initial_prompt = generate_prompt(attributes) |
| enhanced_prompt = enhance_prompt(initial_prompt) |
| |
| st.write("Initial Prompt:") |
| st.write(initial_prompt) |
| st.write("Enhanced Prompt:") |
| st.write(enhanced_prompt) |
|
|
| with st.spinner('Generating image...'): |
| generated_image = generate_image(enhanced_prompt) |
| st.image(generated_image, caption='Generated Image', use_column_width=True) |