Spaces:
Running
Running
Update train.py
Browse files
train.py
CHANGED
|
@@ -37,11 +37,12 @@ logger = logging.getLogger("train")
|
|
| 37 |
# =============================================================================
|
| 38 |
|
| 39 |
def export_dataset(output_path: str = None):
|
| 40 |
-
output = Path(output_path) if output_path else TRAIN_DATA
|
| 41 |
"""
|
| 42 |
Export HF dataset logs to JSONL format for training.
|
| 43 |
Filters: only HIGH_PRIORITY and MEDIUM_PRIORITY entries with actual responses.
|
| 44 |
"""
|
|
|
|
|
|
|
| 45 |
logger.info("Loading dataset from HF...")
|
| 46 |
entries = model_module.load_logs()
|
| 47 |
|
|
@@ -49,9 +50,7 @@ def export_dataset(output_path: str = None):
|
|
| 49 |
logger.warning("Dataset empty — nothing to export")
|
| 50 |
return
|
| 51 |
|
| 52 |
-
output = Path(output_path)
|
| 53 |
count = 0
|
| 54 |
-
|
| 55 |
with open(output, "w") as f:
|
| 56 |
for entry in entries:
|
| 57 |
# Only export entries where SmolLM2 actually responded
|
|
@@ -97,15 +96,14 @@ def validate_adi():
|
|
| 97 |
accuracy = analyzer.validate_weights(labeled)
|
| 98 |
logger.info(f"ADI Validation accuracy: {accuracy:.1%} on {len(labeled)} samples")
|
| 99 |
|
| 100 |
-
# Save results
|
| 101 |
result = {
|
| 102 |
"timestamp": datetime.utcnow().isoformat(),
|
| 103 |
-
"accuracy":
|
| 104 |
-
"samples":
|
| 105 |
-
"weights":
|
| 106 |
}
|
| 107 |
VALID_RESULT.write_text(json.dumps(result, indent=2))
|
| 108 |
-
logger.info("Results saved →
|
| 109 |
|
| 110 |
|
| 111 |
# =============================================================================
|
|
@@ -113,9 +111,14 @@ def validate_adi():
|
|
| 113 |
# =============================================================================
|
| 114 |
|
| 115 |
def finetune():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
if not TRAIN_DATA.exists():
|
| 117 |
-
logger.error(f"train_data.jsonl not found at {TRAIN_DATA}")
|
| 118 |
return
|
|
|
|
| 119 |
lines = TRAIN_DATA.read_text().strip().splitlines()
|
| 120 |
logger.info(f"Training samples available: {len(lines)}")
|
| 121 |
|
|
@@ -125,8 +128,8 @@ def finetune():
|
|
| 125 |
# TODO: implement finetuning with transformers Trainer
|
| 126 |
# Rough plan:
|
| 127 |
# 1. Load base model via model.get_model_id()
|
| 128 |
-
# 2. Tokenize
|
| 129 |
-
# 3. TrainingArguments + Trainer
|
| 130 |
# 4. Save to PRIVATE_MODEL repo via model.push_model_card()
|
| 131 |
logger.info("Finetune placeholder — not yet implemented")
|
| 132 |
logger.info("Next step: implement with transformers.Trainer or TRL SFTTrainer")
|
|
@@ -144,7 +147,7 @@ if __name__ == "__main__":
|
|
| 144 |
required=True,
|
| 145 |
help="export: dump dataset to JSONL | validate: test ADI weights | finetune: train model"
|
| 146 |
)
|
| 147 |
-
parser.add_argument("--output", default=
|
| 148 |
args = parser.parse_args()
|
| 149 |
|
| 150 |
if args.mode == "export":
|
|
@@ -152,4 +155,4 @@ if __name__ == "__main__":
|
|
| 152 |
elif args.mode == "validate":
|
| 153 |
validate_adi()
|
| 154 |
elif args.mode == "finetune":
|
| 155 |
-
finetune()
|
|
|
|
| 37 |
# =============================================================================
|
| 38 |
|
| 39 |
def export_dataset(output_path: str = None):
|
|
|
|
| 40 |
"""
|
| 41 |
Export HF dataset logs to JSONL format for training.
|
| 42 |
Filters: only HIGH_PRIORITY and MEDIUM_PRIORITY entries with actual responses.
|
| 43 |
"""
|
| 44 |
+
output = Path(output_path) if output_path else TRAIN_DATA
|
| 45 |
+
|
| 46 |
logger.info("Loading dataset from HF...")
|
| 47 |
entries = model_module.load_logs()
|
| 48 |
|
|
|
|
| 50 |
logger.warning("Dataset empty — nothing to export")
|
| 51 |
return
|
| 52 |
|
|
|
|
| 53 |
count = 0
|
|
|
|
| 54 |
with open(output, "w") as f:
|
| 55 |
for entry in entries:
|
| 56 |
# Only export entries where SmolLM2 actually responded
|
|
|
|
| 96 |
accuracy = analyzer.validate_weights(labeled)
|
| 97 |
logger.info(f"ADI Validation accuracy: {accuracy:.1%} on {len(labeled)} samples")
|
| 98 |
|
|
|
|
| 99 |
result = {
|
| 100 |
"timestamp": datetime.utcnow().isoformat(),
|
| 101 |
+
"accuracy": accuracy,
|
| 102 |
+
"samples": len(labeled),
|
| 103 |
+
"weights": analyzer.weights,
|
| 104 |
}
|
| 105 |
VALID_RESULT.write_text(json.dumps(result, indent=2))
|
| 106 |
+
logger.info(f"Results saved → {VALID_RESULT}")
|
| 107 |
|
| 108 |
|
| 109 |
# =============================================================================
|
|
|
|
| 111 |
# =============================================================================
|
| 112 |
|
| 113 |
def finetune():
|
| 114 |
+
"""
|
| 115 |
+
Finetune SmolLM2 on collected dataset.
|
| 116 |
+
Requires export first + enough data (>500 samples recommended).
|
| 117 |
+
"""
|
| 118 |
if not TRAIN_DATA.exists():
|
| 119 |
+
logger.error(f"train_data.jsonl not found at {TRAIN_DATA} — run export first")
|
| 120 |
return
|
| 121 |
+
|
| 122 |
lines = TRAIN_DATA.read_text().strip().splitlines()
|
| 123 |
logger.info(f"Training samples available: {len(lines)}")
|
| 124 |
|
|
|
|
| 128 |
# TODO: implement finetuning with transformers Trainer
|
| 129 |
# Rough plan:
|
| 130 |
# 1. Load base model via model.get_model_id()
|
| 131 |
+
# 2. Tokenize TRAIN_DATA
|
| 132 |
+
# 3. TrainingArguments + Trainer (or TRL SFTTrainer)
|
| 133 |
# 4. Save to PRIVATE_MODEL repo via model.push_model_card()
|
| 134 |
logger.info("Finetune placeholder — not yet implemented")
|
| 135 |
logger.info("Next step: implement with transformers.Trainer or TRL SFTTrainer")
|
|
|
|
| 147 |
required=True,
|
| 148 |
help="export: dump dataset to JSONL | validate: test ADI weights | finetune: train model"
|
| 149 |
)
|
| 150 |
+
parser.add_argument("--output", default=None, help="Output file for export mode (default: auto)")
|
| 151 |
args = parser.parse_args()
|
| 152 |
|
| 153 |
if args.mode == "export":
|
|
|
|
| 155 |
elif args.mode == "validate":
|
| 156 |
validate_adi()
|
| 157 |
elif args.mode == "finetune":
|
| 158 |
+
finetune()
|