Buckets:

rtrm's picture
download
raw
39.8 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Affinare il modello con la Trainer API&quot;,&quot;local&quot;:&quot;affinare-il-modello-con-la-trainer-api&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Addestramento&quot;,&quot;local&quot;:&quot;addestramento&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Valutazione&quot;,&quot;local&quot;:&quot;valutazione&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:1}">
<link href="/docs/course/pr_1069/it/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/entry/start.693d748d.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/scheduler.37c15a92.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/singletons.60b4c7a2.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/index.18351ede.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/paths.43b6516c.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/entry/app.e9cfd099.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/index.2bf4358c.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/nodes/0.bb8a536c.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/nodes/22.689ed549.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/Tip.363c041f.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/Youtube.1e50a667.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/CodeBlock.4e987730.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/DocNotebookDropdown.efc1fb7c.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/FrameworkSwitchCourse.8d4d4ab6.js">
<link rel="modulepreload" href="/docs/course/pr_1069/it/_app/immutable/chunks/getInferenceSnippets.24b50994.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Affinare il modello con la Trainer API&quot;,&quot;local&quot;:&quot;affinare-il-modello-con-la-trainer-api&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Addestramento&quot;,&quot;local&quot;:&quot;addestramento&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Valutazione&quot;,&quot;local&quot;:&quot;valutazione&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="affinare-il-modello-con-la-trainer-api" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#affinare-il-modello-con-la-trainer-api"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Affinare il modello con la Trainer API</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/it/chapter3/section3.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/it/chapter3/section3.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/nvBXf7s7vTI" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-w466ks">🤗 Transformers fornisce una classe <code>Trainer</code> (addestratore) per aiutare con l’affinamento di uno qualsiasi dei modelli pre-addestrati nel dataset. Dopo tutto il lavoro di preprocessing nella sezione precedente, rimangono giusto gli ultimi passi per definire il <code>Trainer</code>. Probabilmente la parte più complicata sarà preparare l’ambiente per eseguire <code>Trainer.train()</code>, poiché sarà molto lento su una CPU. Se non avete una GPU a disposizione, potete avere accesso gratuitamente a GPU o TPU su <a href="https://colab.research.google.com/" rel="nofollow">Google Colab</a>.</p> <p data-svelte-h="svelte-11al6ht">Gli esempi di codice qui sotto partono dal presupposto che gli esempi nella sezione precedente siano già stati eseguiti. Ecco un breve riassunto di cosa serve:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, DataCollatorWithPadding
raw_datasets = load_dataset(<span class="hljs-string">&quot;glue&quot;</span>, <span class="hljs-string">&quot;mrpc&quot;</span>)
checkpoint = <span class="hljs-string">&quot;bert-base-uncased&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">example</span>):
<span class="hljs-keyword">return</span> tokenizer(example[<span class="hljs-string">&quot;sentence1&quot;</span>], example[<span class="hljs-string">&quot;sentence2&quot;</span>], truncation=<span class="hljs-literal">True</span>)
tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(tokenize_function, batched=<span class="hljs-literal">True</span>)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="addestramento" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#addestramento"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Addestramento</span></h3> <p data-svelte-h="svelte-3oidne">Il primo passo per definire un <code>Trainer</code> è la definizione di una classe <code>TrainingArguments</code> che contenga tutti gli iperparametri che verranno usati dal <code>Trainer</code> per l’addestramento e la valutazione. L’unico parametro da fornire è la cartella dove verranno salvati il modello addestrato e i vari checkpoint. Per tutto il resto si possono lasciare i parametri di default, che dovrebbero funzionare bene per un affinamento di base.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(<span class="hljs-string">&quot;test-trainer&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1oyy0rx">💡 Se si vuole caricare automaticamente il modello all’Hub durante l’addestramento, basta passare <code>push_to_hub=True</code> come parametro nei <code>TrainingArguments</code>. Maggiori dettagli verranno forniti nel <a href="/course/chapter4/3">Capitolo 4</a>.</p></div> <p data-svelte-h="svelte-nsi4y1">Il secondo passo è definire il modello. Come nel <a href="/course/chapter2">capitolo precedente</a>, utilizzeremo la classe <code>AutoModelForSequenceClassification</code> con due label:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1580ewa">Diversamente dal <a href="/course/chapter2">Capitolo 2</a>, un avviso di avvertimento verrà visualizzato dopo aver istanziato questo modello pre-addestrato. Ciò avviene perché BERT non è stato pre-addestrato per classificare coppie di frasi, quindi la testa del modello pre-addestrato viene scartata e una nuova testa adeguata per il compito di classificazione di sequenze è stata inserita. Gli avvertimenti indicano che alcuni pesi non verranno usati (quelli corrispondenti alla testa scartata del modello pre-addestrato) e che altri pesi sono stati inizializzati con valori casuali (quelli per la nuova testa). L’avvertimento viene concluso con un’esortazione ad addestrare il modello, che è esattamente ciò che stiamo per fare.</p> <p data-svelte-h="svelte-zpy7k9">Una volta ottenuto il modello, si può definire un <code>Trainer</code> passandogli tutti gli oggetti costruiti fino ad adesso — il <code>model</code>, i <code>training_args</code>, i dataset di addestramento e validazione, il <code>data_collator</code>, e il <code>tokenizer</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model,
training_args,
train_dataset=tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=tokenized_datasets[<span class="hljs-string">&quot;validation&quot;</span>],
data_collator=data_collator,
tokenizer=tokenizer,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-gnh7rc">Quando si passa l’argomento <code>tokenizer</code> come appena fatto, il <code>data_collator</code> usato di default dal <code>Trainer</code> sarà del tipo <code>DataCollatorWithPadding</code>, come definito precedentemente, quindi si potrebbe evitare di specificare l’argomento <code>data_collator=data_collator</code> in questa chiamata. Tuttavia era comunque importante mostrare questa parte del processing nella sezione 2!</p> <p data-svelte-h="svelte-xznjnr">Per affinare il modello sul nostro dataset, bisogna solo chiamare il metodo <code>train()</code> del <code>Trainer</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-rtrulk">Questo farà partire l’affinamento (che richiederà un paio di minuti su una GPU) e produrrà un report della funzione obiettivo dell’addestramento ogni 500 passi. Tuttavia, non vi farà sapere quanto sia buona (o cattiva) la performance del modello. Ciò è dovuto al fatto che:</p> <ol data-svelte-h="svelte-pjo77w"><li>Non è stato detto al <code>Trainer</code> di valutare il modello durante l’addestramento, settando <code>evaluation_strategy</code> o al valore <code>&quot;steps&quot;</code> (valuta il modello ogni <code>eval_steps</code>) oppure al valore <code>&quot;epoch&quot;</code> (valuta il modello alla fine di ogni epoca).</li> <li>Non è stato fornito al <code>Trainer</code> una funzione <code>compute_metrics()</code> per calcolare le metriche di valutazione (altrimenti la valutazione stamperebbe solo il valore della funzione obiettivo, che non è un valore molto intuitivo).</li></ol> <h3 class="relative group"><a id="valutazione" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#valutazione"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Valutazione</span></h3> <p data-svelte-h="svelte-pup2lr">Vediamo come si può costruire una funzione <code>compute_metrics()</code> utile e usarla per il prossimo addestramento. La funzione deve prendere come parametro un oggetto <code>EvalPrediction</code> (che è una named tuple avente un campo <code>predictions</code> – predizioni – e un campo <code>label_ids</code> – id delle etichette –) e restituirà un dizionario che associa stringhe a numeri floating point (le stringhe saranno i nomi delle metriche, i numeri i loro valori). Per ottenere delle predizioni, si può usare il comando <code>Trainer.predict()</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->predictions = trainer.predict(tokenized_datasets[<span class="hljs-string">&quot;validation&quot;</span>])
<span class="hljs-built_in">print</span>(predictions.predictions.shape, predictions.label_ids.shape)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->(<span class="hljs-number">408</span>, <span class="hljs-number">2</span>) (<span class="hljs-number">408</span>,)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mg62sf">Il risultato del metodo <code>predict()</code> è un’altra named tuple con tre campi: <code>predictions</code>, <code>label_ids</code>, e <code>metrics</code>. Il campo <code>metrics</code> conterrà solo il valore della funzione obiettivo sul dataset, in aggiunta ad alcune metriche legate al tempo (il tempo necessario per calcolare le predizioni, in totale e in media). Una volta completata la funzione <code>compute_metrics()</code> e passata al <code>Trainer</code>, quel campo conterrà anche le metriche restituite da <code>compute_metrics()</code>.</p> <p data-svelte-h="svelte-bv1sad">Come si può vedere, <code>predictions</code> è un array bi-dimensionale con dimensioni 408 x 2 (poiché 408 è il numero di elementi nel dataset). Questi sono i logit per ogni elemento del dataset passato a <code>predict()</code> (come già visto nel <a href="/course/chapter2">capitolo precedente</a>, tutti i modelli Transformer restituiscono logit). Per trasformarli in predizioni associabili alle etichette, bisogna prendere l’indice col valore massimo sul secondo asse:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
preds = np.argmax(predictions.predictions, axis=-<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-19co3eg">Ora si possono paragonare i <code>preds</code> con le etichette. Per costruire la funzione <code>compute_metric()</code>, verranno utilizzate le metriche dalla libreria 🤗 Dataset. Si possono caricare le metriche associate con il dataset MRPC in maniera semplice, utilizzando la funzione <code>load_metric()</code>. L’oggetto restituito ha un metodo <code>compute()</code> (calcola) che possiamo usare per calcolare le metriche:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_metric
metric = load_metric(<span class="hljs-string">&quot;glue&quot;</span>, <span class="hljs-string">&quot;mrpc&quot;</span>)
metric.compute(predictions=preds, references=predictions.label_ids)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">&#x27;accuracy&#x27;</span>: <span class="hljs-number">0.8578431372549019</span>, <span class="hljs-string">&#x27;f1&#x27;</span>: <span class="hljs-number">0.8996539792387542</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xgb3dx">L’esatto valore dei risultati potrebbe essere diverso nel vostro caso, a casa dell’inizializzazione casuale della testa del modello. In questo caso il nostro modello ha un’accuratezza del 85.78% sul set di validazione e un valore F1 di 89.97. Queste sono le due metriche utilizzate per valutare i risultati sul dataset MRPC per il benchmark GLUE. La tabella nell’<a href="https://arxiv.org/pdf/1810.04805.pdf" rel="nofollow">articolo su BERT</a> riportava un F1 di 88.9 per il modello base. Quello era il modello <code>uncased</code> (senza distinzione fra minuscole e maiuscole) mentre noi stiamo usando quello <code>cased</code>, il che spiega il risultato migliore.</p> <p data-svelte-h="svelte-10eff0k">Mettendo tutto insieme si ottiene la funzione <code>compute_metrics()</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_preds</span>):
metric = load_metric(<span class="hljs-string">&quot;glue&quot;</span>, <span class="hljs-string">&quot;mrpc&quot;</span>)
logits, labels = eval_preds
predictions = np.argmax(logits, axis=-<span class="hljs-number">1</span>)
<span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=labels)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-az7fqa">Per vederla in azione e fare il report delle metriche alla fine di ogni epoca, ecco come si definisce un nuovo <code>Trainer</code> che includa questa funzione <code>compute_metrics()</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->training_args = TrainingArguments(<span class="hljs-string">&quot;test-trainer&quot;</span>, evaluation_strategy=<span class="hljs-string">&quot;epoch&quot;</span>)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)
trainer = Trainer(
model,
training_args,
train_dataset=tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=tokenized_datasets[<span class="hljs-string">&quot;validation&quot;</span>],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-19ejq6q">Da notare che bisogna creare un nuovo oggetto <code>TrainingArguments</code> con il valore di <code>evaluation_strategy</code> pari a <code>&quot;epoch&quot;</code> e un nuovo modello — altrimenti si continuerebbe l’addestramento del modello già addestrato. Per lanciare una nuova esecuzione dell’addestramento si usa:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.trai<span class="hljs-meta">n</span>()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18vmmqt">Stavolta vi sarà il report della funzione obiettivo di validazione alla fine di ogni epoca, in aggiunta alla funzione obiettivo dell’addestramento. Di nuovo, i valori esatti di accuratezza/F1 ottenuti da voi potrebbero variare leggermente da quelli mostrati qui a causa dell’inizializzazione casuale della testa del modello, ma dovrebbero essere comparabili.</p> <p data-svelte-h="svelte-1lvweaa">Il <code>Trainer</code> funzionerà direttamente su svariate GPU e TPU e ha molte opzioni, tra cui addestramento in precisione mista (utilizzare <code>fp16 = True</code> negli argomenti). I dettagli delle opzioni verranno esplorati nel Capitolo 10.</p> <p data-svelte-h="svelte-19uw8j0">Qui si conclude l’introduzione all’affinamento usando l’API del <code>Trainer</code>. Esempi per i compiti più comuni in NLP verranno forniti nel Capitolo 7, ma per ora vediamo come ottenere la stessa cosa usando puramente Pytorch.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1jn42dw">✏️ <strong>Prova tu!</strong> Affinare un modello sul dataset GLUE SST-2 utilizzando il processing dei dati già fatto nella sezione 2.</p></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/it/chapter3/3.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_u8ez91 = {
assets: "/docs/course/pr_1069/it",
base: "/docs/course/pr_1069/it",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/course/pr_1069/it/_app/immutable/entry/start.693d748d.js"),
import("/docs/course/pr_1069/it/_app/immutable/entry/app.e9cfd099.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 22],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
39.8 kB
·
Xet hash:
71e0a021b53449d52c57d70577fd138e335eedb6362b4208f399eaf8b60d7a8c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.