SentenceTransformer
/
examples
/applications
/computing-embeddings
/computing_embeddings_streaming.py
| """ | |
| This example starts multiple processes (1 per GPU), which encode | |
| sentences in parallel. This gives a near linear speed-up | |
| when encoding large text collections. | |
| It also demonstrates how to stream data which is helpful in case you don't | |
| want to wait for an extremely large dataset to download, or if you want to | |
| limit the amount of memory used. More info about dataset streaming: | |
| https://huggingface.co/docs/datasets/stream | |
| """ | |
| from sentence_transformers import SentenceTransformer, LoggingHandler | |
| import logging | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| logging.basicConfig(format='%(asctime)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S', | |
| level=logging.INFO, | |
| handlers=[LoggingHandler()]) | |
| #Important, you need to shield your code with if __name__. Otherwise, CUDA runs into issues when spawning new processes. | |
| if __name__ == '__main__': | |
| #Set params | |
| data_stream_size = 16384 #Size of the data that is loaded into memory at once | |
| chunk_size = 1024 #Size of the chunks that are sent to each process | |
| encode_batch_size = 128 #Batch size of the model | |
| #Load a large dataset in streaming mode. more info: https://huggingface.co/docs/datasets/stream | |
| dataset = load_dataset('yahoo_answers_topics', split='train', streaming=True) | |
| dataloader = DataLoader(dataset.with_format("torch"), batch_size=data_stream_size) | |
| #Define the model | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| #Start the multi-process pool on all available CUDA devices | |
| pool = model.start_multi_process_pool() | |
| for i, batch in enumerate(tqdm(dataloader)): | |
| #Compute the embeddings using the multi-process pool | |
| sentences = batch['best_answer'] | |
| batch_emb = model.encode_multi_process(sentences, pool, chunk_size=chunk_size, batch_size=encode_batch_size) | |
| print("Embeddings computed for 1 batch. Shape:", batch_emb.shape) | |
| #Optional: Stop the proccesses in the pool | |
| model.stop_multi_process_pool(pool) | |