Instructions to use abdullah890/malconv with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use abdullah890/malconv with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://abdullah890/malconv") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import sys | |
| import argparse | |
| import numpy as np | |
| import tensorflow as tf | |
| from sklearn.model_selection import train_test_split | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.model import create_malconv_model | |
| from src.utils import ( | |
| configure_gpu_memory, | |
| plot_training_history, | |
| evaluate_model, | |
| get_file_paths_and_labels, | |
| data_generator, | |
| read_binary_file | |
| ) | |
| def train_malconv(data_source, | |
| epochs=10, | |
| batch_size=256, | |
| max_length=2_000_000, | |
| validation_split=0.2, | |
| save_path="models/malconv_model.h5"): | |
| """ | |
| MalConv ๋ชจ๋ธ ํ๋ จ (๋ฐ์ดํฐ ์ ๋๋ ์ดํฐ ์ฌ์ฉ) | |
| Args: | |
| data_source: (malware_dir, benign_dir) ํํ | |
| epochs: ํ๋ จ ์ํฌํฌ ์ | |
| batch_size: ๋ฐฐ์น ํฌ๊ธฐ | |
| max_length: ์ต๋ ์ ๋ ฅ ๊ธธ์ด (2MB) | |
| validation_split: ๊ฒ์ฆ ๋ฐ์ดํฐ ๋น์จ | |
| save_path: ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก | |
| """ | |
| print("=" * 60) | |
| print("MalConv ๋ชจ๋ธ ํ๋ จ ์์ (๋ฐ์ดํฐ ์ ๋๋ ์ดํฐ ๋ชจ๋)") | |
| print("=" * 60) | |
| # GPU ์ค์ | |
| configure_gpu_memory() | |
| # ๋ฐ์ดํฐ ๊ฒฝ๋ก ๋ฐ ๋ ์ด๋ธ ๋ก๋ฉ | |
| if isinstance(data_source, tuple) and len(data_source) == 2: | |
| malware_dir, benign_dir = data_source | |
| filepaths, labels = get_file_paths_and_labels(malware_dir, benign_dir) | |
| else: | |
| raise ValueError("data_source๋ (malware_dir, benign_dir) ํํ์ด์ด์ผ ํฉ๋๋ค.") | |
| # ํ๋ จ/๊ฒ์ฆ ๋ถํ (ํ์ผ ๊ฒฝ๋ก ๊ธฐ์ค) | |
| filepaths_train, filepaths_val, labels_train, labels_val = train_test_split( | |
| filepaths, labels, test_size=validation_split, random_state=42, stratify=labels | |
| ) | |
| print(f"์ด ๋ฐ์ดํฐ: {len(filepaths)}") | |
| print(f"ํ๋ จ ๋ฐ์ดํฐ: {len(filepaths_train)}, ๊ฒ์ฆ ๋ฐ์ดํฐ: {len(filepaths_val)}") | |
| # ๋ฐ์ดํฐ ์ ๋๋ ์ดํฐ ์์ฑ | |
| train_gen = data_generator(filepaths_train, labels_train, batch_size, max_length) | |
| val_gen = data_generator(filepaths_val, labels_val, batch_size, max_length, shuffle=False) # ๊ฒ์ฆ ์์๋ ์ ํ ์ํจ | |
| # ๋ชจ๋ธ ์์ฑ | |
| print("MalConv ๋ชจ๋ธ ์์ฑ ์ค...") | |
| model = create_malconv_model(max_length) | |
| # ๋๋ฏธ ์ ๋ ฅ์ผ๋ก ๋ชจ๋ธ ๋น๋ | |
| dummy_input = np.zeros((1, max_length), dtype=np.uint8) | |
| _ = model(dummy_input) | |
| print("\n=== ๋ชจ๋ธ ์ํคํ ์ฒ ===") | |
| model.summary() | |
| print(f"์ด ํ๋ผ๋ฏธํฐ ์: {model.count_params():,}") | |
| # ์ฝ๋ฐฑ ์ค์ | |
| callbacks = [ | |
| tf.keras.callbacks.EarlyStopping( | |
| monitor='val_loss', | |
| patience=5, # ์ฐธ์์ฑ ์ฆ๊ฐ | |
| restore_best_weights=True, | |
| verbose=1 | |
| ), | |
| tf.keras.callbacks.ModelCheckpoint( | |
| save_path, | |
| monitor='val_auc', | |
| save_best_only=True, | |
| verbose=1, | |
| mode='max' # AUC๋ ๋์์๋ก ์ข์ | |
| ) | |
| ] | |
| # ํ๋ จ | |
| print(f"\n=== ํ๋ จ ์์ ===") | |
| print(f"๋ฐฐ์น ํฌ๊ธฐ: {batch_size}") | |
| print(f"์ํฌํฌ: {epochs}") | |
| history = model.fit( | |
| train_gen, | |
| steps_per_epoch=len(filepaths_train) // batch_size, | |
| epochs=epochs, | |
| validation_data=val_gen, | |
| validation_steps=len(filepaths_val) // batch_size, | |
| callbacks=callbacks, | |
| verbose=1 | |
| ) | |
| # ํ๊ฐ (๋ฉ๋ชจ๋ฆฌ ๋ฌธ์ ๋ก ๊ฒ์ฆ ๋ฐ์ดํฐ์ ์ผ๋ถ๋ง ์ฌ์ฉ) | |
| print("\n=== ์ต์ข ํ๊ฐ ===") | |
| num_eval_samples = min(len(filepaths_val), 1024) # ํ๊ฐ ์ํ ์ ์ ํ | |
| X_eval = np.array([read_binary_file(fp, max_length) for fp in filepaths_val[:num_eval_samples]]) | |
| y_eval = np.array(labels_val[:num_eval_samples]) | |
| if X_eval.size > 0: | |
| results = evaluate_model(model, X_eval, y_eval, batch_size=batch_size//2) | |
| else: | |
| print("ํ๊ฐํ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") | |
| results = {} | |
| # ์๊ฐํ | |
| plot_training_history(history) | |
| print(f"\n๋ชจ๋ธ์ด ์ ์ฅ๋์์ต๋๋ค: {save_path}") | |
| return model, history, results | |
| def main(): | |
| parser = argparse.ArgumentParser(description='MalConv ๋ชจ๋ธ ํ๋ จ') | |
| # ๋ฐ์ดํฐ ์์ค ์ต์ | |
| parser.add_argument('--malware_dir', required=True, help='์ ์ฑ์ฝ๋ ๋๋ ํ ๋ฆฌ') | |
| parser.add_argument('--benign_dir', required=True, help='์ ์ํ์ผ ๋๋ ํ ๋ฆฌ') | |
| # ํ๋ จ ์ต์ | |
| parser.add_argument('--epochs', type=int, default=20, help='์ํฌํฌ ์') # ์ํฌํฌ ์ฆ๊ฐ | |
| parser.add_argument('--batch_size', type=int, default=64, help='๋ฐฐ์น ํฌ๊ธฐ') # ๋ฐฐ์น ํฌ๊ธฐ ์กฐ์ | |
| parser.add_argument('--max_length', type=int, default=2_000_000, help='์ต๋ ์ ๋ ฅ ๊ธธ์ด') | |
| parser.add_argument('--save_path', default='models/malconv_model.h5', help='๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก') | |
| args = parser.parse_args() | |
| data_source = (args.malware_dir, args.benign_dir) | |
| # ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ | |
| os.makedirs(os.path.dirname(args.save_path), exist_ok=True) | |
| # ๋ชจ๋ธ ํ๋ จ | |
| train_malconv( | |
| data_source=data_source, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| max_length=args.max_length, | |
| save_path=args.save_path | |
| ) | |
| if __name__ == "__main__": | |
| main() |