baseten-admin commited on
Commit
4dd975e
·
verified ·
1 Parent(s): 8f5ba2f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/gemma-4-E2B-it
3
+ library_name: transformers
4
+ pipeline_tag: text-classification
5
+ tags:
6
+ - gemma4
7
+ - sequence-classification
8
+ - text-classification
9
+ ---
10
+
11
+ # Gemma 4 E2B IT Sequence Classification
12
+
13
+ This checkpoint converts `google/gemma-4-E2B-it` into a two-label sequence
14
+ classifier by slicing the original LM head to the single-token labels `no` and
15
+ `yes`.
16
+
17
+ The classifier logits are exactly the original next-token logits for:
18
+
19
+ - `no`: token id `1904`
20
+ - `yes`: token id `4443`
21
+
22
+ Load with custom code enabled:
23
+
24
+ ```python
25
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
26
+
27
+ model_id = "baseten/gemma-4-e2b-it-sequence-classification"
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
29
+ model = AutoModelForSequenceClassification.from_pretrained(
30
+ model_id,
31
+ trust_remote_code=True,
32
+ dtype="auto",
33
+ )
34
+ ```
chat_template.jinja ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- macro format_parameters(properties, required, filter_keys=false) -%}
2
+ {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
3
+ {%- set ns = namespace(found_first=false) -%}
4
+ {%- for key, value in properties | dictsort -%}
5
+ {%- set add_comma = false -%}
6
+ {%- if not filter_keys or key not in standard_keys -%}
7
+ {%- if ns.found_first %},{% endif -%}
8
+ {%- set ns.found_first = true -%}
9
+ {{ key }}:{
10
+ {%- if value['description'] -%}
11
+ description:<|"|>{{ value['description'] }}<|"|>
12
+ {%- set add_comma = true -%}
13
+ {%- endif -%}
14
+ {%- if value['type'] | upper == 'STRING' -%}
15
+ {%- if value['enum'] -%}
16
+ {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
17
+ enum:{{ format_argument(value['enum']) }}
18
+ {%- endif -%}
19
+ {%- elif value['type'] | upper == 'ARRAY' -%}
20
+ {%- if value['items'] is mapping and value['items'] -%}
21
+ {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
22
+ items:{
23
+ {%- set ns_items = namespace(found_first=false) -%}
24
+ {%- for item_key, item_value in value['items'] | dictsort -%}
25
+ {%- if item_value is not none -%}
26
+ {%- if ns_items.found_first %},{% endif -%}
27
+ {%- set ns_items.found_first = true -%}
28
+ {%- if item_key == 'properties' -%}
29
+ properties:{
30
+ {%- if item_value is mapping -%}
31
+ {{- format_parameters(item_value, value['items']['required'] | default([])) -}}
32
+ {%- endif -%}
33
+ }
34
+ {%- elif item_key == 'required' -%}
35
+ required:[
36
+ {%- for req_item in item_value -%}
37
+ <|"|>{{- req_item -}}<|"|>
38
+ {%- if not loop.last %},{% endif -%}
39
+ {%- endfor -%}
40
+ ]
41
+ {%- elif item_key == 'type' -%}
42
+ {%- if item_value is string -%}
43
+ type:{{ format_argument(item_value | upper) }}
44
+ {%- else -%}
45
+ type:{{ format_argument(item_value | map('upper') | list) }}
46
+ {%- endif -%}
47
+ {%- else -%}
48
+ {{ item_key }}:{{ format_argument(item_value) }}
49
+ {%- endif -%}
50
+ {%- endif -%}
51
+ {%- endfor -%}
52
+ }
53
+ {%- endif -%}
54
+ {%- endif -%}
55
+ {%- if value['nullable'] %}
56
+ {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
57
+ nullable:true
58
+ {%- endif -%}
59
+ {%- if value['type'] | upper == 'OBJECT' -%}
60
+ {%- if value['properties'] is defined and value['properties'] is mapping -%}
61
+ {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
62
+ properties:{
63
+ {{- format_parameters(value['properties'], value['required'] | default([])) -}}
64
+ }
65
+ {%- elif value is mapping -%}
66
+ {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
67
+ properties:{
68
+ {{- format_parameters(value, value['required'] | default([]), filter_keys=true) -}}
69
+ }
70
+ {%- endif -%}
71
+ {%- if value['required'] -%}
72
+ {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
73
+ required:[
74
+ {%- for item in value['required'] | default([]) -%}
75
+ <|"|>{{- item -}}<|"|>
76
+ {%- if not loop.last %},{% endif -%}
77
+ {%- endfor -%}
78
+ ]
79
+ {%- endif -%}
80
+ {%- endif -%}
81
+ {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
82
+ type:<|"|>{{ value['type'] | upper }}<|"|>}
83
+ {%- endif -%}
84
+ {%- endfor -%}
85
+ {%- endmacro -%}
86
+ {%- macro format_function_declaration(tool_data) -%}
87
+ declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
88
+ {%- set params = tool_data['function']['parameters'] -%}
89
+ {%- if params -%}
90
+ ,parameters:{
91
+ {%- if params['properties'] -%}
92
+ properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
93
+ {%- endif -%}
94
+ {%- if params['required'] -%}
95
+ required:[
96
+ {%- for item in params['required'] -%}
97
+ <|"|>{{- item -}}<|"|>
98
+ {{- ',' if not loop.last -}}
99
+ {%- endfor -%}
100
+ ],
101
+ {%- endif -%}
102
+ {%- if params['type'] -%}
103
+ type:<|"|>{{- params['type'] | upper -}}<|"|>}
104
+ {%- endif -%}
105
+ {%- endif -%}
106
+ {%- if 'response' in tool_data['function'] -%}
107
+ {%- set response_declaration = tool_data['function']['response'] -%}
108
+ ,response:{
109
+ {%- if response_declaration['description'] -%}
110
+ description:<|"|>{{- response_declaration['description'] -}}<|"|>,
111
+ {%- endif -%}
112
+ {%- if response_declaration['type'] | upper == 'OBJECT' -%}
113
+ type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
114
+ {%- endif -%}
115
+ {%- endif -%}
116
+ }
117
+ {%- endmacro -%}
118
+ {%- macro format_argument(argument, escape_keys=True) -%}
119
+ {%- if argument is string -%}
120
+ {{- '<|"|>' + argument + '<|"|>' -}}
121
+ {%- elif argument is boolean -%}
122
+ {{- 'true' if argument else 'false' -}}
123
+ {%- elif argument is mapping -%}
124
+ {{- '{' -}}
125
+ {%- set ns = namespace(found_first=false) -%}
126
+ {%- for key, value in argument | dictsort -%}
127
+ {%- if ns.found_first %},{% endif -%}
128
+ {%- set ns.found_first = true -%}
129
+ {%- if escape_keys -%}
130
+ {{- '<|"|>' + key + '<|"|>' -}}
131
+ {%- else -%}
132
+ {{- key -}}
133
+ {%- endif -%}
134
+ :{{- format_argument(value, escape_keys=escape_keys) -}}
135
+ {%- endfor -%}
136
+ {{- '}' -}}
137
+ {%- elif argument is sequence -%}
138
+ {{- '[' -}}
139
+ {%- for item in argument -%}
140
+ {{- format_argument(item, escape_keys=escape_keys) -}}
141
+ {%- if not loop.last %},{% endif -%}
142
+ {%- endfor -%}
143
+ {{- ']' -}}
144
+ {%- else -%}
145
+ {{- argument -}}
146
+ {%- endif -%}
147
+ {%- endmacro -%}
148
+ {%- macro strip_thinking(text) -%}
149
+ {%- set ns = namespace(result='') -%}
150
+ {%- for part in text.split('<channel|>') -%}
151
+ {%- if '<|channel>' in part -%}
152
+ {%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
153
+ {%- else -%}
154
+ {%- set ns.result = ns.result + part -%}
155
+ {%- endif -%}
156
+ {%- endfor -%}
157
+ {{- ns.result | trim -}}
158
+ {%- endmacro -%}
159
+
160
+ {%- macro format_tool_response_block(tool_name, response) -%}
161
+ {{- '<|tool_response>' -}}
162
+ {%- if response is mapping -%}
163
+ {{- 'response:' + tool_name + '{' -}}
164
+ {%- for key, value in response | dictsort -%}
165
+ {{- key -}}:{{- format_argument(value, escape_keys=False) -}}
166
+ {%- if not loop.last %},{% endif -%}
167
+ {%- endfor -%}
168
+ {{- '}' -}}
169
+ {%- else -%}
170
+ {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
171
+ {%- endif -%}
172
+ {{- '<tool_response|>' -}}
173
+ {%- endmacro -%}
174
+
175
+ {%- set ns = namespace(prev_message_type=None) -%}
176
+ {%- set loop_messages = messages -%}
177
+ {{- bos_token -}}
178
+ {#- Handle System/Tool Definitions Block -#}
179
+ {%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
180
+ {{- '<|turn>system\n' -}}
181
+ {#- Inject Thinking token at the very top of the FIRST system turn -#}
182
+ {%- if enable_thinking is defined and enable_thinking -%}
183
+ {{- '<|think|>\n' -}}
184
+ {%- set ns.prev_message_type = 'think' -%}
185
+ {%- endif -%}
186
+ {%- if messages[0]['role'] in ['system', 'developer'] -%}
187
+ {%- if messages[0]['content'] is string -%}
188
+ {{- messages[0]['content'] | trim -}}
189
+ {%- elif messages[0]['content'] is sequence -%}
190
+ {%- for item in messages[0]['content'] -%}
191
+ {{- item['text'] | trim + ' '-}}
192
+ {%- endfor -%}
193
+ {%- endif -%}
194
+ {%- set loop_messages = messages[1:] -%}
195
+ {%- endif -%}
196
+ {%- if tools -%}
197
+ {%- for tool in tools %}
198
+ {{- '<|tool>' -}}
199
+ {{- format_function_declaration(tool) | trim -}}
200
+ {{- '<tool|>' -}}
201
+ {%- endfor %}
202
+ {%- set ns.prev_message_type = 'tool' -%}
203
+ {%- endif -%}
204
+ {{- '<turn|>\n' -}}
205
+ {%- endif %}
206
+
207
+ {#- Pre-scan: find last user message index for reasoning guard -#}
208
+ {%- set ns_turn = namespace(last_user_idx=-1) -%}
209
+ {%- for i in range(loop_messages | length) -%}
210
+ {%- if loop_messages[i]['role'] == 'user' -%}
211
+ {%- set ns_turn.last_user_idx = i -%}
212
+ {%- endif -%}
213
+ {%- endfor -%}
214
+
215
+ {#- Loop through messages -#}
216
+ {%- for message in loop_messages -%}
217
+ {%- if message['role'] != 'tool' -%}
218
+ {%- set ns.prev_message_type = None -%}
219
+ {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
220
+ {#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
221
+ {%- set prev_nt = namespace(role=None, found=false) -%}
222
+ {%- if loop.index0 > 0 -%}
223
+ {%- for j in range(loop.index0 - 1, -1, -1) -%}
224
+ {%- if not prev_nt.found -%}
225
+ {%- if loop_messages[j]['role'] != 'tool' -%}
226
+ {%- set prev_nt.role = loop_messages[j]['role'] -%}
227
+ {%- set prev_nt.found = true -%}
228
+ {%- endif -%}
229
+ {%- endif -%}
230
+ {%- endfor -%}
231
+ {%- endif -%}
232
+ {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
233
+ {%- if not continue_same_model_turn -%}
234
+ {{- '<|turn>' + role + '\n' }}
235
+ {%- endif -%}
236
+
237
+ {#- Render reasoning/reasoning_content as thinking channel -#}
238
+ {%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
239
+ {%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
240
+ {{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
241
+ {%- endif -%}
242
+
243
+ {%- if message['tool_calls'] -%}
244
+ {%- for tool_call in message['tool_calls'] -%}
245
+ {%- set function = tool_call['function'] -%}
246
+ {{- '<|tool_call>call:' + function['name'] + '{' -}}
247
+ {%- if function['arguments'] is mapping -%}
248
+ {%- set ns_args = namespace(found_first=false) -%}
249
+ {%- for key, value in function['arguments'] | dictsort -%}
250
+ {%- if ns_args.found_first %},{% endif -%}
251
+ {%- set ns_args.found_first = true -%}
252
+ {{- key -}}:{{- format_argument(value, escape_keys=False) -}}
253
+ {%- endfor -%}
254
+ {%- elif function['arguments'] is string -%}
255
+ {{- function['arguments'] -}}
256
+ {%- endif -%}
257
+ {{- '}<tool_call|>' -}}
258
+ {%- endfor -%}
259
+ {%- set ns.prev_message_type = 'tool_call' -%}
260
+ {%- endif -%}
261
+
262
+ {%- set ns_tr_out = namespace(flag=false) -%}
263
+ {%- if message.get('tool_responses') -%}
264
+ {#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
265
+ {%- for tool_response in message['tool_responses'] -%}
266
+ {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
267
+ {%- set ns_tr_out.flag = true -%}
268
+ {%- set ns.prev_message_type = 'tool_response' -%}
269
+ {%- endfor -%}
270
+ {%- elif message.get('tool_calls') -%}
271
+ {#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
272
+ {%- set ns_tool_scan = namespace(stopped=false) -%}
273
+ {%- for k in range(loop.index0 + 1, loop_messages | length) -%}
274
+ {%- if ns_tool_scan.stopped -%}
275
+ {%- elif loop_messages[k]['role'] != 'tool' -%}
276
+ {%- set ns_tool_scan.stopped = true -%}
277
+ {%- else -%}
278
+ {%- set follow = loop_messages[k] -%}
279
+ {#- Resolve tool_call_id to function name -#}
280
+ {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
281
+ {%- for tc in message['tool_calls'] -%}
282
+ {%- if tc.get('id') == follow.get('tool_call_id') -%}
283
+ {%- set ns_tname.name = tc['function']['name'] -%}
284
+ {%- endif -%}
285
+ {%- endfor -%}
286
+ {#- Handle content as string or content-parts array -#}
287
+ {%- set tool_body = follow.get('content') -%}
288
+ {%- if tool_body is string -%}
289
+ {{- format_tool_response_block(ns_tname.name, tool_body) -}}
290
+ {%- elif tool_body is sequence and tool_body is not string -%}
291
+ {%- set ns_txt = namespace(s='') -%}
292
+ {%- for part in tool_body -%}
293
+ {%- if part.get('type') == 'text' -%}
294
+ {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
295
+ {%- endif -%}
296
+ {%- endfor -%}
297
+ {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
298
+ {%- for part in tool_body -%}
299
+ {%- if part.get('type') == 'image' -%}
300
+ {{- '<|image|>' -}}
301
+ {%- elif part.get('type') == 'audio' -%}
302
+ {{- '<|audio|>' -}}
303
+ {%- elif part.get('type') == 'video' -%}
304
+ {{- '<|video|>' -}}
305
+ {%- endif -%}
306
+ {%- endfor -%}
307
+ {%- else -%}
308
+ {{- format_tool_response_block(ns_tname.name, tool_body) -}}
309
+ {%- endif -%}
310
+ {%- set ns_tr_out.flag = true -%}
311
+ {%- set ns.prev_message_type = 'tool_response' -%}
312
+ {%- endif -%}
313
+ {%- endfor -%}
314
+ {%- endif -%}
315
+
316
+ {%- set captured_content -%}
317
+ {%- if message['content'] is string -%}
318
+ {%- if role == 'model' -%}
319
+ {{- strip_thinking(message['content']) -}}
320
+ {%- else -%}
321
+ {{- message['content'] | trim -}}
322
+ {%- endif -%}
323
+ {%- elif message['content'] is sequence -%}
324
+ {%- for item in message['content'] -%}
325
+ {%- if item['type'] == 'text' -%}
326
+ {%- if role == 'model' -%}
327
+ {{- strip_thinking(item['text']) -}}
328
+ {%- else -%}
329
+ {{- item['text'] | trim -}}
330
+ {%- endif -%}
331
+ {%- elif item['type'] == 'image' -%}
332
+ {{- '<|image|>' -}}
333
+ {%- set ns.prev_message_type = 'image' -%}
334
+ {%- elif item['type'] == 'audio' -%}
335
+ {{- '<|audio|>' -}}
336
+ {%- set ns.prev_message_type = 'audio' -%}
337
+ {%- elif item['type'] == 'video' -%}
338
+ {{- '<|video|>' -}}
339
+ {%- set ns.prev_message_type = 'video' -%}
340
+ {%- endif -%}
341
+ {%- endfor -%}
342
+ {%- endif -%}
343
+ {%- endset -%}
344
+
345
+ {{- captured_content -}}
346
+ {%- set has_content = captured_content | trim | length > 0 -%}
347
+
348
+ {%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
349
+ {{- '<|tool_response>' -}}
350
+ {%- elif not (ns_tr_out.flag and not has_content) -%}
351
+ {{- '<turn|>\n' -}}
352
+ {%- endif -%}
353
+ {%- endif -%}
354
+ {%- endfor -%}
355
+
356
+ {%- if add_generation_prompt -%}
357
+ {%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
358
+ {{- '<|turn>model\n' -}}
359
+ {%- endif -%}
360
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma4ForTokenSequenceClassification"
4
+ ],
5
+ "audio_config": {
6
+ "_name_or_path": "",
7
+ "architectures": null,
8
+ "attention_chunk_size": 12,
9
+ "attention_context_left": 13,
10
+ "attention_context_right": 0,
11
+ "attention_invalid_logits_value": -1000000000.0,
12
+ "attention_logit_cap": 50.0,
13
+ "chunk_size_feed_forward": 0,
14
+ "conv_kernel_size": 5,
15
+ "dtype": "bfloat16",
16
+ "gradient_clipping": 10000000000.0,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 1024,
19
+ "id2label": {
20
+ "0": "LABEL_0",
21
+ "1": "LABEL_1"
22
+ },
23
+ "initializer_range": 0.02,
24
+ "is_encoder_decoder": false,
25
+ "label2id": {
26
+ "LABEL_0": 0,
27
+ "LABEL_1": 1
28
+ },
29
+ "model_type": "gemma4_audio",
30
+ "num_attention_heads": 8,
31
+ "num_hidden_layers": 12,
32
+ "output_attentions": false,
33
+ "output_hidden_states": false,
34
+ "output_proj_dims": 1536,
35
+ "problem_type": null,
36
+ "residual_weight": 0.5,
37
+ "return_dict": true,
38
+ "rms_norm_eps": 1e-06,
39
+ "subsampling_conv_channels": [
40
+ 128,
41
+ 32
42
+ ],
43
+ "use_clipped_linears": true
44
+ },
45
+ "audio_token_id": 258881,
46
+ "auto_map": {
47
+ "AutoModelForSequenceClassification": "modeling_gemma4_sequence.Gemma4ForTokenSequenceClassification"
48
+ },
49
+ "boa_token_id": 256000,
50
+ "boi_token_id": 255999,
51
+ "classifier_token_ids": {
52
+ "no": 1904,
53
+ "yes": 4443
54
+ },
55
+ "dtype": "bfloat16",
56
+ "eoa_token_id": 258883,
57
+ "eoa_token_index": 258883,
58
+ "eoi_token_id": 258882,
59
+ "eos_token_id": [
60
+ 1,
61
+ 106
62
+ ],
63
+ "id2label": {
64
+ "0": "no",
65
+ "1": "yes"
66
+ },
67
+ "image_token_id": 258880,
68
+ "initializer_range": 0.02,
69
+ "label2id": {
70
+ "no": 0,
71
+ "yes": 1
72
+ },
73
+ "model_type": "gemma4",
74
+ "pad_token_id": 0,
75
+ "problem_type": "single_label_classification",
76
+ "text_config": {
77
+ "attention_bias": false,
78
+ "attention_dropout": 0.0,
79
+ "attention_k_eq_v": false,
80
+ "bos_token_id": 2,
81
+ "dtype": "bfloat16",
82
+ "enable_moe_block": false,
83
+ "eos_token_id": 1,
84
+ "expert_intermediate_size": null,
85
+ "final_logit_softcapping": 30.0,
86
+ "global_head_dim": 512,
87
+ "head_dim": 256,
88
+ "hidden_activation": "gelu_pytorch_tanh",
89
+ "hidden_size": 1536,
90
+ "hidden_size_per_layer_input": 256,
91
+ "initializer_range": 0.02,
92
+ "intermediate_size": 6144,
93
+ "layer_types": [
94
+ "sliding_attention",
95
+ "sliding_attention",
96
+ "sliding_attention",
97
+ "sliding_attention",
98
+ "full_attention",
99
+ "sliding_attention",
100
+ "sliding_attention",
101
+ "sliding_attention",
102
+ "sliding_attention",
103
+ "full_attention",
104
+ "sliding_attention",
105
+ "sliding_attention",
106
+ "sliding_attention",
107
+ "sliding_attention",
108
+ "full_attention",
109
+ "sliding_attention",
110
+ "sliding_attention",
111
+ "sliding_attention",
112
+ "sliding_attention",
113
+ "full_attention",
114
+ "sliding_attention",
115
+ "sliding_attention",
116
+ "sliding_attention",
117
+ "sliding_attention",
118
+ "full_attention",
119
+ "sliding_attention",
120
+ "sliding_attention",
121
+ "sliding_attention",
122
+ "sliding_attention",
123
+ "full_attention",
124
+ "sliding_attention",
125
+ "sliding_attention",
126
+ "sliding_attention",
127
+ "sliding_attention",
128
+ "full_attention"
129
+ ],
130
+ "max_position_embeddings": 131072,
131
+ "model_type": "gemma4_text",
132
+ "moe_intermediate_size": null,
133
+ "num_attention_heads": 8,
134
+ "num_experts": null,
135
+ "num_global_key_value_heads": null,
136
+ "num_hidden_layers": 35,
137
+ "num_key_value_heads": 1,
138
+ "num_kv_shared_layers": 20,
139
+ "pad_token_id": 0,
140
+ "rms_norm_eps": 1e-06,
141
+ "rope_parameters": {
142
+ "full_attention": {
143
+ "partial_rotary_factor": 0.25,
144
+ "rope_theta": 1000000.0,
145
+ "rope_type": "proportional"
146
+ },
147
+ "sliding_attention": {
148
+ "rope_theta": 10000.0,
149
+ "rope_type": "default"
150
+ }
151
+ },
152
+ "sliding_window": 512,
153
+ "tie_word_embeddings": true,
154
+ "top_k_experts": null,
155
+ "use_bidirectional_attention": null,
156
+ "use_cache": true,
157
+ "use_double_wide_mlp": true,
158
+ "vocab_size": 262144,
159
+ "vocab_size_per_layer_input": 262144
160
+ },
161
+ "tie_word_embeddings": true,
162
+ "transformers_version": "5.9.0",
163
+ "video_token_id": 258884,
164
+ "vision_config": {
165
+ "_name_or_path": "",
166
+ "architectures": null,
167
+ "attention_bias": false,
168
+ "attention_dropout": 0.0,
169
+ "chunk_size_feed_forward": 0,
170
+ "default_output_length": 280,
171
+ "dtype": "bfloat16",
172
+ "global_head_dim": 64,
173
+ "head_dim": 64,
174
+ "hidden_activation": "gelu_pytorch_tanh",
175
+ "hidden_size": 768,
176
+ "id2label": {
177
+ "0": "LABEL_0",
178
+ "1": "LABEL_1"
179
+ },
180
+ "initializer_range": 0.02,
181
+ "intermediate_size": 3072,
182
+ "is_encoder_decoder": false,
183
+ "label2id": {
184
+ "LABEL_0": 0,
185
+ "LABEL_1": 1
186
+ },
187
+ "max_position_embeddings": 131072,
188
+ "model_type": "gemma4_vision",
189
+ "num_attention_heads": 12,
190
+ "num_hidden_layers": 16,
191
+ "num_key_value_heads": 12,
192
+ "output_attentions": false,
193
+ "output_hidden_states": false,
194
+ "patch_size": 16,
195
+ "pooling_kernel_size": 3,
196
+ "position_embedding_size": 10240,
197
+ "problem_type": null,
198
+ "return_dict": true,
199
+ "rms_norm_eps": 1e-06,
200
+ "rope_parameters": {
201
+ "rope_theta": 100.0,
202
+ "rope_type": "default"
203
+ },
204
+ "standardize": false,
205
+ "use_clipped_linears": true
206
+ },
207
+ "vision_soft_tokens_per_image": 280
208
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbe019c39e43a59531ad8cc07f6cfc5478dfef94d3e3d929836b7be6431dbe9c
3
+ size 10208859150
modeling_gemma4_sequence.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gemma 4 sequence classifier backed by selected next-token logits.
2
+
3
+ This module is intentionally small: it reuses the Gemma 4 multimodal backbone and
4
+ replaces the LM head with a classifier head containing selected token rows.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Sequence
10
+
11
+ import torch
12
+ from torch import nn
13
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
14
+ from transformers.models.gemma4.configuration_gemma4 import Gemma4Config
15
+ from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4PreTrainedModel
16
+
17
+
18
+ class Gemma4ForTokenSequenceClassification(Gemma4PreTrainedModel):
19
+ """Pool the last text position and score it with selected Gemma 4 token rows."""
20
+
21
+ config_class = Gemma4Config
22
+ base_model_prefix = "model"
23
+
24
+ @classmethod
25
+ def _can_set_experts_implementation(cls) -> bool:
26
+ return True
27
+
28
+ def __init__(
29
+ self,
30
+ config: Gemma4Config,
31
+ source_model: nn.Module | None = None,
32
+ classifier_weight: torch.Tensor | None = None,
33
+ ) -> None:
34
+ super().__init__(config)
35
+ self.num_labels = config.num_labels
36
+ self.model = source_model.model if source_model is not None else Gemma4Model(config)
37
+ self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
38
+
39
+ if classifier_weight is not None:
40
+ self.score.to(device=classifier_weight.device, dtype=classifier_weight.dtype)
41
+ self.score.weight.data.copy_(classifier_weight)
42
+
43
+ if source_model is None and classifier_weight is None:
44
+ self.post_init()
45
+
46
+ @classmethod
47
+ def from_conditional_generation(
48
+ cls,
49
+ model_lm: nn.Module,
50
+ selected_token_ids: Sequence[int],
51
+ labels: Sequence[str],
52
+ ) -> "Gemma4ForTokenSequenceClassification":
53
+ token_ids = torch.tensor(selected_token_ids, device=model_lm.lm_head.weight.device)
54
+ classifier_weight = model_lm.lm_head.weight.index_select(0, token_ids).detach().clone()
55
+ cls.configure_classification_config(model_lm.config, selected_token_ids, labels)
56
+ return cls(model_lm.config, source_model=model_lm, classifier_weight=classifier_weight)
57
+
58
+ @classmethod
59
+ def configure_classification_config(
60
+ cls,
61
+ config: Gemma4Config,
62
+ selected_token_ids: Sequence[int],
63
+ labels: Sequence[str],
64
+ ) -> None:
65
+ config.num_labels = len(labels)
66
+ config.id2label = {i: label for i, label in enumerate(labels)}
67
+ config.label2id = {label: i for i, label in enumerate(labels)}
68
+ config.classifier_token_ids = {
69
+ label: int(token_id) for label, token_id in zip(labels, selected_token_ids)
70
+ }
71
+ config.architectures = [cls.__name__]
72
+ config.problem_type = "single_label_classification"
73
+ if getattr(config, "pad_token_id", None) is None:
74
+ config.pad_token_id = config.text_config.pad_token_id
75
+
76
+ def get_input_embeddings(self):
77
+ return self.model.get_input_embeddings()
78
+
79
+ def set_input_embeddings(self, value):
80
+ self.model.set_input_embeddings(value)
81
+
82
+ def get_per_layer_input_embeddings(self):
83
+ return self.model.get_per_layer_input_embeddings()
84
+
85
+ def set_per_layer_input_embeddings(self, value):
86
+ self.model.set_per_layer_input_embeddings(value)
87
+
88
+ def _last_non_pad_token(
89
+ self,
90
+ logits: torch.Tensor,
91
+ input_ids: torch.LongTensor | None,
92
+ attention_mask: torch.Tensor | None,
93
+ inputs_embeds: torch.FloatTensor | None,
94
+ ) -> torch.Tensor | int:
95
+ batch_size = logits.shape[0]
96
+ if attention_mask is not None:
97
+ token_indices = torch.arange(logits.shape[1], device=logits.device)
98
+ return (attention_mask.to(logits.device) * token_indices).argmax(-1)
99
+
100
+ pad_token_id = getattr(self.config, "pad_token_id", None)
101
+ if input_ids is not None and pad_token_id is not None:
102
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
103
+ non_pad = input_ids.to(logits.device).ne(pad_token_id)
104
+ return (non_pad * token_indices).argmax(-1)
105
+
106
+ if batch_size != 1:
107
+ raise ValueError(
108
+ "Cannot infer sequence lengths for a padded batch without a pad token."
109
+ )
110
+
111
+ if input_ids is None and inputs_embeds is None:
112
+ raise ValueError("Expected input_ids or inputs_embeds.")
113
+
114
+ return -1
115
+
116
+ def _apply_final_logit_softcapping(self, logits: torch.Tensor) -> torch.Tensor:
117
+ final_logit_softcapping = self.config.get_text_config().final_logit_softcapping
118
+ if final_logit_softcapping is None:
119
+ return logits
120
+ logits = logits / final_logit_softcapping
121
+ logits = torch.tanh(logits)
122
+ return logits * final_logit_softcapping
123
+
124
+ def forward(
125
+ self,
126
+ input_ids: torch.LongTensor | None = None,
127
+ pixel_values: torch.FloatTensor | None = None,
128
+ pixel_values_videos: torch.FloatTensor | None = None,
129
+ input_features: torch.FloatTensor | None = None,
130
+ attention_mask: torch.Tensor | None = None,
131
+ input_features_mask: torch.Tensor | None = None,
132
+ position_ids: torch.LongTensor | None = None,
133
+ image_position_ids: torch.LongTensor | None = None,
134
+ video_position_ids: torch.LongTensor | None = None,
135
+ past_key_values=None,
136
+ mm_token_type_ids: torch.LongTensor | None = None,
137
+ inputs_embeds: torch.FloatTensor | None = None,
138
+ labels: torch.LongTensor | None = None,
139
+ use_cache: bool | None = None,
140
+ return_dict: bool | None = None,
141
+ **kwargs,
142
+ ):
143
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
144
+ outputs = self.model(
145
+ input_ids=input_ids,
146
+ pixel_values=pixel_values,
147
+ pixel_values_videos=pixel_values_videos,
148
+ input_features=input_features,
149
+ attention_mask=attention_mask,
150
+ input_features_mask=input_features_mask,
151
+ position_ids=position_ids,
152
+ past_key_values=past_key_values,
153
+ mm_token_type_ids=mm_token_type_ids,
154
+ inputs_embeds=inputs_embeds,
155
+ use_cache=use_cache,
156
+ image_position_ids=image_position_ids,
157
+ video_position_ids=video_position_ids,
158
+ return_dict=True,
159
+ **kwargs,
160
+ )
161
+
162
+ logits = self.score(outputs.last_hidden_state)
163
+ logits = self._apply_final_logit_softcapping(logits)
164
+ sequence_lengths = self._last_non_pad_token(
165
+ logits,
166
+ input_ids,
167
+ attention_mask,
168
+ inputs_embeds,
169
+ )
170
+ pooled_logits = logits[
171
+ torch.arange(logits.shape[0], device=logits.device),
172
+ sequence_lengths,
173
+ ]
174
+
175
+ loss = None
176
+ if labels is not None:
177
+ labels = labels.to(pooled_logits.device)
178
+ if self.config.problem_type == "regression":
179
+ loss = nn.MSELoss()(pooled_logits.squeeze(), labels.squeeze())
180
+ elif self.config.problem_type == "multi_label_classification":
181
+ loss = nn.BCEWithLogitsLoss()(pooled_logits, labels)
182
+ else:
183
+ loss = nn.CrossEntropyLoss()(
184
+ pooled_logits.view(-1, self.num_labels),
185
+ labels.view(-1),
186
+ )
187
+
188
+ if not return_dict:
189
+ output = (pooled_logits,) + outputs[1:]
190
+ return ((loss,) + output) if loss is not None else output
191
+
192
+ return SequenceClassifierOutputWithPast(
193
+ loss=loss,
194
+ logits=pooled_logits,
195
+ past_key_values=outputs.past_key_values,
196
+ hidden_states=outputs.hidden_states,
197
+ attentions=outputs.attentions,
198
+ )
199
+
200
+
201
+ Gemma4ForTokenSequenceClassification.register_for_auto_class("AutoModelForSequenceClassification")
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc8d3a0ce36466ccc1278bf987df5f71db1719b9ca6b4118264f45cb627bfe0f
3
+ size 32169626
tokenizer_config.json ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_token": "<|audio|>",
3
+ "backend": "tokenizers",
4
+ "boa_token": "<|audio>",
5
+ "boi_token": "<|image>",
6
+ "bos_token": "<bos>",
7
+ "eoa_token": "<audio|>",
8
+ "eoc_token": "<channel|>",
9
+ "eoi_token": "<image|>",
10
+ "eos_token": "<eos>",
11
+ "eot_token": "<turn|>",
12
+ "escape_token": "<|\"|>",
13
+ "etc_token": "<tool_call|>",
14
+ "etd_token": "<tool|>",
15
+ "etr_token": "<tool_response|>",
16
+ "extra_special_tokens": [
17
+ "<|video|>"
18
+ ],
19
+ "image_token": "<|image|>",
20
+ "is_local": false,
21
+ "local_files_only": false,
22
+ "mask_token": "<mask>",
23
+ "model_max_length": 1000000000000000019884624838656,
24
+ "model_specific_special_tokens": {
25
+ "audio_token": "<|audio|>",
26
+ "boa_token": "<|audio>",
27
+ "boi_token": "<|image>",
28
+ "eoa_token": "<audio|>",
29
+ "eoc_token": "<channel|>",
30
+ "eoi_token": "<image|>",
31
+ "eot_token": "<turn|>",
32
+ "escape_token": "<|\"|>",
33
+ "etc_token": "<tool_call|>",
34
+ "etd_token": "<tool|>",
35
+ "etr_token": "<tool_response|>",
36
+ "image_token": "<|image|>",
37
+ "soc_token": "<|channel>",
38
+ "sot_token": "<|turn>",
39
+ "stc_token": "<|tool_call>",
40
+ "std_token": "<|tool>",
41
+ "str_token": "<|tool_response>",
42
+ "think_token": "<|think|>"
43
+ },
44
+ "pad_token": "<pad>",
45
+ "padding_side": "left",
46
+ "processor_class": "Gemma4Processor",
47
+ "response_schema": {
48
+ "properties": {
49
+ "content": {
50
+ "type": "string"
51
+ },
52
+ "role": {
53
+ "const": "assistant"
54
+ },
55
+ "thinking": {
56
+ "type": "string"
57
+ },
58
+ "tool_calls": {
59
+ "items": {
60
+ "properties": {
61
+ "function": {
62
+ "properties": {
63
+ "arguments": {
64
+ "additionalProperties": {},
65
+ "type": "object",
66
+ "x-parser": "gemma4-tool-call"
67
+ },
68
+ "name": {
69
+ "type": "string"
70
+ }
71
+ },
72
+ "type": "object",
73
+ "x-regex": "call\\:(?P<name>\\w+)(?P<arguments>\\{.*\\})"
74
+ },
75
+ "type": {
76
+ "const": "function"
77
+ }
78
+ },
79
+ "type": "object"
80
+ },
81
+ "type": "array",
82
+ "x-regex-iterator": "<\\|tool_call>(.*?)<tool_call\\|>"
83
+ }
84
+ },
85
+ "type": "object",
86
+ "x-regex": "(\\<\\|channel\\>thought\\n(?P<thinking>.*?)\\<channel\\|\\>)?(?P<tool_calls>\\<\\|tool_call\\>.*\\<tool_call\\|\\>)?(?P<content>(?:(?!\\<turn\\|\\>)(?!\\<\\|tool_response\\>).)+)?(?:\\<turn\\|\\>|\\<\\|tool_response\\>)?"
87
+ },
88
+ "soc_token": "<|channel>",
89
+ "sot_token": "<|turn>",
90
+ "stc_token": "<|tool_call>",
91
+ "std_token": "<|tool>",
92
+ "str_token": "<|tool_response>",
93
+ "think_token": "<|think|>",
94
+ "tokenizer_class": "GemmaTokenizer",
95
+ "unk_token": "<unk>"
96
+ }