File size: 2,439 Bytes
8c28866 9a96194 8c28866 8277e31 8c28866 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import gradio as gr
import numpy as np
import pandas as pd
from tensorflow import keras
from PIL import Image
import io
import base64
from sklearn.metrics.pairwise import cosine_similarity
PATH_MODEL = "./autoencoder.keras"
PATH_DB = "./mnist_train_small.csv"
# ── Cargar modelo y datos al iniciar ─────────────────────────────────────────
model = keras.models.load_model(PATH_MODEL)
encoder = model.get_layer("encoder")
decoder = model.get_layer("decoder")
data = pd.read_csv(PATH_DB, header=None)
X_ref = data.iloc[:, 1:].values.astype("float32") / 255
X_latent = encoder.predict(X_ref, verbose=0)
LATENT_DIM = 32
# ── Helper: imagen subida → array (1, 784) ────────────────────────────────────
def image_to_array(canva):
img = canva['composite'].convert("L")
img = img.resize((28, 28))
arr = 1 - np.array(img, dtype="float32") / 255
return arr.reshape(1, 784)
def find_similar(img, top_k):
X = image_to_array(img)
query_vec = encoder.predict(X, verbose=0)
sims = cosine_similarity(query_vec, X_latent)[0]
top_idx = np.argsort(sims)[::-1][:int(top_k)]
best_arr = (X_ref[top_idx[0]].reshape(28, 28) * 255).astype(np.uint8)
best_img = Image.fromarray(best_arr)
table = [[int(i), round(float(sims[i]), 4)] for i in top_idx]
gallery_imgs = [
Image.fromarray((X_ref[i].reshape(28, 28) * 255).astype(np.uint8))
for i in top_idx
]
return table, gallery_imgs
with gr.Blocks() as demo:
with gr.Tab("Búsqueda"):
gr.Markdown("## Búsqueda en espacio latente")
with gr.Row():
with gr.Column():
canvas = gr.Sketchpad(label="Dibuja", type='pil')
with gr.Column():
topk = gr.Slider(1, 50, value=10, step=1, label="top_k")
btn = gr.Button("Buscar similares")
gallery = gr.Gallery(label="Imágenes similares", columns=5, object_fit="contain")
with gr.Tab("Metadatos"):
results = gr.Dataframe(
headers=["index", "cosine_similarity"],
datatype=["number", "number"],
label="Ranking",
interactive=False
)
btn.click(find_similar, inputs=[canvas, topk], outputs=[results, gallery])
demo.launch(server_port=7860) |