| | from transformers import BertTokenizer, BertModel |
| | import torch |
| | import torch.nn.functional as F |
| | import json |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | |
| | |
| |
|
| | data = [ |
| | {"obj_name": "apple", "llava_text": "wheel"}, |
| | {"obj_name": "car", "llava_text": "engine"}, |
| | |
| | ] |
| |
|
| | |
| | SIMILARITY_THRESHOLD = 0.8 |
| |
|
| | |
| | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
| | model = BertModel.from_pretrained("bert-base-uncased").to(device) |
| | model.eval() |
| |
|
| | def get_bert_embedding(text): |
| | """ |
| | 使用BERT提取文本特征向量 |
| | """ |
| | inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
| | inputs = {key: value.to(device) for key, value in inputs.items()} |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | |
| | cls_embedding = outputs.last_hidden_state[:, 0, :] |
| | return cls_embedding |
| |
|
| | def calculate_cosine_similarity(embedding1, embedding2): |
| | """ |
| | 计算两个特征向量的余弦相似性 |
| | """ |
| | similarity = F.cosine_similarity(embedding1, embedding2) |
| | return similarity.item() |
| |
|
| | |
| | correct_count = 0 |
| | |
| | similarity_list = [] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | for item in data: |
| | obj_name = item["obj_name"] |
| | llava_text = item["llava_text"] |
| |
|
| | |
| | obj_embedding = get_bert_embedding(obj_name) |
| | llava_embedding = get_bert_embedding(llava_text) |
| | |
| | |
| | similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) |
| | print("similarity:", similarity) |
| | similarity_list.append(similarity) |
| | |
| | |
| | if similarity > SIMILARITY_THRESHOLD: |
| | correct_count += 1 |
| |
|
| |
|
| | |
| | total_samples = len(data) |
| | accuracy = correct_count / total_samples |
| | average_similarity = sum(similarity_list) / total_samples |
| |
|
| | print(f"正确样本数: {correct_count}") |
| | print(f"总样本数: {total_samples}") |
| | print(f"正确样本比例: {accuracy:.2%}") |
| | print(f"平均相似性: {average_similarity:.4f}") |