| import torch
|
| import torch.nn as nn
|
| from transformers import AutoModel, AutoTokenizer
|
| from torch.utils.data import Dataset
|
| import re
|
|
|
|
|
| class IntentDataset(Dataset):
|
| """
|
| Dataset for handling student input and session context for 5-class intent categorization.
|
| """
|
| def __init__(self, data, tokenizer, max_length=128):
|
|
|
| self.data = data
|
| self.tokenizer = tokenizer
|
| self.max_length = max_length
|
| self.label_map = {
|
| 'On-Topic Question': 0,
|
| 'Off-Topic Question': 1,
|
| 'Emotional-State': 2,
|
| 'Pace-Related': 3,
|
| 'Repeat/clarification': 4
|
| }
|
|
|
| def __len__(self):
|
| return len(self.data)
|
|
|
| def __getitem__(self, idx):
|
| item = self.data[idx]
|
| student_input = str(item.get('student_input', ''))
|
| session_context = str(item.get('session_context', ''))
|
|
|
|
|
| encoded = self.tokenizer(
|
| student_input,
|
| session_context,
|
| padding='max_length',
|
| truncation='longest_first',
|
| max_length=self.max_length,
|
| return_tensors='pt'
|
| )
|
|
|
| label_val = item.get('label', 0)
|
| if isinstance(label_val, str):
|
| label_val = self.label_map.get(label_val, 0)
|
|
|
| output = {
|
| 'input_ids': encoded['input_ids'].squeeze(0),
|
| 'attention_mask': encoded['attention_mask'].squeeze(0),
|
| 'labels': torch.tensor(label_val, dtype=torch.long)
|
| }
|
| if 'token_type_ids' in encoded:
|
| output['token_type_ids'] = encoded['token_type_ids'].squeeze(0)
|
|
|
| return output
|
|
|
|
|
| class CompoundSentenceSplitter:
|
| """
|
| Algorithm to split compound sentences containing 2 separate questions.
|
| Handles various patterns and conjunctions commonly used to combine questions.
|
| English only.
|
| """
|
|
|
| def __init__(self):
|
|
|
| self.question_words = [
|
| 'what', 'when', 'where', 'which', 'who', 'whom', 'whose', 'why', 'how',
|
| 'is', 'are', 'was', 'were', 'do', 'does', 'did', 'can', 'could',
|
| 'will', 'would', 'should', 'may', 'might', 'must'
|
| ]
|
|
|
|
|
| self.conjunctions = [
|
| 'and', 'or', 'also', 'plus', 'additionally', 'moreover'
|
| ]
|
|
|
|
|
| self.transition_phrases = [
|
| 'and also', 'and what about', 'and how about', 'or what about',
|
| 'or how about', 'also what', 'also how', 'also when', 'also where',
|
| 'also who', 'also why', 'plus what', 'plus how'
|
| ]
|
|
|
| def split_compound_question(self, text):
|
| """
|
| Split a compound sentence into 2 separate questions if applicable.
|
| Works with English text.
|
|
|
| Args:
|
| text (str): Input text that may contain compound questions
|
|
|
| Returns:
|
| list: List of separated questions. Returns [text] if no split is needed.
|
| """
|
| text = text.strip()
|
|
|
|
|
| if not self._is_question(text):
|
| return [text]
|
|
|
|
|
| questions = []
|
|
|
|
|
| questions = self._split_by_transition_phrases(text)
|
| if len(questions) > 1:
|
| return self._clean_questions(questions)
|
|
|
|
|
| questions = self._split_by_conjunction_pattern(text)
|
| if len(questions) > 1:
|
| return self._clean_questions(questions)
|
|
|
|
|
| questions = self._split_by_punctuation_pattern(text)
|
| if len(questions) > 1:
|
| return self._clean_questions(questions)
|
|
|
|
|
| questions = self._split_by_question_marks(text)
|
| if len(questions) > 1:
|
| return self._clean_questions(questions)
|
|
|
|
|
| return [text]
|
|
|
| def _is_question(self, text):
|
| """Check if text is likely a question (English)"""
|
| text_stripped = text.strip()
|
|
|
|
|
| if '?' in text:
|
| return True
|
|
|
|
|
| words = text_stripped.split()
|
| if words:
|
| first_word = words[0].lower()
|
|
|
| if first_word in self.question_words:
|
| return True
|
|
|
| return False
|
|
|
| def _split_by_transition_phrases(self, text):
|
| """Split by transition phrases (English)"""
|
| for phrase in self.transition_phrases:
|
|
|
| pattern = r'\s+' + re.escape(phrase) + r'\s+'
|
|
|
| match = re.search(pattern, text, re.IGNORECASE)
|
| if match:
|
| parts = re.split(pattern, text, maxsplit=1, flags=re.IGNORECASE)
|
| if len(parts) == 2 and parts[0] and parts[1]:
|
| return parts
|
|
|
| return [text]
|
|
|
| def _split_by_conjunction_pattern(self, text):
|
| """Split by conjunction followed by question word (English)"""
|
|
|
| for conj in self.conjunctions:
|
| for qword in self.question_words:
|
|
|
| pattern = r'\s+' + re.escape(conj) + r'\s+' + re.escape(qword) + r'\b'
|
|
|
| match = re.search(pattern, text, re.IGNORECASE)
|
|
|
| if match:
|
|
|
| split_pos = match.start()
|
| part1 = text[:split_pos].strip()
|
| part2 = text[split_pos:].strip()
|
|
|
|
|
| for c in self.conjunctions:
|
| is_arabic_c = any(ch in 'أبتثجحخدذرزسشصضطظعغفقكلمنهويىةؤإآ' for ch in c)
|
| part2 = re.sub(r'^\s*' + re.escape(c) + r'\s+', '', part2, flags=re.IGNORECASE if not is_arabic_c else 0)
|
|
|
|
|
| if part1 and part2 and self._is_question(part1):
|
| return [part1, part2]
|
|
|
| return [text]
|
|
|
| def _split_by_punctuation_pattern(self, text):
|
| """Split by semicolon or specific comma patterns"""
|
|
|
| if ';' in text or '؛' in text:
|
| parts = re.split(r'[;؛]', text, maxsplit=1)
|
| if len(parts) == 2:
|
| parts = [p.strip() for p in parts]
|
| if all(self._is_question(p) for p in parts):
|
| return parts
|
|
|
|
|
| pattern = r',\s+(?=' + '|'.join([re.escape(qw) for qw in self.question_words]) + r')'
|
| parts = re.split(pattern, text, maxsplit=1, flags=re.IGNORECASE)
|
|
|
| if len(parts) == 2:
|
| parts = [p.strip() for p in parts]
|
|
|
| if self._is_question(parts[1]):
|
| return parts
|
|
|
| return [text]
|
|
|
| def _split_by_question_marks(self, text):
|
| """Split by question marks if multiple exist (both ? and ؟)"""
|
|
|
| q_marks = text.count('?') + text.count('؟')
|
|
|
| if q_marks >= 2:
|
|
|
| match = re.search(r'[?؟]', text)
|
| if match:
|
| split_pos = match.end()
|
| part1 = text[:split_pos].strip()
|
| part2 = text[split_pos:].strip()
|
|
|
| if part2:
|
| return [part1, part2]
|
|
|
| return [text]
|
|
|
| def _clean_questions(self, questions):
|
| """Clean and validate split questions"""
|
| cleaned = []
|
|
|
| for q in questions:
|
| q = q.strip()
|
|
|
|
|
| if not q:
|
| continue
|
|
|
|
|
| if self._is_question(q):
|
|
|
| if not (q.endswith('?') or q.endswith('؟')):
|
|
|
| if any(c in 'أبتثجحخدذرزسشصضطظعغفقكلمنهويىةؤإآ' for c in q):
|
| q += '؟'
|
| else:
|
| q += '?'
|
|
|
| cleaned.append(q)
|
|
|
| return cleaned if len(cleaned) > 1 else [' '.join(questions)]
|
|
|
|
|
| class TinyBertCNN(nn.Module):
|
| """
|
| TinyBERT-CNN model for intent classification.
|
| Combines TinyBERT embeddings with CNN layers + BatchNorm + hidden FC layer.
|
| """
|
|
|
| def __init__(
|
| self,
|
| num_classes,
|
| bert_model_name='huawei-noah/TinyBERT_General_4L_312D',
|
| num_filters=256,
|
| filter_sizes=[2, 3, 4],
|
| dropout=0.5,
|
| hidden_dim=128,
|
| freeze_bert=False
|
| ):
|
| """
|
| Args:
|
| num_classes (int): Number of intent classes
|
| bert_model_name (str): Pre-trained TinyBERT model name
|
| num_filters (int): Number of filters for each filter size
|
| filter_sizes (list): List of filter sizes for CNN
|
| dropout (float): Dropout rate
|
| hidden_dim (int): Hidden FC layer dimension
|
| freeze_bert (bool): Whether to freeze BERT parameters
|
| """
|
| super(TinyBertCNN, self).__init__()
|
|
|
|
|
| self.bert = AutoModel.from_pretrained(bert_model_name)
|
| self.bert_hidden_size = self.bert.config.hidden_size
|
|
|
|
|
| if freeze_bert:
|
| for param in self.bert.parameters():
|
| param.requires_grad = False
|
|
|
|
|
| self.convs = nn.ModuleList([
|
| nn.Conv1d(
|
| in_channels=self.bert_hidden_size,
|
| out_channels=num_filters,
|
| kernel_size=fs
|
| )
|
| for fs in filter_sizes
|
| ])
|
| self.batchnorms = nn.ModuleList([
|
| nn.BatchNorm1d(num_filters)
|
| for _ in filter_sizes
|
| ])
|
|
|
|
|
| self.dropout = nn.Dropout(dropout)
|
|
|
|
|
| cnn_out_dim = len(filter_sizes) * num_filters
|
| self.fc_hidden = nn.Linear(cnn_out_dim, hidden_dim)
|
| self.bn_hidden = nn.BatchNorm1d(hidden_dim)
|
|
|
|
|
| self.fc = nn.Linear(hidden_dim, num_classes)
|
|
|
| def forward(self, input_ids, attention_mask, token_type_ids=None):
|
| """
|
| Forward pass
|
|
|
| Args:
|
| input_ids: Token IDs (batch_size, seq_len)
|
| attention_mask: Attention mask (batch_size, seq_len)
|
| token_type_ids: Token type IDs (batch_size, seq_len), optional
|
|
|
| Returns:
|
| logits: Classification logits (batch_size, num_classes)
|
| """
|
|
|
|
|
| bert_kwargs = {
|
| 'input_ids': input_ids,
|
| 'attention_mask': attention_mask
|
| }
|
| if token_type_ids is not None:
|
| bert_kwargs['token_type_ids'] = token_type_ids
|
|
|
| bert_output = self.bert(**bert_kwargs)
|
|
|
|
|
|
|
| sequence_output = bert_output.last_hidden_state
|
|
|
|
|
| sequence_output = sequence_output.transpose(1, 2)
|
|
|
|
|
| max_kernel = max(conv.kernel_size[0] for conv in self.convs)
|
| if sequence_output.size(2) < max_kernel:
|
| pad_size = max_kernel - sequence_output.size(2)
|
| sequence_output = torch.nn.functional.pad(sequence_output, (0, pad_size))
|
|
|
|
|
| conv_outputs = []
|
| for conv, bn in zip(self.convs, self.batchnorms):
|
|
|
| conv_out = torch.relu(bn(conv(sequence_output)))
|
|
|
| pooled = torch.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
|
| conv_outputs.append(pooled)
|
|
|
|
|
|
|
| concatenated = torch.cat(conv_outputs, dim=1)
|
| concatenated = self.dropout(concatenated)
|
|
|
|
|
| hidden = torch.relu(self.bn_hidden(self.fc_hidden(concatenated)))
|
| hidden = self.dropout(hidden)
|
|
|
|
|
| logits = self.fc(hidden)
|
|
|
| return logits
|
|
|
|
|
| class IntentClassifier:
|
| """
|
| Wrapper class for training and inference
|
| """
|
|
|
| def __init__(
|
| self,
|
| num_classes,
|
| bert_model_name='huawei-noah/TinyBERT_General_4L_312D',
|
| num_filters=256,
|
| filter_sizes=[2, 3, 4],
|
| dropout=0.5,
|
| freeze_bert=False,
|
| device=None
|
| ):
|
| self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
| self.model = TinyBertCNN(
|
| num_classes=num_classes,
|
| bert_model_name=bert_model_name,
|
| num_filters=num_filters,
|
| filter_sizes=filter_sizes,
|
| dropout=dropout,
|
| freeze_bert=freeze_bert
|
| ).to(self.device)
|
|
|
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
|
|
|
|
| self.sentence_splitter = CompoundSentenceSplitter()
|
|
|
| self.num_classes = num_classes
|
|
|
| def preprocess_text(self, text):
|
| """
|
| Preprocess text by splitting compound questions if detected
|
|
|
| Args:
|
| text (str): Input text (English or Arabic)
|
|
|
| Returns:
|
| list: List of individual questions
|
| """
|
| return self.sentence_splitter.split_compound_question(text)
|
|
|
| def predict(self, student_inputs, session_contexts=None, max_length=128, split_compound=False):
|
| """
|
| Predict intents for input texts
|
|
|
| Args:
|
| student_inputs (list): List of student input texts (English or Arabic)
|
| session_contexts (list): List of session context texts
|
| max_length (int): Maximum sequence length
|
| split_compound (bool): Whether to split compound questions before prediction
|
|
|
| Returns:
|
| If split_compound=False:
|
| predictions: Predicted class indices
|
| probabilities: Prediction probabilities
|
| If split_compound=True:
|
| predictions: List of predictions (may contain multiple per text if split)
|
| probabilities: List of probabilities
|
| split_info: Dictionary with information about splits
|
| """
|
|
|
| if split_compound:
|
| return self._predict_with_splitting(student_inputs, session_contexts, max_length)
|
|
|
| self.model.eval()
|
|
|
|
|
| if session_contexts is not None:
|
| text_args = (student_inputs, session_contexts)
|
| else:
|
| text_args = (student_inputs,)
|
|
|
|
|
| encoded = self.tokenizer(
|
| *text_args,
|
| padding=True,
|
| truncation=True,
|
| max_length=max_length,
|
| return_tensors='pt'
|
| )
|
|
|
| input_ids = encoded['input_ids'].to(self.device)
|
| attention_mask = encoded['attention_mask'].to(self.device)
|
| token_type_ids = encoded.get('token_type_ids')
|
| if token_type_ids is not None:
|
| token_type_ids = token_type_ids.to(self.device)
|
|
|
| with torch.no_grad():
|
| logits = self.model(input_ids, attention_mask, token_type_ids=token_type_ids)
|
| probabilities = torch.softmax(logits, dim=1)
|
| predictions = torch.argmax(probabilities, dim=1)
|
|
|
| return predictions.cpu().numpy(), probabilities.cpu().numpy()
|
|
|
| def _predict_with_splitting(self, student_inputs, session_contexts=None, max_length=128):
|
| """
|
| Predict intents after splitting compound questions (English and Arabic)
|
|
|
| Args:
|
| student_inputs (list): List of input texts
|
| session_contexts (list): List of session context texts
|
| max_length (int): Maximum sequence length
|
|
|
| Returns:
|
| predictions: List of predictions (one per original text, may contain multiple if split)
|
| probabilities: List of probabilities
|
| split_info: Dictionary with information about splits
|
| """
|
| all_predictions = []
|
| all_probabilities = []
|
| split_info = {
|
| 'original_texts': student_inputs,
|
| 'split_texts': [],
|
| 'was_split': [],
|
| 'split_indices': []
|
| }
|
|
|
|
|
| all_questions = []
|
| all_contexts = []
|
| for i, text in enumerate(student_inputs):
|
| questions = self.preprocess_text(text)
|
| split_info['split_texts'].append(questions)
|
| split_info['was_split'].append(len(questions) > 1)
|
|
|
|
|
| for _ in questions:
|
| split_info['split_indices'].append(i)
|
| if session_contexts is not None:
|
| all_contexts.append(session_contexts[i])
|
|
|
| all_questions.extend(questions)
|
|
|
|
|
| if all_questions:
|
| contexts_to_pass = all_contexts if session_contexts is not None else None
|
| predictions, probabilities = self.predict(all_questions, contexts_to_pass, max_length, split_compound=False)
|
|
|
|
|
| idx = 0
|
| for i, text in enumerate(student_inputs):
|
| num_questions = len(split_info['split_texts'][i])
|
| text_predictions = predictions[idx:idx + num_questions]
|
| text_probabilities = probabilities[idx:idx + num_questions]
|
|
|
| all_predictions.append(text_predictions)
|
| all_probabilities.append(text_probabilities)
|
|
|
| idx += num_questions
|
|
|
| return all_predictions, all_probabilities, split_info
|
|
|
| def train_step(self, batch, optimizer, criterion):
|
| """
|
| Single training step
|
|
|
| Args:
|
| batch: Dictionary with 'input_ids', 'attention_mask', 'labels'
|
| optimizer: Optimizer
|
| criterion: Loss function
|
|
|
| Returns:
|
| loss: Training loss
|
| """
|
| self.model.train()
|
|
|
| input_ids = batch['input_ids'].to(self.device)
|
| attention_mask = batch['attention_mask'].to(self.device)
|
| labels = batch['labels'].to(self.device)
|
| token_type_ids = batch.get('token_type_ids')
|
| if token_type_ids is not None:
|
| token_type_ids = token_type_ids.to(self.device)
|
|
|
|
|
| logits = self.model(input_ids, attention_mask, token_type_ids=token_type_ids)
|
| loss = criterion(logits, labels)
|
|
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
|
|
| return loss.item()
|
|
|
| def evaluate(self, dataloader, criterion):
|
| """
|
| Evaluate model on validation/test set
|
|
|
| Args:
|
| dataloader: DataLoader for evaluation
|
| criterion: Loss function
|
|
|
| Returns:
|
| avg_loss: Average loss
|
| accuracy: Classification accuracy
|
| """
|
| self.model.eval()
|
|
|
| total_loss = 0
|
| total_correct = 0
|
| total_samples = 0
|
|
|
| with torch.no_grad():
|
| for batch in dataloader:
|
| input_ids = batch['input_ids'].to(self.device)
|
| attention_mask = batch['attention_mask'].to(self.device)
|
| labels = batch['labels'].to(self.device)
|
| token_type_ids = batch.get('token_type_ids')
|
| if token_type_ids is not None:
|
| token_type_ids = token_type_ids.to(self.device)
|
|
|
|
|
| logits = self.model(input_ids, attention_mask, token_type_ids=token_type_ids)
|
| loss = criterion(logits, labels)
|
|
|
|
|
| predictions = torch.argmax(logits, dim=1)
|
| total_loss += loss.item() * labels.size(0)
|
| total_correct += (predictions == labels).sum().item()
|
| total_samples += labels.size(0)
|
|
|
| avg_loss = total_loss / total_samples
|
| accuracy = total_correct / total_samples
|
|
|
| return avg_loss, accuracy
|
|
|
| def save_model(self, path):
|
| """Save model checkpoint"""
|
| torch.save({
|
| 'model_state_dict': self.model.state_dict(),
|
| 'num_classes': self.num_classes
|
| }, path)
|
| print(f"Model saved to {path}")
|
|
|
| def load_model(self, path):
|
| """Load model checkpoint"""
|
| checkpoint = torch.load(path, map_location=self.device)
|
| self.model.load_state_dict(checkpoint['model_state_dict'])
|
| print(f"Model loaded from {path}")
|
|
|
|
|