| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """nl2bash metric.""" |
| import re |
| import string |
|
|
| import datasets |
| import numpy as np |
|
|
| import evaluate |
|
|
|
|
| _DESCRIPTION = """ |
| returns a score that indicates how close the bash command generated is to the actual command with a perfect score out of 1.0 |
| """ |
|
|
| _KWARGS_DESCRIPTION = """ |
| Args: |
| predictions: List of predicted texts. |
| references: List of reference texts. |
| cmd_weight: The weight you want to put on getting the command correct |
| opt_weight: The weight you want to put on getting the option correct |
| arg_weight: The weight you want to put on getting the arg correct |
| ignore_case=False, |
| ignore_numbers=False, |
| Returns: |
| nl2bash metric: Dictionary containing nl2bash score. Possible values are between 0.0 and 1.0, inclusive. |
| Examples: |
| |
| |
| >>> metric = evaluate.load("Josh98/nl2bash_m") |
| >>> preds = ["ls -l /home/userr", "ls -l /home/josh", "lss /home/josh some argument"] |
| >>> refs = [["ls -l /home/user"], ["ls -l --v /home/josh"], ["ls /home/josh"]] |
| >>> results = exact_match.compute(references=refs, predictions=preds) |
| >>> print(round(results["nl2bash"], 2)) |
| 0.708 |
| """ |
|
|
| _CITATION = """ |
| """ |
|
|
|
|
| @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
| class nl2bash_m(evaluate.Metric): |
| def _info(self): |
| return evaluate.MetricInfo( |
| description=_DESCRIPTION, |
| citation=_CITATION, |
| inputs_description=_KWARGS_DESCRIPTION, |
| features=[ |
| datasets.Features( |
| { |
| "predictions": datasets.Value("string", id="sequence"), |
| "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), |
| } |
| ), |
| datasets.Features( |
| { |
| "predictions": datasets.Value("string", id="sequence"), |
| "references": datasets.Value("string", id="sequence"), |
| } |
| ), |
| ], |
| reference_urls=[], |
| ) |
|
|
| def get_score(self, pred, ref): |
| if not pred and not ref: return 1 |
| cor = 0 |
| for i in range(min(len(pred), len(ref))): |
| if (pred[i] == ref[i]): |
| cor += 1 |
| |
| return cor/max(len(pred), len(ref)) |
|
|
| def _compute( |
| self, |
| predictions, |
| references, |
| cmd_weight = 0.65, |
| opt_weight = 0.25, |
| arg_weight = 0.15, |
| ignore_case=True, |
| ignore_numbers=True, |
| ): |
|
|
| predictions = np.asarray(predictions) |
| references = np.asarray(references) |
|
|
| if ignore_case: |
| predictions = np.char.lower(predictions) |
| references = np.char.lower(references) |
|
|
| if ignore_numbers: |
| repl_table = string.digits.maketrans("", "", string.digits) |
| predictions = np.char.translate(predictions, table=repl_table) |
| references = np.char.translate(references, table=repl_table) |
|
|
| |
| final_score = 0 |
| for pred, refs in zip(predictions, references): |
| best_score = 0 |
| if len(pred) == 0 and min([len(ref) for ref in refs]) == 0: |
| best_score = 1 |
| elif len(pred) == 0 or min([len(ref) for ref in refs]) == 0: |
| best_score = 0 |
| else: |
| for ref in refs: |
| pred_words, ref_words = pred.split(), ref.split() |
|
|
| |
| |
| cmd_corr = 1 if pred_words.pop(0)==ref_words.pop(0) else 0 |
|
|
| |
| pred_option = [ x for x in pred_words if x[0] == '-'] |
| ref_option = [ x for x in ref_words if x[0] == '-'] |
| |
| |
| pred_args = [ x for x in pred_words if x[0] != '-'] |
| ref_args = [ x for x in ref_words if x[0] != '-'] |
|
|
| |
| cmd_score = cmd_weight * cmd_corr |
| opt_score = opt_weight * self.get_score(pred_option, ref_option) |
| arg_score = arg_weight * self.get_score(pred_args, ref_args) |
|
|
| score = cmd_score + opt_score + arg_score |
| best_score = max(best_score, score) |
|
|
| final_score += best_score |
|
|
| final_score = final_score/len(predictions) |
|
|
| return {"nl2bash_m": (final_score)} |