CULTURE / get_embeddings.py
pranamanam's picture
Upload 7 files
a0e0ff1 verified
Raw
History Blame Contribute Delete
1.56 kB
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from model import CustomBERTModel
from config import Config
import pandas as pd
def load_data(file_path):
df = pd.read_csv(file_path, header=None)
return torch.tensor(df.values, dtype=torch.float32)
def get_embeddings(input_file, output_file):
config = Config()
model = CustomBERTModel(config).to(config.device)
model.load_state_dict(torch.load("bert_mlm_model.pth"))
model.eval()
input_data = load_data(input_file)
dataset = TensorDataset(input_data)
data_loader = DataLoader(dataset, batch_size=config.batch_size)
all_embeddings = []
with torch.no_grad():
for batch in data_loader:
inputs = batch[0].to(config.device)
embeddings = model.get_encoder_output(inputs)
all_embeddings.append(embeddings.cpu().numpy())
all_embeddings = np.concatenate(all_embeddings, axis=0)
print(f"Generated embeddings shape: {all_embeddings.shape}")
# Save embeddings
np.save(output_file, all_embeddings)
print(f"Embeddings saved as {output_file}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate embeddings for microbial growth curves")
parser.add_argument("input_file", help="Path to the input CSV file containing growth curves")
parser.add_argument("output_file", help="Path to save the output embeddings (as .npy file)")
args = parser.parse_args()
get_embeddings(args.input_file, args.output_file)