malteklaes/cpp-code-code_search_net-style
Viewer • Updated • 70k • 55 • 1
How to use malteklaes/based-CodeBERTa-language-id-llm-module_uniVienna with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("text-classification", model="malteklaes/based-CodeBERTa-language-id-llm-module_uniVienna") # Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("malteklaes/based-CodeBERTa-language-id-llm-module_uniVienna")
model = AutoModelForSequenceClassification.from_pretrained("malteklaes/based-CodeBERTa-language-id-llm-module_uniVienna")This model is a fine-tuned version of malteklaes/based-CodeBERTa-language-id-llm-module.
RobertaTokenizerFast(name_or_path='malteklaes/based-CodeBERTa-language-id-llm-module_uniVienna', vocab_size=52000, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True), added_tokens_decoder={
0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
4: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),
}
RobertaConfig {
"_name_or_path": "malteklaes/based-CodeBERTa-language-id-llm-module_uniVienna",
"_num_labels": 7,
"architectures": [
"RobertaForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": null,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "go",
"1": "java",
"2": "javascript",
"3": "php",
"4": "python",
"5": "ruby",
"6": "cpp"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"cpp": 6,
"go": 0,
"java": 1,
"javascript": 2,
"php": 3,
"python": 4,
"ruby": 5
},
"layer_norm_eps": 1e-05,
"max_position_embeddings": 514,
"model_type": "roberta",
"num_attention_heads": 12,
"num_hidden_layers": 6,
"pad_token_id": 1,
"position_embedding_type": "absolute",
"problem_type": "single_label_classification",
"torch_dtype": "float32",
"transformers_version": "4.39.3",
"type_vocab_size": 1,
"use_cache": true,
"vocab_size": 52000
}
For a given code, the following programming language can be determined:
checkpoint = "malteklaes/based-CodeBERTa-language-id-llm-module_uniVienna"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
modelPOST = AutoTokenizer.from_pretrained(checkpoint)
myPipeline = TextClassificationPipeline(
model=AutoModelForSequenceClassification.from_pretrained(checkpoint, ignore_mismatched_sizes=True),
tokenizer=AutoTokenizer.from_pretrained(checkpoint)
)
CODE_TO_IDENTIFY_py = """
def is_prime(n):
if n <= 1:
return False
if n == 2 or n == 3:
return True
if n % 2 == 0:
return False
max_divisor = int(n ** 0.5)
for i in range(3, max_divisor + 1, 2):
if n % i == 0:
return False
return True
number = 17
if is_prime(number):
print(f"{number} is a prime number.")
else:
print(f"{number} is not a prime number.")
"""
myPipeline(CODE_TO_IDENTIFY_py) # output: [{'label': 'python', 'score': 0.9999967813491821}]
The following hyperparameters were used during training (training args):
training_args = TrainingArguments(
output_dir="./based-CodeBERTa-language-id-llm-module_uniVienna",
overwrite_output_dir=True,
num_train_epochs=0.1,
per_device_train_batch_size=8,
save_steps=500,
save_total_limit=2,
)
TrainOutput(global_step=24136, training_loss=0.005988701689750161, metrics={'train_runtime': 1936.0586, 'train_samples_per_second': 99.731, 'train_steps_per_second': 12.467, 'total_flos': 3197518224531456.0, 'train_loss': 0.005988701689750161, 'epoch': 0.1})