| """
|
| Comprehensive TorchForge Examples
|
|
|
| Demonstrates all major features of TorchForge framework.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from torch.utils.data import DataLoader, TensorDataset
|
|
|
| from torchforge import ForgeModel, ForgeConfig
|
| from torchforge.governance import ComplianceChecker, NISTFramework
|
| from torchforge.monitoring import ModelMonitor
|
| from torchforge.deployment import DeploymentManager
|
|
|
|
|
|
|
| def example_basic_classification():
|
| """Basic classification with TorchForge."""
|
| print("\n" + "="*60)
|
| print("Example 1: Basic Classification")
|
| print("="*60)
|
|
|
|
|
| class Classifier(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.fc1 = nn.Linear(20, 64)
|
| self.fc2 = nn.Linear(64, 32)
|
| self.fc3 = nn.Linear(32, 3)
|
| self.relu = nn.ReLU()
|
|
|
| def forward(self, x):
|
| x = self.relu(self.fc1(x))
|
| x = self.relu(self.fc2(x))
|
| return self.fc3(x)
|
|
|
|
|
| config = ForgeConfig(
|
| model_name="simple_classifier",
|
| version="1.0.0",
|
| enable_monitoring=True,
|
| enable_governance=True
|
| )
|
|
|
| base_model = Classifier()
|
| model = ForgeModel(base_model, config=config)
|
|
|
|
|
| X_train = torch.randn(1000, 20)
|
| y_train = torch.randint(0, 3, (1000,))
|
|
|
|
|
| criterion = nn.CrossEntropyLoss()
|
| optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
| print("\nTraining model...")
|
| for epoch in range(5):
|
| model.train()
|
| optimizer.zero_grad()
|
| output = model(X_train)
|
| loss = criterion(output, y_train)
|
| loss.backward()
|
| optimizer.step()
|
|
|
|
|
| model.track_prediction(output, y_train, metadata={"epoch": epoch})
|
| print(f"Epoch {epoch+1}/5, Loss: {loss.item():.4f}")
|
|
|
|
|
| print("\nModel Metrics:")
|
| metrics = model.get_metrics_summary()
|
| for key, value in metrics.items():
|
| print(f" {key}: {value}")
|
|
|
| print("\n✓ Example 1 completed successfully!")
|
|
|
|
|
|
|
| def example_governance():
|
| """Demonstrate governance and compliance features."""
|
| print("\n" + "="*60)
|
| print("Example 2: Governance & Compliance")
|
| print("="*60)
|
|
|
|
|
| class SimpleNet(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.fc = nn.Linear(10, 2)
|
|
|
| def forward(self, x):
|
| return self.fc(x)
|
|
|
| config = ForgeConfig(
|
| model_name="compliant_model",
|
| version="1.0.0",
|
| enable_governance=True,
|
| enable_monitoring=True,
|
| )
|
| config.governance.bias_detection = True
|
| config.governance.audit_logging = True
|
| config.governance.lineage_tracking = True
|
|
|
| model = ForgeModel(SimpleNet(), config=config)
|
|
|
|
|
| print("\nRunning NIST AI RMF compliance check...")
|
| checker = ComplianceChecker(framework=NISTFramework.RMF_1_0)
|
| report = checker.assess_model(model)
|
|
|
| print(f"\nCompliance Results:")
|
| print(f" Overall Score: {report.overall_score:.1f}/100")
|
| print(f" Risk Level: {report.risk_level}")
|
| print(f"\nCompliance Checks:")
|
| for check in report.checks:
|
| status = "✓" if check.passed else "✗"
|
| print(f" {status} {check.check_name}: {check.score:.1f}/100")
|
|
|
| print(f"\nRecommendations:")
|
| for i, rec in enumerate(report.recommendations, 1):
|
| print(f" {i}. {rec}")
|
|
|
|
|
| print("\nExporting compliance report...")
|
| report.export_json("compliance_report.json")
|
| report.export_pdf("compliance_report.pdf")
|
| print(" - compliance_report.json")
|
| print(" - compliance_report.html")
|
|
|
| print("\n✓ Example 2 completed successfully!")
|
|
|
|
|
|
|
| def example_deployment():
|
| """Demonstrate deployment features."""
|
| print("\n" + "="*60)
|
| print("Example 3: Production Deployment")
|
| print("="*60)
|
|
|
|
|
| class ProductionModel(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(10, 64),
|
| nn.ReLU(),
|
| nn.Linear(64, 2)
|
| )
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
| config = ForgeConfig(
|
| model_name="production_model",
|
| version="2.0.0",
|
| enable_monitoring=True,
|
| enable_governance=True,
|
| enable_optimization=True
|
| )
|
|
|
| model = ForgeModel(ProductionModel(), config=config)
|
|
|
|
|
| print("\nDeploying to AWS SageMaker...")
|
| deployment = DeploymentManager(
|
| model=model,
|
| cloud_provider="aws",
|
| instance_type="ml.g4dn.xlarge"
|
| )
|
|
|
| info = deployment.deploy(
|
| enable_autoscaling=True,
|
| min_instances=2,
|
| max_instances=10,
|
| health_check_path="/health"
|
| )
|
|
|
| print(f"\nDeployment Information:")
|
| print(f" Status: {info['status']}")
|
| print(f" Endpoint: {info['endpoint_url']}")
|
| print(f" Cloud Provider: {info['cloud_provider']}")
|
| print(f" Instance Type: {info['instance_type']}")
|
| print(f" Autoscaling: {info['autoscaling_enabled']}")
|
| print(f" Min Instances: {info['min_instances']}")
|
| print(f" Max Instances: {info['max_instances']}")
|
|
|
|
|
| print("\nDeployment Metrics (1h window):")
|
| metrics = deployment.get_metrics(window="1h")
|
| print(f" P95 Latency: {metrics.latency_p95:.2f}ms")
|
| print(f" P99 Latency: {metrics.latency_p99:.2f}ms")
|
| print(f" Requests/sec: {metrics.requests_per_second:.1f}")
|
| print(f" Error Rate: {metrics.error_rate:.3%}")
|
|
|
| print("\n✓ Example 3 completed successfully!")
|
|
|
|
|
|
|
| def example_monitoring():
|
| """Demonstrate monitoring features."""
|
| print("\n" + "="*60)
|
| print("Example 4: Monitoring & Observability")
|
| print("="*60)
|
|
|
|
|
| class MonitoredNet(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.fc = nn.Linear(10, 2)
|
|
|
| def forward(self, x):
|
| return self.fc(x)
|
|
|
| config = ForgeConfig(
|
| model_name="monitored_model",
|
| version="1.0.0",
|
| enable_monitoring=True
|
| )
|
| config.monitoring.drift_detection = True
|
| config.monitoring.fairness_tracking = True
|
| config.monitoring.prometheus_enabled = True
|
|
|
| model = ForgeModel(MonitoredNet(), config=config)
|
|
|
|
|
| print("\nSetting up model monitor...")
|
| monitor = ModelMonitor(model)
|
| monitor.enable_drift_detection()
|
| monitor.enable_fairness_tracking()
|
|
|
|
|
| print("\nSimulating production traffic...")
|
| for i in range(100):
|
| x = torch.randn(1, 10)
|
| _ = model(x)
|
|
|
|
|
| print("\nModel Health Status:")
|
| health = monitor.get_health_status()
|
| print(f" Status: {health['status']}")
|
| print(f" Drift Detection: {health['drift_detection']}")
|
| print(f" Fairness Tracking: {health['fairness_tracking']}")
|
|
|
| metrics = health['metrics']
|
| print(f"\nPerformance Metrics:")
|
| print(f" Total Inferences: {metrics['inference_count']}")
|
| print(f" Mean Latency: {metrics['latency_mean_ms']:.2f}ms")
|
| print(f" P95 Latency: {metrics['latency_p95_ms']:.2f}ms")
|
| print(f" Error Rate: {metrics['error_rate']:.3%}")
|
|
|
| print("\n✓ Example 4 completed successfully!")
|
|
|
|
|
|
|
| def example_complete_pipeline():
|
| """Demonstrate complete ML pipeline."""
|
| print("\n" + "="*60)
|
| print("Example 5: Complete ML Pipeline")
|
| print("="*60)
|
|
|
|
|
| class MLPipeline(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Linear(20, 128),
|
| nn.ReLU(),
|
| nn.Dropout(0.2),
|
| nn.Linear(128, 64),
|
| nn.ReLU(),
|
| nn.Linear(64, 2)
|
| )
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
| print("\n1. Configuring model...")
|
| config = ForgeConfig(
|
| model_name="ml_pipeline",
|
| version="1.0.0",
|
| description="Complete ML pipeline with all features",
|
| author="Anil Prasad",
|
| tags=["production", "classification"],
|
| enable_monitoring=True,
|
| enable_governance=True,
|
| enable_optimization=True
|
| )
|
|
|
| model = ForgeModel(MLPipeline(), config=config)
|
|
|
|
|
| print("\n2. Training model...")
|
| X = torch.randn(1000, 20)
|
| y = torch.randint(0, 2, (1000,))
|
|
|
| criterion = nn.CrossEntropyLoss()
|
| optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
| model.train()
|
| for epoch in range(10):
|
| optimizer.zero_grad()
|
| output = model(X)
|
| loss = criterion(output, y)
|
| loss.backward()
|
| optimizer.step()
|
|
|
| if (epoch + 1) % 2 == 0:
|
| print(f" Epoch {epoch+1}/10, Loss: {loss.item():.4f}")
|
|
|
|
|
| print("\n3. Evaluating model...")
|
| model.eval()
|
| with torch.no_grad():
|
| output = model(X)
|
| predictions = output.argmax(dim=1)
|
| accuracy = (predictions == y).float().mean()
|
| print(f" Accuracy: {accuracy:.2%}")
|
|
|
|
|
| print("\n4. Checking compliance...")
|
| checker = ComplianceChecker()
|
| report = checker.assess_model(model)
|
| print(f" Compliance Score: {report.overall_score:.1f}/100")
|
| print(f" Risk Level: {report.risk_level}")
|
|
|
|
|
| print("\n5. Saving checkpoint...")
|
| model.save_checkpoint("ml_pipeline_checkpoint.pt")
|
| print(" ✓ Checkpoint saved")
|
|
|
|
|
| print("\n6. Deploying to production...")
|
| deployment = DeploymentManager(model=model)
|
| info = deployment.deploy(enable_autoscaling=True)
|
| print(f" ✓ Deployed to {info['endpoint_url']}")
|
|
|
|
|
| print("\n7. Setting up monitoring...")
|
| monitor = ModelMonitor(model)
|
| monitor.enable_drift_detection()
|
| monitor.enable_fairness_tracking()
|
| print(" ✓ Monitoring enabled")
|
|
|
| print("\n✓ Example 5 completed successfully!")
|
| print("\nComplete ML pipeline executed end-to-end!")
|
|
|
|
|
| if __name__ == "__main__":
|
| print("\n" + "="*60)
|
| print("TorchForge - Comprehensive Examples")
|
| print("Author: Anil Prasad")
|
| print("="*60)
|
|
|
|
|
| example_basic_classification()
|
| example_governance()
|
| example_deployment()
|
| example_monitoring()
|
| example_complete_pipeline()
|
|
|
| print("\n" + "="*60)
|
| print("All examples completed successfully! 🎉")
|
| print("="*60)
|
|
|