smcleod commited on
Commit
c1e2ab3
·
verified ·
1 Parent(s): 13a2a0e

Upload modeling_seed_diffcoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_seed_diffcoder.py +30 -0
modeling_seed_diffcoder.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ from .generation_utils import generate_block
5
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
6
+ from transformers.generation.utils import GenerationConfig
7
+ import torch
8
+
9
+ class SeedDiffcoderForCausalLM(LlamaForCausalLM):
10
+ @torch.no_grad()
11
+ def generate(
12
+ self,
13
+ input_ids=None,
14
+ generation_config: GenerationConfig = None,
15
+ **kwargs,
16
+ ):
17
+ if input_ids is None:
18
+ raise ValueError("input_ids must be provided")
19
+
20
+ if generation_config is None:
21
+ generation_config = self.generation_config
22
+
23
+ prompt = input_ids
24
+ output_ids, nfe = generate_block(
25
+ model=self,
26
+ prompt=prompt,
27
+ **kwargs,
28
+ )
29
+
30
+ return output_ids