Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Implementarea GRPO în TRL","local":"implementarea-grpo-în-trl","sections":[{"title":"Componentele Cheie","local":"componentele-cheie","sections":[{"title":"1. Formatul Setului de Date","local":"1-formatul-setului-de-date","sections":[],"depth":3},{"title":"2. Funcția de Recompensă","local":"2-funcția-de-recompensă","sections":[],"depth":3},{"title":"3. Configurația Antrenamentului","local":"3-configurația-antrenamentului","sections":[],"depth":3}],"depth":2},{"title":"Sfaturi pentru Succes","local":"sfaturi-pentru-succes","sections":[],"depth":2},{"title":"Designul Funcției de Recompensă","local":"designul-funcției-de-recompensă","sections":[{"title":"1. Recompense Bazate pe Lungime","local":"1-recompense-bazate-pe-lungime","sections":[],"depth":3}],"depth":2},{"title":"2. Recompense Bazate pe Reguli pentru Sarcini Verificabile","local":"2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile","sections":[],"depth":2},{"title":"3. Recompense Bazate pe Format","local":"3-recompense-bazate-pe-format","sections":[],"depth":2},{"title":"Asta e tot!","local":"asta-e-tot","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/course/pr_1069/rum/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/entry/start.1de7c3d2.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/singletons.e13b7dfd.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/paths.e130b7b0.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/entry/app.1f82014c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/index.2bf4358c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/nodes/0.3c83e1ab.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/nodes/31.023708c8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/Tip.363c041f.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/CodeBlock.4e987730.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/getInferenceSnippets.24b50994.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Implementarea GRPO în TRL","local":"implementarea-grpo-în-trl","sections":[{"title":"Componentele Cheie","local":"componentele-cheie","sections":[{"title":"1. Formatul Setului de Date","local":"1-formatul-setului-de-date","sections":[],"depth":3},{"title":"2. Funcția de Recompensă","local":"2-funcția-de-recompensă","sections":[],"depth":3},{"title":"3. Configurația Antrenamentului","local":"3-configurația-antrenamentului","sections":[],"depth":3}],"depth":2},{"title":"Sfaturi pentru Succes","local":"sfaturi-pentru-succes","sections":[],"depth":2},{"title":"Designul Funcției de Recompensă","local":"designul-funcției-de-recompensă","sections":[{"title":"1. Recompense Bazate pe Lungime","local":"1-recompense-bazate-pe-lungime","sections":[],"depth":3}],"depth":2},{"title":"2. Recompense Bazate pe Reguli pentru Sarcini Verificabile","local":"2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile","sections":[],"depth":2},{"title":"3. Recompense Bazate pe Format","local":"3-recompense-bazate-pe-format","sections":[],"depth":2},{"title":"Asta e tot!","local":"asta-e-tot","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="implementarea-grpo-în-trl" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#implementarea-grpo-în-trl"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Implementarea GRPO în TRL</span></h1> <p data-svelte-h="svelte-1293vxv">În această pagină, vom învăța cum să implementăm Optimizarea Relativă a Politicii de Grup (GRPO) folosind biblioteca Transformer Reinforcement Learning (TRL). Ne vom concentra pe implementarea practică cu cod minimal.</p> <p data-svelte-h="svelte-na6ewk">Vom explora conceptele centrale ale GRPO așa cum sunt întruchipate în GRPOTrainer din TRL, folosind fragmente din documentația oficială TRL pentru a ne ghida.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-4m2s3v">Acest capitol este destinat începătorilor TRL. Dacă ești deja familiar cu TRL, ai putea de asemenea să consulți <a href="https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py" rel="nofollow">implementarea Open R1</a> a GRPO.</p></div> <p data-svelte-h="svelte-1kpbqds">În primul rând, să ne reamintim unele dintre conceptele importante ale algoritmului GRPO:</p> <ul data-svelte-h="svelte-ez6ot5"><li>Formarea Grupului: Modelul generează multiple completări pentru fiecare prompt.</li> <li>Învățarea Preferințelor: Modelul învață dintr-o funcție de recompensă care compară grupuri de completări.</li> <li>Configurația Antrenamentului: Modelul folosește o configurație pentru a controla procesul de antrenare.</li></ul> <p data-svelte-h="svelte-1a3xk8g">Ce trebuie să facem pentru a implementa GRPO?</p> <ul data-svelte-h="svelte-xpn84y"><li>Să definim un set de date de prompt-uri.</li> <li>Să definim o funcție de recompensă care ia o listă de completări și returnează o listă de recompense.</li> <li>Să configurăm procesul de antrenare cu un GRPOConfig.</li> <li>Să antrenăm modelul folosind GRPOTrainer.</li></ul> <p data-svelte-h="svelte-f204iz">Iată un exemplu minimal pentru a începe antrenamentul GRPO:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> GRPOTrainer, GRPOConfig | |
| <span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-comment"># 1. Încarcă setul tău de date</span> | |
| dataset = load_dataset(<span class="hljs-string">"setul_tau_de_date"</span>, split=<span class="hljs-string">"train"</span>) | |
| <span class="hljs-comment"># 2. Definește o funcție de recompensă simplă</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_func</span>(<span class="hljs-params">completions, **kwargs</span>): | |
| <span class="hljs-string">"""Exemplu: Recompensează completările mai lungi"""</span> | |
| <span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(<span class="hljs-built_in">len</span>(completion)) <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions] | |
| <span class="hljs-comment"># 3. Configurează antrenamentul</span> | |
| training_args = GRPOConfig( | |
| output_dir=<span class="hljs-string">"output"</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, | |
| gradient_accumulation_steps=<span class="hljs-number">2</span>, | |
| logging_steps=<span class="hljs-number">10</span>, | |
| ) | |
| <span class="hljs-comment"># 4. Inițializează și antrenează</span> | |
| trainer = GRPOTrainer( | |
| model=<span class="hljs-string">"modelul_tau"</span>, <span class="hljs-comment"># de exemplu "Qwen/Qwen2-0.5B-Instruct"</span> | |
| args=training_args, | |
| train_dataset=dataset, | |
| reward_funcs=reward_func, | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="componentele-cheie" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#componentele-cheie"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Componentele Cheie</span></h2> <h3 class="relative group"><a id="1-formatul-setului-de-date" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-formatul-setului-de-date"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Formatul Setului de Date</span></h3> <p data-svelte-h="svelte-j8bpdm">Setul tău de date ar trebui să conțină prompt-uri la care modelul va răspunde. Antrenorul GRPO va genera multiple completări pentru fiecare prompt și va folosi funcția de recompensă pentru a le compara.</p> <h3 class="relative group"><a id="2-funcția-de-recompensă" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-funcția-de-recompensă"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Funcția de Recompensă</span></h3> <p data-svelte-h="svelte-iw81bf">Funcția de recompensă este crucială - determină cum învață modelul. Iată două exemple practice:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Exemplul 1: Recompensă bazată pe lungimea completării</span> | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_length</span>(<span class="hljs-params">completions, **kwargs</span>): | |
| <span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(<span class="hljs-built_in">len</span>(completion)) <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions] | |
| <span class="hljs-comment"># Exemplul 2: Recompensă bazată pe potrivirea unui model</span> | |
| <span class="hljs-keyword">import</span> re | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_format</span>(<span class="hljs-params">completions, **kwargs</span>): | |
| pattern = <span class="hljs-string">r"^<think>.*?</think><answer>.*?</answer>$"</span> | |
| <span class="hljs-keyword">return</span> [<span class="hljs-number">1.0</span> <span class="hljs-keyword">if</span> re.<span class="hljs-keyword">match</span>(pattern, c) <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span> <span class="hljs-keyword">for</span> c <span class="hljs-keyword">in</span> completions]<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="3-configurația-antrenamentului" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-configurația-antrenamentului"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Configurația Antrenamentului</span></h3> <p data-svelte-h="svelte-1cxbitg">Parametrii cheie de considerat în <code>GRPOConfig</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->training_args = GRPOConfig( | |
| <span class="hljs-comment"># Parametrii esențiali</span> | |
| output_dir=<span class="hljs-string">"output"</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| num_generation=<span class="hljs-number">4</span>, <span class="hljs-comment"># Numărul de completări de generat pentru fiecare prompt</span> | |
| per_device_train_batch_size=<span class="hljs-number">4</span>, <span class="hljs-comment"># Vrem să obținem toate generările într-un lot de dispozitiv</span> | |
| <span class="hljs-comment"># Opțional dar util</span> | |
| gradient_accumulation_steps=<span class="hljs-number">2</span>, | |
| learning_rate=<span class="hljs-number">1e-5</span>, | |
| logging_steps=<span class="hljs-number">10</span>, | |
| <span class="hljs-comment"># Specific GRPO (opțional)</span> | |
| use_vllm=<span class="hljs-literal">True</span>, <span class="hljs-comment"># Accelerează generarea</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zu7ra">Parametrul <code>num_generation</code> este deosebit de important pentru GRPO deoarece definește dimensiunea grupului - câte completări diferite va genera modelul pentru fiecare prompt. Acesta este un diferențiator cheie de alte metode RL:</p> <ul data-svelte-h="svelte-1al5l6o"><li>Prea mic (de exemplu, 2-3): S-ar putea să nu ofere suficientă diversitate pentru comparații semnificative</li> <li>Recomandat (4-16): Oferă un echilibru bun între diversitate și eficiența computațională</li> <li>Valori mai mari: Pot îmbunătăți învățarea dar cresc semnificativ costul computațional</li></ul> <p data-svelte-h="svelte-1rh1v9n">Dimensiunea grupului ar trebui aleasă în funcție de resursele tale computaționale și complexitatea sarcinii tale. Pentru sarcini simple, grupuri mai mici (4-8) pot fi suficiente, în timp ce sarcinile de raționament mai complexe ar putea beneficia de grupuri mai mari (8-16).</p> <h2 class="relative group"><a id="sfaturi-pentru-succes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sfaturi-pentru-succes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sfaturi pentru Succes</span></h2> <ol data-svelte-h="svelte-db4915"><li><strong>Gestionarea Memoriei</strong>: Ajustează <code>per_device_train_batch_size</code> și <code>gradient_accumulation_steps</code> în funcție de memoria GPU-ului tău.</li> <li><strong>Viteza</strong>: Activează <code>use_vllm=True</code> pentru generare mai rapidă dacă modelul tău este suportat.</li> <li><strong>Monitorizarea</strong>: Urmărește metricile înregistrate în timpul antrenamentului:<ul><li><code>reward</code>: Recompensa medie pe completări</li> <li><code>reward_std</code>: Deviația standard în cadrul grupurilor de recompense</li> <li><code>kl</code>: Divergența KL de la modelul de referință</li></ul></li></ol> <h2 class="relative group"><a id="designul-funcției-de-recompensă" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#designul-funcției-de-recompensă"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Designul Funcției de Recompensă</span></h2> <p data-svelte-h="svelte-6gbmmo">Lucrarea DeepSeek R1 demonstrează mai multe abordări eficiente pentru designul funcției de recompensă pe care le poți adapta pentru propria ta implementare GRPO:</p> <h3 class="relative group"><a id="1-recompense-bazate-pe-lungime" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-recompense-bazate-pe-lungime"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Recompense Bazate pe Lungime</span></h3> <p data-svelte-h="svelte-18bht15">Una dintre cele mai ușoare recompense de implementat este o recompensă bazată pe lungime. Poți recompensa completări mai lungi:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_len</span>(<span class="hljs-params">completions, **kwargs</span>): | |
| ideal_length = <span class="hljs-number">20</span> | |
| <span class="hljs-keyword">return</span> [-<span class="hljs-built_in">abs</span>(ideal_length - <span class="hljs-built_in">len</span>(completion)) <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xyrfe1">Această funcție de recompensă penalizează completările care sunt prea scurte sau prea lungi, încurajând modelul să genereze completări care sunt aproape de lungimea ideală de 20 de token-uri.</p> <iframe src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_length.py&embed=true&show-chrome=false" title="Marimo Notebook" width="100%" height="800px" frameborder="0" allow="clipboard-write"></iframe> <h2 class="relative group"><a id="2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Recompense Bazate pe Reguli pentru Sarcini Verificabile</span></h2> <p data-svelte-h="svelte-11f5ljd">Pentru sarcini cu răspunsuri obiectiv corecte (cum ar fi matematica sau codarea), poți implementa funcții de recompensă bazate pe reguli:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">problem_reward</span>(<span class="hljs-params">completions, answers, **kwargs</span>): | |
| <span class="hljs-string">"""Funcție de recompensă pentru probleme de matematică cu răspunsuri verificabile | |
| completions: lista de completări de evaluat | |
| answers: lista de răspunsuri la problemele din setul de date | |
| """</span> | |
| rewards = [] | |
| <span class="hljs-keyword">for</span> completion, correct_answer <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(completions, answers): | |
| <span class="hljs-comment"># Extrage răspunsul din completare</span> | |
| <span class="hljs-keyword">try</span>: | |
| <span class="hljs-comment"># Acesta este un exemplu simplificat - ai avea nevoie de parsing corespunzător</span> | |
| answer = extract_final_answer(completion) | |
| <span class="hljs-comment"># Recompensă binară: 1 pentru corect, 0 pentru incorect</span> | |
| reward = <span class="hljs-number">1.0</span> <span class="hljs-keyword">if</span> answer == correct_answer <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span> | |
| rewards.append(reward) | |
| <span class="hljs-keyword">except</span>: | |
| <span class="hljs-comment"># Dacă nu putem parsa un răspuns, dăm o recompensă mică</span> | |
| rewards.append(<span class="hljs-number">0.0</span>) | |
| <span class="hljs-keyword">return</span> rewards<!-- HTML_TAG_END --></pre></div> <iframe src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_math.py&embed=true&show-chrome=false" title="Marimo Notebook" width="100%" height="800px" frameborder="0" allow="clipboard-write"></iframe> <h2 class="relative group"><a id="3-recompense-bazate-pe-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-recompense-bazate-pe-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Recompense Bazate pe Format</span></h2> <p data-svelte-h="svelte-15alen6">Poți de asemenea să recompensezi formatarea corespunzătoare, care a fost importantă în antrenamentul DeepSeek R1:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">format_reward</span>(<span class="hljs-params">completions, **kwargs</span>): | |
| <span class="hljs-string">"""Recompensează completările care urmează formatul dorit"""</span> | |
| <span class="hljs-comment"># Exemplu: Verifică dacă completarea urmează un format gândește-apoi-răspunde</span> | |
| pattern = <span class="hljs-string">r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"</span> | |
| rewards = [] | |
| <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions: | |
| <span class="hljs-keyword">match</span> = re.search(pattern, completion, re.DOTALL) | |
| <span class="hljs-keyword">if</span> <span class="hljs-keyword">match</span>: | |
| <span class="hljs-comment"># Verifică dacă există conținut substanțial în ambele secțiuni</span> | |
| think_content = <span class="hljs-keyword">match</span>.group(<span class="hljs-number">1</span>).strip() | |
| answer_content = <span class="hljs-keyword">match</span>.group(<span class="hljs-number">2</span>).strip() | |
| <span class="hljs-keyword">if</span> <span class="hljs-built_in">len</span>(think_content) > <span class="hljs-number">20</span> <span class="hljs-keyword">and</span> <span class="hljs-built_in">len</span>(answer_content) > <span class="hljs-number">0</span>: | |
| rewards.append(<span class="hljs-number">1.0</span>) | |
| <span class="hljs-keyword">else</span>: | |
| rewards.append( | |
| <span class="hljs-number">0.5</span> | |
| ) <span class="hljs-comment"># Recompensă parțială pentru format corect dar conținut limitat</span> | |
| <span class="hljs-keyword">else</span>: | |
| rewards.append(<span class="hljs-number">0.0</span>) <span class="hljs-comment"># Nicio recompensă pentru format incorect</span> | |
| <span class="hljs-keyword">return</span> rewards<!-- HTML_TAG_END --></pre></div> <iframe src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_format.py&embed=true&show-chrome=false" title="Marimo Notebook" width="100%" height="800px" frameborder="0" allow="clipboard-write"></iframe> <p data-svelte-h="svelte-5dp67z">Aceste exemple demonstrează cum poți implementa funcții de recompensă inspirate din procesul de antrenare DeepSeek R1, concentrându-se pe corectitudine, formatare și semnale combinate.</p> <h2 class="relative group"><a id="asta-e-tot" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#asta-e-tot"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Asta e tot!</span></h2> <p data-svelte-h="svelte-11kcv9w">În următoarea secțiune, vei urma un exercițiu pentru a implementa GRPO în TRL.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/rum/chapter12/4.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_1ftlxhy = { | |
| assets: "/docs/course/pr_1069/rum", | |
| base: "/docs/course/pr_1069/rum", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1069/rum/_app/immutable/entry/start.1de7c3d2.js"), | |
| import("/docs/course/pr_1069/rum/_app/immutable/entry/app.1f82014c.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 31], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 39.8 kB
- Xet hash:
- e51f2ede32339d91bc5826376c078fafd03f15ba2b6dff17e7f01a6d13f23210
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.