| import argparse |
| import logging |
| import os |
| import sys |
| import pandas as pd |
| from sentence_transformers import SentenceTransformer, InputExample, losses |
| from torch.utils.data import DataLoader |
|
|
| |
| logging.basicConfig( |
| format='%(asctime)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| level=logging.INFO, |
| handlers=[ |
| logging.StreamHandler(sys.stdout), |
| logging.FileHandler('sbert_training.log') |
| ] |
| ) |
|
|
| def create_default_training_data(output_path): |
| """Create minimal training data if none exists""" |
| data = { |
| 'resume': [ |
| 'Experienced software engineer with Python and Java', |
| 'Data scientist with machine learning expertise', |
| 'Web developer specializing in React' |
| ], |
| 'job_description': [ |
| 'Seeking software engineer with Python experience', |
| 'Looking for data scientist with ML background', |
| 'Hiring frontend developer with React skills' |
| ], |
| 'match_score': [85, 90, 95] |
| } |
| df = pd.DataFrame(data) |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| df.to_csv(output_path, index=False) |
| logging.warning(f"Created minimal training data at {output_path}") |
| return df |
|
|
| def train_sbert_model(train_data_path, output_dir, |
| model_name='all-MiniLM-L6-v2', |
| num_epochs=10, batch_size=32): |
| """Train an SBERT model for semantic similarity""" |
| |
| |
| if not os.path.exists(train_data_path): |
| logging.warning(f"No training data found at {train_data_path}") |
| df = create_default_training_data(train_data_path) |
| else: |
| try: |
| df = pd.read_csv(train_data_path) |
| logging.info(f"Loaded training data with {len(df)} records") |
| except Exception as e: |
| logging.error(f"Error loading training data: {str(e)}") |
| df = create_default_training_data(train_data_path) |
| |
| |
| train_examples = [] |
| for _, row in df.iterrows(): |
| try: |
| |
| score = max(0, min(100, float(row['match_score']))) / 100 |
| train_examples.append(InputExample( |
| texts=[str(row['resume']), str(row['job_description'])], |
| label=score |
| )) |
| except Exception as e: |
| logging.warning(f"Skipping invalid row: {str(e)}") |
| |
| if not train_examples: |
| logging.error("No valid training examples found!") |
| return False |
| |
| |
| model = SentenceTransformer(model_name) |
| logging.info(f"Initialized model: {model_name}") |
| |
| |
| train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size) |
| train_loss = losses.CosineSimilarityLoss(model) |
| |
| |
| logging.info(f"Starting training for {num_epochs} epochs with {len(train_examples)} examples") |
| |
| try: |
| model.fit( |
| train_objectives=[(train_dataloader, train_loss)], |
| epochs=num_epochs, |
| show_progress_bar=True, |
| output_path=output_dir |
| ) |
| logging.info(f"Model successfully saved to {output_dir}") |
| return True |
| except Exception as e: |
| logging.error(f"Training failed: {str(e)}") |
| return False |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description='Train SBERT model for resume-JD matching') |
| parser.add_argument('--train_data', type=str, |
| default='data/training/train.csv', |
| help='Path to training data CSV') |
| parser.add_argument('--output_dir', type=str, |
| default='trained_models/sbert', |
| help='Output directory for trained model') |
| parser.add_argument('--model_name', type=str, |
| default='all-MiniLM-L6-v2', |
| help='Base SBERT model name') |
| parser.add_argument('--epochs', type=int, default=10, |
| help='Number of training epochs') |
| parser.add_argument('--batch_size', type=int, default=32, |
| help='Training batch size') |
| |
| args = parser.parse_args() |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| success = train_sbert_model( |
| train_data_path=args.train_data, |
| output_dir=args.output_dir, |
| model_name=args.model_name, |
| num_epochs=args.epochs, |
| batch_size=args.batch_size |
| ) |
| |
| sys.exit(0 if success else 1) |