Alibrown commited on
Commit
231c7d9
·
verified ·
1 Parent(s): 5913e40

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +16 -13
train.py CHANGED
@@ -37,11 +37,12 @@ logger = logging.getLogger("train")
37
  # =============================================================================
38
 
39
  def export_dataset(output_path: str = None):
40
- output = Path(output_path) if output_path else TRAIN_DATA
41
  """
42
  Export HF dataset logs to JSONL format for training.
43
  Filters: only HIGH_PRIORITY and MEDIUM_PRIORITY entries with actual responses.
44
  """
 
 
45
  logger.info("Loading dataset from HF...")
46
  entries = model_module.load_logs()
47
 
@@ -49,9 +50,7 @@ def export_dataset(output_path: str = None):
49
  logger.warning("Dataset empty — nothing to export")
50
  return
51
 
52
- output = Path(output_path)
53
  count = 0
54
-
55
  with open(output, "w") as f:
56
  for entry in entries:
57
  # Only export entries where SmolLM2 actually responded
@@ -97,15 +96,14 @@ def validate_adi():
97
  accuracy = analyzer.validate_weights(labeled)
98
  logger.info(f"ADI Validation accuracy: {accuracy:.1%} on {len(labeled)} samples")
99
 
100
- # Save results
101
  result = {
102
  "timestamp": datetime.utcnow().isoformat(),
103
- "accuracy": accuracy,
104
- "samples": len(labeled),
105
- "weights": analyzer.weights,
106
  }
107
  VALID_RESULT.write_text(json.dumps(result, indent=2))
108
- logger.info("Results saved → validation_results.json")
109
 
110
 
111
  # =============================================================================
@@ -113,9 +111,14 @@ def validate_adi():
113
  # =============================================================================
114
 
115
  def finetune():
 
 
 
 
116
  if not TRAIN_DATA.exists():
117
- logger.error(f"train_data.jsonl not found at {TRAIN_DATA}")
118
  return
 
119
  lines = TRAIN_DATA.read_text().strip().splitlines()
120
  logger.info(f"Training samples available: {len(lines)}")
121
 
@@ -125,8 +128,8 @@ def finetune():
125
  # TODO: implement finetuning with transformers Trainer
126
  # Rough plan:
127
  # 1. Load base model via model.get_model_id()
128
- # 2. Tokenize train_data.jsonl
129
- # 3. TrainingArguments + Trainer
130
  # 4. Save to PRIVATE_MODEL repo via model.push_model_card()
131
  logger.info("Finetune placeholder — not yet implemented")
132
  logger.info("Next step: implement with transformers.Trainer or TRL SFTTrainer")
@@ -144,7 +147,7 @@ if __name__ == "__main__":
144
  required=True,
145
  help="export: dump dataset to JSONL | validate: test ADI weights | finetune: train model"
146
  )
147
- parser.add_argument("--output", default="train_data.jsonl", help="Output file for export mode")
148
  args = parser.parse_args()
149
 
150
  if args.mode == "export":
@@ -152,4 +155,4 @@ if __name__ == "__main__":
152
  elif args.mode == "validate":
153
  validate_adi()
154
  elif args.mode == "finetune":
155
- finetune()
 
37
  # =============================================================================
38
 
39
  def export_dataset(output_path: str = None):
 
40
  """
41
  Export HF dataset logs to JSONL format for training.
42
  Filters: only HIGH_PRIORITY and MEDIUM_PRIORITY entries with actual responses.
43
  """
44
+ output = Path(output_path) if output_path else TRAIN_DATA
45
+
46
  logger.info("Loading dataset from HF...")
47
  entries = model_module.load_logs()
48
 
 
50
  logger.warning("Dataset empty — nothing to export")
51
  return
52
 
 
53
  count = 0
 
54
  with open(output, "w") as f:
55
  for entry in entries:
56
  # Only export entries where SmolLM2 actually responded
 
96
  accuracy = analyzer.validate_weights(labeled)
97
  logger.info(f"ADI Validation accuracy: {accuracy:.1%} on {len(labeled)} samples")
98
 
 
99
  result = {
100
  "timestamp": datetime.utcnow().isoformat(),
101
+ "accuracy": accuracy,
102
+ "samples": len(labeled),
103
+ "weights": analyzer.weights,
104
  }
105
  VALID_RESULT.write_text(json.dumps(result, indent=2))
106
+ logger.info(f"Results saved → {VALID_RESULT}")
107
 
108
 
109
  # =============================================================================
 
111
  # =============================================================================
112
 
113
  def finetune():
114
+ """
115
+ Finetune SmolLM2 on collected dataset.
116
+ Requires export first + enough data (>500 samples recommended).
117
+ """
118
  if not TRAIN_DATA.exists():
119
+ logger.error(f"train_data.jsonl not found at {TRAIN_DATA} — run export first")
120
  return
121
+
122
  lines = TRAIN_DATA.read_text().strip().splitlines()
123
  logger.info(f"Training samples available: {len(lines)}")
124
 
 
128
  # TODO: implement finetuning with transformers Trainer
129
  # Rough plan:
130
  # 1. Load base model via model.get_model_id()
131
+ # 2. Tokenize TRAIN_DATA
132
+ # 3. TrainingArguments + Trainer (or TRL SFTTrainer)
133
  # 4. Save to PRIVATE_MODEL repo via model.push_model_card()
134
  logger.info("Finetune placeholder — not yet implemented")
135
  logger.info("Next step: implement with transformers.Trainer or TRL SFTTrainer")
 
147
  required=True,
148
  help="export: dump dataset to JSONL | validate: test ADI weights | finetune: train model"
149
  )
150
+ parser.add_argument("--output", default=None, help="Output file for export mode (default: auto)")
151
  args = parser.parse_args()
152
 
153
  if args.mode == "export":
 
155
  elif args.mode == "validate":
156
  validate_adi()
157
  elif args.mode == "finetune":
158
+ finetune()