| import string |
| from typing import Any, List |
|
|
| from langchain.evaluation.schema import StringEvaluator |
|
|
|
|
| class ExactMatchStringEvaluator(StringEvaluator): |
| """Compute an exact match between the prediction and the reference. |
| |
| Examples |
| ---------- |
| >>> evaluator = ExactMatchChain() |
| >>> evaluator.evaluate_strings( |
| prediction="Mindy is the CTO", |
| reference="Mindy is the CTO", |
| ) # This will return {'score': 1.0} |
| |
| >>> evaluator.evaluate_strings( |
| prediction="Mindy is the CTO", |
| reference="Mindy is the CEO", |
| ) # This will return {'score': 0.0} |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| ignore_case: bool = False, |
| ignore_punctuation: bool = False, |
| ignore_numbers: bool = False, |
| **kwargs: Any, |
| ): |
| super().__init__() |
| self.ignore_case = ignore_case |
| self.ignore_punctuation = ignore_punctuation |
| self.ignore_numbers = ignore_numbers |
|
|
| @property |
| def requires_input(self) -> bool: |
| """ |
| This evaluator does not require input. |
| """ |
| return False |
|
|
| @property |
| def requires_reference(self) -> bool: |
| """ |
| This evaluator requires a reference. |
| """ |
| return True |
|
|
| @property |
| def input_keys(self) -> List[str]: |
| """ |
| Get the input keys. |
| |
| Returns: |
| List[str]: The input keys. |
| """ |
| return ["reference", "prediction"] |
|
|
| @property |
| def evaluation_name(self) -> str: |
| """ |
| Get the evaluation name. |
| |
| Returns: |
| str: The evaluation name. |
| """ |
| return "exact_match" |
|
|
| def _evaluate_strings( |
| self, |
| *, |
| prediction: str, |
| reference: str, |
| **kwargs: Any, |
| ) -> dict: |
| """ |
| Evaluate the exact match between the prediction and the reference. |
| |
| Args: |
| prediction (str): The prediction string. |
| reference (Optional[str], optional): The reference string. |
| |
| Returns: |
| dict: The evaluation results containing the score. |
| """ |
| if self.ignore_case: |
| prediction = prediction.lower() |
| reference = reference.lower() |
| if self.ignore_punctuation: |
| prediction = prediction.translate(str.maketrans("", "", string.punctuation)) |
| reference = reference.translate(str.maketrans("", "", string.punctuation)) |
| if self.ignore_numbers: |
| prediction = prediction.translate(str.maketrans("", "", string.digits)) |
| reference = reference.translate(str.maketrans("", "", string.digits)) |
| return {"score": int(prediction == reference)} |
|
|