from flask import Flask, render_template_string, request, jsonify
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import os
import sys
import threading
import time
app = Flask(__name__)
CORS(app)
# Model loading state (thread-safe)
model_name = "openai/privacy-filter"
classifier = None
model_loading = False
model_error = None
model_thread = None
# Background model loading
def load_model_async():
global classifier, model_loading, model_error
model_loading = True
print("="*60, flush=True)
print("BACKGROUND: Loading OpenAI Privacy Filter model...", flush=True)
print("="*60, flush=True)
try:
print(f"Loading tokenizer and model: {model_name}", flush=True)
print("This may take 5-10 minutes on first run...", flush=True)
# Use AutoModelForTokenClassification directly for better performance
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir="/app/.cache/huggingface"
)
model = AutoModelForTokenClassification.from_pretrained(
model_name,
cache_dir="/app/.cache/huggingface"
)
global classifier
classifier = pipeline(
task="token-classification",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
device=-1 # Force CPU
)
print("✓ Model loaded successfully!", flush=True)
model_error = None
except Exception as e:
model_error = str(e)
print(f"✗ ERROR loading model: {e}", flush=True)
import traceback
traceback.print_exc()
finally:
model_loading = False
# Start model loading in background
model_thread = threading.Thread(target=load_model_async, daemon=True)
model_thread.start()
# HTML Template with proper loading states
HTML_TEMPLATE = '''
OpenAI Privacy Filter - PII Detection Demo
OpenAI Privacy Filter
PII Detection & Masking Demo using Flask
Waiting for server to start...
Detects 8 Types of PII:
- private_person - Names and personal identifiers
- private_email - Email addresses
- private_phone - Phone numbers
- private_address - Physical addresses
- account_number - Account/ID numbers
- secret - Passwords, tokens, credentials
- private_url - Personal/private URLs
- private_date - Personal dates (birthdays, etc.)
'''
@app.route('/')
def index():
return render_template_string(HTML_TEMPLATE)
@app.route('/health')
def health():
"""Health check with model loading status"""
global classifier, model_loading, model_error, model_thread
if classifier is not None:
return jsonify({
'status': 'healthy',
'model_loaded': True,
'model_loading': False
})
elif model_loading:
return jsonify({
'status': 'loading',
'model_loaded': False,
'model_loading': True,
'message': 'Model is still loading, please wait...'
})
else:
# Model failed or thread died
return jsonify({
'status': 'unhealthy',
'model_loaded': False,
'model_loading': False,
'error': model_error or 'Model loading failed or thread terminated unexpectedly'
}), 503
@app.route('/analyze', methods=['POST', 'OPTIONS'])
def analyze():
if request.method == 'OPTIONS':
return '', 204
global classifier, model_loading
if classifier is None:
return jsonify({
'success': False,
'error': f'Model not yet loaded. Current status: {"loading" if model_loading else "failed"}. Please wait and refresh in a few minutes.'
}), 503
try:
data = request.get_json()
if not data:
return jsonify({'success': False, 'error': 'No JSON data received'}), 400
text = data.get('text', '')
if not text.strip():
return jsonify({'success': True, 'entities': [], 'entity_count': 0})
# Run classification
results = classifier(text)
entities = []
for entity in results:
entities.append({
'label': entity.get('entity_group', entity.get('entity', 'unknown')),
'text': entity.get('word', ''),
'start': entity.get('start', 0),
'end': entity.get('end', 0),
'score': float(entity.get('score', 0))
})
return jsonify({
'success': True,
'entities': entities,
'entity_count': len(entities)
})
except Exception as e:
print(f"Error during analysis: {e}", flush=True)
import traceback
traceback.print_exc()
return jsonify({
'success': False,
'error': str(e)
}), 500
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)