skillsync-cli / model /training /train_sbert.py
Mr-Haseeb786
Sanitized Production Build
56c7b6d
Raw
History Blame Contribute Delete
4.63 kB
import argparse
import logging
import os
import sys
import pandas as pd
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
# Configure logging
logging.basicConfig(
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('sbert_training.log')
]
)
def create_default_training_data(output_path):
"""Create minimal training data if none exists"""
data = {
'resume': [
'Experienced software engineer with Python and Java',
'Data scientist with machine learning expertise',
'Web developer specializing in React'
],
'job_description': [
'Seeking software engineer with Python experience',
'Looking for data scientist with ML background',
'Hiring frontend developer with React skills'
],
'match_score': [85, 90, 95]
}
df = pd.DataFrame(data)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
df.to_csv(output_path, index=False)
logging.warning(f"Created minimal training data at {output_path}")
return df
def train_sbert_model(train_data_path, output_dir,
model_name='all-MiniLM-L6-v2',
num_epochs=10, batch_size=32):
"""Train an SBERT model for semantic similarity"""
# Handle missing training data
if not os.path.exists(train_data_path):
logging.warning(f"No training data found at {train_data_path}")
df = create_default_training_data(train_data_path)
else:
try:
df = pd.read_csv(train_data_path)
logging.info(f"Loaded training data with {len(df)} records")
except Exception as e:
logging.error(f"Error loading training data: {str(e)}")
df = create_default_training_data(train_data_path)
# Prepare training examples
train_examples = []
for _, row in df.iterrows():
try:
# Normalize score to 0-1 range
score = max(0, min(100, float(row['match_score']))) / 100
train_examples.append(InputExample(
texts=[str(row['resume']), str(row['job_description'])],
label=score
))
except Exception as e:
logging.warning(f"Skipping invalid row: {str(e)}")
if not train_examples:
logging.error("No valid training examples found!")
return False
# Initialize model
model = SentenceTransformer(model_name)
logging.info(f"Initialized model: {model_name}")
# Data loader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
train_loss = losses.CosineSimilarityLoss(model)
# Training
logging.info(f"Starting training for {num_epochs} epochs with {len(train_examples)} examples")
try:
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
show_progress_bar=True,
output_path=output_dir
)
logging.info(f"Model successfully saved to {output_dir}")
return True
except Exception as e:
logging.error(f"Training failed: {str(e)}")
return False
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train SBERT model for resume-JD matching')
parser.add_argument('--train_data', type=str,
default='data/training/train.csv',
help='Path to training data CSV')
parser.add_argument('--output_dir', type=str,
default='trained_models/sbert',
help='Output directory for trained model')
parser.add_argument('--model_name', type=str,
default='all-MiniLM-L6-v2',
help='Base SBERT model name')
parser.add_argument('--epochs', type=int, default=10,
help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=32,
help='Training batch size')
args = parser.parse_args()
# Create output directory if not exists
os.makedirs(args.output_dir, exist_ok=True)
# Run training
success = train_sbert_model(
train_data_path=args.train_data,
output_dir=args.output_dir,
model_name=args.model_name,
num_epochs=args.epochs,
batch_size=args.batch_size
)
sys.exit(0 if success else 1)