| | import torch |
| |
|
| |
|
| |
|
| | def retrieve_st_by_image(image_embeddings, all_text_embeddings, dataframe, k=3): |
| | """ |
| | Retrieves the top-k most similar ST based on the similarity between ST embeddings and image embeddings. |
| | |
| | :param image_embeddings: A numpy array or torch tensor containing image embeddings (shape: [1, embedding_dim]). |
| | :param all_text_embeddings: A numpy array or torch tensor containing ST embeddings (shape: [n_samples, embedding_dim]). |
| | :param dataframe: A pandas DataFrame containing information about the ST samples, specifically the image indices in the 'img_idx' column. |
| | :param k: The number of top similar samples to retrieve. Default is 3. |
| | :return: A list of the filenames or indices corresponding to the top-k similar samples. |
| | """ |
| | |
| | |
| | dot_similarity = image_embeddings @ all_text_embeddings.T |
| | |
| | |
| | values, indices = torch.topk(dot_similarity.squeeze(0), k) |
| | |
| | |
| | image_filenames = dataframe['img_idx'].values |
| | matches = [image_filenames[idx] for idx in indices] |
| |
|
| | return matches |
| |
|
| |
|
| |
|