Buckets:

rtrm's picture
download
raw
39.8 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Implementarea GRPO în TRL&quot;,&quot;local&quot;:&quot;implementarea-grpo-în-trl&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Componentele Cheie&quot;,&quot;local&quot;:&quot;componentele-cheie&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Formatul Setului de Date&quot;,&quot;local&quot;:&quot;1-formatul-setului-de-date&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;2. Funcția de Recompensă&quot;,&quot;local&quot;:&quot;2-funcția-de-recompensă&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;3. Configurația Antrenamentului&quot;,&quot;local&quot;:&quot;3-configurația-antrenamentului&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Sfaturi pentru Succes&quot;,&quot;local&quot;:&quot;sfaturi-pentru-succes&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Designul Funcției de Recompensă&quot;,&quot;local&quot;:&quot;designul-funcției-de-recompensă&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Recompense Bazate pe Lungime&quot;,&quot;local&quot;:&quot;1-recompense-bazate-pe-lungime&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;2. Recompense Bazate pe Reguli pentru Sarcini Verificabile&quot;,&quot;local&quot;:&quot;2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;3. Recompense Bazate pe Format&quot;,&quot;local&quot;:&quot;3-recompense-bazate-pe-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Asta e tot!&quot;,&quot;local&quot;:&quot;asta-e-tot&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/course/pr_1069/rum/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/entry/start.1de7c3d2.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/scheduler.37c15a92.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/singletons.e13b7dfd.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/index.18351ede.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/paths.e130b7b0.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/entry/app.1f82014c.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/index.2bf4358c.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/nodes/0.3c83e1ab.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/nodes/31.023708c8.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/Tip.363c041f.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/CodeBlock.4e987730.js">
<link rel="modulepreload" href="/docs/course/pr_1069/rum/_app/immutable/chunks/getInferenceSnippets.24b50994.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Implementarea GRPO în TRL&quot;,&quot;local&quot;:&quot;implementarea-grpo-în-trl&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Componentele Cheie&quot;,&quot;local&quot;:&quot;componentele-cheie&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Formatul Setului de Date&quot;,&quot;local&quot;:&quot;1-formatul-setului-de-date&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;2. Funcția de Recompensă&quot;,&quot;local&quot;:&quot;2-funcția-de-recompensă&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;3. Configurația Antrenamentului&quot;,&quot;local&quot;:&quot;3-configurația-antrenamentului&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Sfaturi pentru Succes&quot;,&quot;local&quot;:&quot;sfaturi-pentru-succes&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Designul Funcției de Recompensă&quot;,&quot;local&quot;:&quot;designul-funcției-de-recompensă&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;1. Recompense Bazate pe Lungime&quot;,&quot;local&quot;:&quot;1-recompense-bazate-pe-lungime&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;2. Recompense Bazate pe Reguli pentru Sarcini Verificabile&quot;,&quot;local&quot;:&quot;2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;3. Recompense Bazate pe Format&quot;,&quot;local&quot;:&quot;3-recompense-bazate-pe-format&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Asta e tot!&quot;,&quot;local&quot;:&quot;asta-e-tot&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="implementarea-grpo-în-trl" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#implementarea-grpo-în-trl"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Implementarea GRPO în TRL</span></h1> <p data-svelte-h="svelte-1293vxv">În această pagină, vom învăța cum să implementăm Optimizarea Relativă a Politicii de Grup (GRPO) folosind biblioteca Transformer Reinforcement Learning (TRL). Ne vom concentra pe implementarea practică cu cod minimal.</p> <p data-svelte-h="svelte-na6ewk">Vom explora conceptele centrale ale GRPO așa cum sunt întruchipate în GRPOTrainer din TRL, folosind fragmente din documentația oficială TRL pentru a ne ghida.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-4m2s3v">Acest capitol este destinat începătorilor TRL. Dacă ești deja familiar cu TRL, ai putea de asemenea să consulți <a href="https://github.com/huggingface/open-r1/blob/main/src/open_r1/grpo.py" rel="nofollow">implementarea Open R1</a> a GRPO.</p></div> <p data-svelte-h="svelte-1kpbqds">În primul rând, să ne reamintim unele dintre conceptele importante ale algoritmului GRPO:</p> <ul data-svelte-h="svelte-ez6ot5"><li>Formarea Grupului: Modelul generează multiple completări pentru fiecare prompt.</li> <li>Învățarea Preferințelor: Modelul învață dintr-o funcție de recompensă care compară grupuri de completări.</li> <li>Configurația Antrenamentului: Modelul folosește o configurație pentru a controla procesul de antrenare.</li></ul> <p data-svelte-h="svelte-1a3xk8g">Ce trebuie să facem pentru a implementa GRPO?</p> <ul data-svelte-h="svelte-xpn84y"><li>Să definim un set de date de prompt-uri.</li> <li>Să definim o funcție de recompensă care ia o listă de completări și returnează o listă de recompense.</li> <li>Să configurăm procesul de antrenare cu un GRPOConfig.</li> <li>Să antrenăm modelul folosind GRPOTrainer.</li></ul> <p data-svelte-h="svelte-f204iz">Iată un exemplu minimal pentru a începe antrenamentul GRPO:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> GRPOTrainer, GRPOConfig
<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-comment"># 1. Încarcă setul tău de date</span>
dataset = load_dataset(<span class="hljs-string">&quot;setul_tau_de_date&quot;</span>, split=<span class="hljs-string">&quot;train&quot;</span>)
<span class="hljs-comment"># 2. Definește o funcție de recompensă simplă</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_func</span>(<span class="hljs-params">completions, **kwargs</span>):
<span class="hljs-string">&quot;&quot;&quot;Exemplu: Recompensează completările mai lungi&quot;&quot;&quot;</span>
<span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(<span class="hljs-built_in">len</span>(completion)) <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions]
<span class="hljs-comment"># 3. Configurează antrenamentul</span>
training_args = GRPOConfig(
output_dir=<span class="hljs-string">&quot;output&quot;</span>,
num_train_epochs=<span class="hljs-number">3</span>,
per_device_train_batch_size=<span class="hljs-number">4</span>,
gradient_accumulation_steps=<span class="hljs-number">2</span>,
logging_steps=<span class="hljs-number">10</span>,
)
<span class="hljs-comment"># 4. Inițializează și antrenează</span>
trainer = GRPOTrainer(
model=<span class="hljs-string">&quot;modelul_tau&quot;</span>, <span class="hljs-comment"># de exemplu &quot;Qwen/Qwen2-0.5B-Instruct&quot;</span>
args=training_args,
train_dataset=dataset,
reward_funcs=reward_func,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="componentele-cheie" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#componentele-cheie"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Componentele Cheie</span></h2> <h3 class="relative group"><a id="1-formatul-setului-de-date" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-formatul-setului-de-date"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Formatul Setului de Date</span></h3> <p data-svelte-h="svelte-j8bpdm">Setul tău de date ar trebui să conțină prompt-uri la care modelul va răspunde. Antrenorul GRPO va genera multiple completări pentru fiecare prompt și va folosi funcția de recompensă pentru a le compara.</p> <h3 class="relative group"><a id="2-funcția-de-recompensă" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-funcția-de-recompensă"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Funcția de Recompensă</span></h3> <p data-svelte-h="svelte-iw81bf">Funcția de recompensă este crucială - determină cum învață modelul. Iată două exemple practice:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Exemplul 1: Recompensă bazată pe lungimea completării</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_length</span>(<span class="hljs-params">completions, **kwargs</span>):
<span class="hljs-keyword">return</span> [<span class="hljs-built_in">float</span>(<span class="hljs-built_in">len</span>(completion)) <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions]
<span class="hljs-comment"># Exemplul 2: Recompensă bazată pe potrivirea unui model</span>
<span class="hljs-keyword">import</span> re
<span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_format</span>(<span class="hljs-params">completions, **kwargs</span>):
pattern = <span class="hljs-string">r&quot;^&lt;think&gt;.*?&lt;/think&gt;&lt;answer&gt;.*?&lt;/answer&gt;$&quot;</span>
<span class="hljs-keyword">return</span> [<span class="hljs-number">1.0</span> <span class="hljs-keyword">if</span> re.<span class="hljs-keyword">match</span>(pattern, c) <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span> <span class="hljs-keyword">for</span> c <span class="hljs-keyword">in</span> completions]<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="3-configurația-antrenamentului" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-configurația-antrenamentului"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Configurația Antrenamentului</span></h3> <p data-svelte-h="svelte-1cxbitg">Parametrii cheie de considerat în <code>GRPOConfig</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->training_args = GRPOConfig(
<span class="hljs-comment"># Parametrii esențiali</span>
output_dir=<span class="hljs-string">&quot;output&quot;</span>,
num_train_epochs=<span class="hljs-number">3</span>,
num_generation=<span class="hljs-number">4</span>, <span class="hljs-comment"># Numărul de completări de generat pentru fiecare prompt</span>
per_device_train_batch_size=<span class="hljs-number">4</span>, <span class="hljs-comment"># Vrem să obținem toate generările într-un lot de dispozitiv</span>
<span class="hljs-comment"># Opțional dar util</span>
gradient_accumulation_steps=<span class="hljs-number">2</span>,
learning_rate=<span class="hljs-number">1e-5</span>,
logging_steps=<span class="hljs-number">10</span>,
<span class="hljs-comment"># Specific GRPO (opțional)</span>
use_vllm=<span class="hljs-literal">True</span>, <span class="hljs-comment"># Accelerează generarea</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zu7ra">Parametrul <code>num_generation</code> este deosebit de important pentru GRPO deoarece definește dimensiunea grupului - câte completări diferite va genera modelul pentru fiecare prompt. Acesta este un diferențiator cheie de alte metode RL:</p> <ul data-svelte-h="svelte-1al5l6o"><li>Prea mic (de exemplu, 2-3): S-ar putea să nu ofere suficientă diversitate pentru comparații semnificative</li> <li>Recomandat (4-16): Oferă un echilibru bun între diversitate și eficiența computațională</li> <li>Valori mai mari: Pot îmbunătăți învățarea dar cresc semnificativ costul computațional</li></ul> <p data-svelte-h="svelte-1rh1v9n">Dimensiunea grupului ar trebui aleasă în funcție de resursele tale computaționale și complexitatea sarcinii tale. Pentru sarcini simple, grupuri mai mici (4-8) pot fi suficiente, în timp ce sarcinile de raționament mai complexe ar putea beneficia de grupuri mai mari (8-16).</p> <h2 class="relative group"><a id="sfaturi-pentru-succes" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sfaturi-pentru-succes"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sfaturi pentru Succes</span></h2> <ol data-svelte-h="svelte-db4915"><li><strong>Gestionarea Memoriei</strong>: Ajustează <code>per_device_train_batch_size</code> și <code>gradient_accumulation_steps</code> în funcție de memoria GPU-ului tău.</li> <li><strong>Viteza</strong>: Activează <code>use_vllm=True</code> pentru generare mai rapidă dacă modelul tău este suportat.</li> <li><strong>Monitorizarea</strong>: Urmărește metricile înregistrate în timpul antrenamentului:<ul><li><code>reward</code>: Recompensa medie pe completări</li> <li><code>reward_std</code>: Deviația standard în cadrul grupurilor de recompense</li> <li><code>kl</code>: Divergența KL de la modelul de referință</li></ul></li></ol> <h2 class="relative group"><a id="designul-funcției-de-recompensă" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#designul-funcției-de-recompensă"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Designul Funcției de Recompensă</span></h2> <p data-svelte-h="svelte-6gbmmo">Lucrarea DeepSeek R1 demonstrează mai multe abordări eficiente pentru designul funcției de recompensă pe care le poți adapta pentru propria ta implementare GRPO:</p> <h3 class="relative group"><a id="1-recompense-bazate-pe-lungime" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#1-recompense-bazate-pe-lungime"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>1. Recompense Bazate pe Lungime</span></h3> <p data-svelte-h="svelte-18bht15">Una dintre cele mai ușoare recompense de implementat este o recompensă bazată pe lungime. Poți recompensa completări mai lungi:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">reward_len</span>(<span class="hljs-params">completions, **kwargs</span>):
ideal_length = <span class="hljs-number">20</span>
<span class="hljs-keyword">return</span> [-<span class="hljs-built_in">abs</span>(ideal_length - <span class="hljs-built_in">len</span>(completion)) <span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xyrfe1">Această funcție de recompensă penalizează completările care sunt prea scurte sau prea lungi, încurajând modelul să genereze completări care sunt aproape de lungimea ideală de 20 de token-uri.</p> <iframe src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_length.py&embed=true&show-chrome=false" title="Marimo Notebook" width="100%" height="800px" frameborder="0" allow="clipboard-write"></iframe> <h2 class="relative group"><a id="2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#2-recompense-bazate-pe-reguli-pentru-sarcini-verificabile"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>2. Recompense Bazate pe Reguli pentru Sarcini Verificabile</span></h2> <p data-svelte-h="svelte-11f5ljd">Pentru sarcini cu răspunsuri obiectiv corecte (cum ar fi matematica sau codarea), poți implementa funcții de recompensă bazate pe reguli:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">problem_reward</span>(<span class="hljs-params">completions, answers, **kwargs</span>):
<span class="hljs-string">&quot;&quot;&quot;Funcție de recompensă pentru probleme de matematică cu răspunsuri verificabile
completions: lista de completări de evaluat
answers: lista de răspunsuri la problemele din setul de date
&quot;&quot;&quot;</span>
rewards = []
<span class="hljs-keyword">for</span> completion, correct_answer <span class="hljs-keyword">in</span> <span class="hljs-built_in">zip</span>(completions, answers):
<span class="hljs-comment"># Extrage răspunsul din completare</span>
<span class="hljs-keyword">try</span>:
<span class="hljs-comment"># Acesta este un exemplu simplificat - ai avea nevoie de parsing corespunzător</span>
answer = extract_final_answer(completion)
<span class="hljs-comment"># Recompensă binară: 1 pentru corect, 0 pentru incorect</span>
reward = <span class="hljs-number">1.0</span> <span class="hljs-keyword">if</span> answer == correct_answer <span class="hljs-keyword">else</span> <span class="hljs-number">0.0</span>
rewards.append(reward)
<span class="hljs-keyword">except</span>:
<span class="hljs-comment"># Dacă nu putem parsa un răspuns, dăm o recompensă mică</span>
rewards.append(<span class="hljs-number">0.0</span>)
<span class="hljs-keyword">return</span> rewards<!-- HTML_TAG_END --></pre></div> <iframe src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_math.py&embed=true&show-chrome=false" title="Marimo Notebook" width="100%" height="800px" frameborder="0" allow="clipboard-write"></iframe> <h2 class="relative group"><a id="3-recompense-bazate-pe-format" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#3-recompense-bazate-pe-format"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>3. Recompense Bazate pe Format</span></h2> <p data-svelte-h="svelte-15alen6">Poți de asemenea să recompensezi formatarea corespunzătoare, care a fost importantă în antrenamentul DeepSeek R1:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">format_reward</span>(<span class="hljs-params">completions, **kwargs</span>):
<span class="hljs-string">&quot;&quot;&quot;Recompensează completările care urmează formatul dorit&quot;&quot;&quot;</span>
<span class="hljs-comment"># Exemplu: Verifică dacă completarea urmează un format gândește-apoi-răspunde</span>
pattern = <span class="hljs-string">r&quot;&lt;think&gt;(.*?)&lt;/think&gt;\s*&lt;answer&gt;(.*?)&lt;/answer&gt;&quot;</span>
rewards = []
<span class="hljs-keyword">for</span> completion <span class="hljs-keyword">in</span> completions:
<span class="hljs-keyword">match</span> = re.search(pattern, completion, re.DOTALL)
<span class="hljs-keyword">if</span> <span class="hljs-keyword">match</span>:
<span class="hljs-comment"># Verifică dacă există conținut substanțial în ambele secțiuni</span>
think_content = <span class="hljs-keyword">match</span>.group(<span class="hljs-number">1</span>).strip()
answer_content = <span class="hljs-keyword">match</span>.group(<span class="hljs-number">2</span>).strip()
<span class="hljs-keyword">if</span> <span class="hljs-built_in">len</span>(think_content) &gt; <span class="hljs-number">20</span> <span class="hljs-keyword">and</span> <span class="hljs-built_in">len</span>(answer_content) &gt; <span class="hljs-number">0</span>:
rewards.append(<span class="hljs-number">1.0</span>)
<span class="hljs-keyword">else</span>:
rewards.append(
<span class="hljs-number">0.5</span>
) <span class="hljs-comment"># Recompensă parțială pentru format corect dar conținut limitat</span>
<span class="hljs-keyword">else</span>:
rewards.append(<span class="hljs-number">0.0</span>) <span class="hljs-comment"># Nicio recompensă pentru format incorect</span>
<span class="hljs-keyword">return</span> rewards<!-- HTML_TAG_END --></pre></div> <iframe src="https://marimo.app/gh/huggingface/notebooks/main/e?entrypoint=course%2Fen%2Fchapter13%2Fgrpo_format.py&embed=true&show-chrome=false" title="Marimo Notebook" width="100%" height="800px" frameborder="0" allow="clipboard-write"></iframe> <p data-svelte-h="svelte-5dp67z">Aceste exemple demonstrează cum poți implementa funcții de recompensă inspirate din procesul de antrenare DeepSeek R1, concentrându-se pe corectitudine, formatare și semnale combinate.</p> <h2 class="relative group"><a id="asta-e-tot" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#asta-e-tot"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Asta e tot!</span></h2> <p data-svelte-h="svelte-11kcv9w">În următoarea secțiune, vei urma un exercițiu pentru a implementa GRPO în TRL.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/rum/chapter12/4.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1ftlxhy = {
assets: "/docs/course/pr_1069/rum",
base: "/docs/course/pr_1069/rum",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/course/pr_1069/rum/_app/immutable/entry/start.1de7c3d2.js"),
import("/docs/course/pr_1069/rum/_app/immutable/entry/app.1f82014c.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 31],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
39.8 kB
·
Xet hash:
e51f2ede32339d91bc5826376c078fafd03f15ba2b6dff17e7f01a6d13f23210

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.