File size: 2,354 Bytes
88da18c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import argparse
import logging
import os
import sys
from gensim.models import Doc2Vec
from gensim.models.doc2vec import TaggedDocument
import numpy as np
# Configure logging
logging.basicConfig(
format='%(asctime)s : %(levelname)s : %(message)s',
level=logging.INFO,
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('doc2vec_training.log')
]
)
def train_doc2vec_model(documents_path, output_path, vector_size=200, window=5, min_count=5, epochs=20):
"""Train a Doc2Vec model for document embeddings"""
# Read documents
try:
with open(documents_path, 'r', encoding='utf-8') as f:
documents = [line.strip() for line in f if line.strip()]
except Exception as e:
logging.error(f"Error reading documents: {str(e)}")
return False
# Prepare tagged documents
tagged_data = []
for i, doc in enumerate(documents):
try:
tokens = tokenize_text(doc)
tagged_data.append(TaggedDocument(words=tokens, tags=[str(i)]))
except Exception as e:
logging.warning(f"Skipping document {i}: {str(e)}")
if not tagged_data:
logging.error("No valid documents for training")
return False
# Train model
try:
model = Doc2Vec(
documents=tagged_data,
vector_size=vector_size,
window=window,
min_count=min_count,
workers=os.cpu_count(),
epochs=epochs,
dm=1, # Use PV-DM (Distributed Memory) mode
dbow_words=0 # Skip training word vectors in DBOW mode
)
# Save model
model.save(output_path)
logging.info(f"Doc2Vec model saved to {output_path}")
# Test the model
test_phrases = [
"software engineer python java",
"data scientist machine learning",
"cloud engineer aws docker"
]
for phrase in test_phrases:
tokens = tokenize_text(phrase)
vec = model.infer_vector(tokens)
similar = model.dv.most_similar([vec], topn=3)
logging.info(f"Similar to '{phrase}': {similar}")
return True
except Exception as e:
logging.error(f"Training failed: {str(e)}")
return False |