File size: 7,405 Bytes
492772b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""
Test script for Binary Segmentation API

Run this to verify the API is working correctly.
"""

import requests
import sys
import time
from pathlib import Path


def test_api(base_url: str = "http://localhost:7860"):
    """Run basic API tests"""
    
    print("=" * 60)
    print("Binary Segmentation API - Test Suite")
    print("=" * 60)
    print(f"\nTesting API at: {base_url}\n")
    
    # Test 1: Health Check
    print("Test 1: Health Check")
    try:
        response = requests.get(f"{base_url}/health", timeout=5)
        if response.status_code == 200:
            print("βœ“ Health check passed")
            print(f"  Response: {response.json()}")
        else:
            print(f"βœ— Health check failed: {response.status_code}")
            return False
    except Exception as e:
        print(f"βœ— Health check failed: {e}")
        print("\n  Make sure the API is running:")
        print("  python app.py")
        print("  or")
        print("  uvicorn app:app --host 0.0.0.0 --port 7860")
        return False
    
    print()
    
    # Test 2: List Models
    print("Test 2: List Models")
    try:
        response = requests.get(f"{base_url}/models", timeout=5)
        if response.status_code == 200:
            print("βœ“ Models endpoint working")
            data = response.json()
            print(f"  Available models: {len(data.get('models', []))}")
            for model in data.get('models', []):
                print(f"    - {model['name']}: {model['description']}")
        else:
            print(f"βœ— Models endpoint failed: {response.status_code}")
    except Exception as e:
        print(f"βœ— Models endpoint failed: {e}")
    
    print()
    
    # Test 3: Create test image
    print("Test 3: Create Test Image")
    try:
        import numpy as np
        from PIL import Image
        
        # Create a simple test image (100x100 red square on white background)
        img = np.ones((200, 200, 3), dtype=np.uint8) * 255
        img[50:150, 50:150] = [255, 0, 0]  # Red square
        
        test_img = Image.fromarray(img)
        test_path = Path("test_image.jpg")
        test_img.save(test_path)
        
        print(f"βœ“ Test image created: {test_path}")
    except Exception as e:
        print(f"βœ— Failed to create test image: {e}")
        return False
    
    print()
    
    # Test 4: Segmentation (if test image exists)
    if test_path.exists():
        print("Test 4: Image Segmentation")
        try:
            with open(test_path, 'rb') as f:
                files = {'file': f}
                data = {
                    'model': 'u2netp',
                    'threshold': '0.5'
                }
                
                start_time = time.time()
                response = requests.post(
                    f"{base_url}/segment",
                    files=files,
                    data=data,
                    timeout=30
                )
                elapsed = time.time() - start_time
                
                if response.status_code == 200:
                    output_path = Path("test_output.png")
                    with open(output_path, 'wb') as out:
                        out.write(response.content)
                    
                    print(f"βœ“ Segmentation successful ({elapsed:.2f}s)")
                    print(f"  Output saved to: {output_path}")
                    print(f"  Output size: {len(response.content)} bytes")
                else:
                    print(f"βœ— Segmentation failed: {response.status_code}")
                    print(f"  Response: {response.text}")
        except Exception as e:
            print(f"βœ— Segmentation failed: {e}")
    
    print()
    
    # Test 5: Mask endpoint
    if test_path.exists():
        print("Test 5: Binary Mask")
        try:
            with open(test_path, 'rb') as f:
                files = {'file': f}
                data = {
                    'model': 'u2netp',
                    'threshold': '0.5'
                }
                
                response = requests.post(
                    f"{base_url}/segment/mask",
                    files=files,
                    data=data,
                    timeout=30
                )
                
                if response.status_code == 200:
                    mask_path = Path("test_mask.png")
                    with open(mask_path, 'wb') as out:
                        out.write(response.content)
                    
                    print(f"βœ“ Mask generation successful")
                    print(f"  Mask saved to: {mask_path}")
                else:
                    print(f"βœ— Mask generation failed: {response.status_code}")
        except Exception as e:
            print(f"βœ— Mask generation failed: {e}")
    
    print()
    
    # Test 6: Base64 endpoint
    if test_path.exists():
        print("Test 6: Base64 Output")
        try:
            with open(test_path, 'rb') as f:
                files = {'file': f}
                data = {
                    'model': 'u2netp',
                    'threshold': '0.5',
                    'return_type': 'both'
                }
                
                response = requests.post(
                    f"{base_url}/segment/base64",
                    files=files,
                    data=data,
                    timeout=30
                )
                
                if response.status_code == 200:
                    result = response.json()
                    print(f"βœ“ Base64 output successful")
                    print(f"  Has RGBA: {'rgba' in result}")
                    print(f"  Has Mask: {'mask' in result}")
                else:
                    print(f"βœ— Base64 output failed: {response.status_code}")
        except Exception as e:
            print(f"βœ— Base64 output failed: {e}")
    
    print()
    
    # Cleanup
    print("Cleanup:")
    try:
        if test_path.exists():
            test_path.unlink()
            print(f"  Removed: {test_path}")
        
        output_path = Path("test_output.png")
        if output_path.exists():
            output_path.unlink()
            print(f"  Removed: {output_path}")
        
        mask_path = Path("test_mask.png")
        if mask_path.exists():
            mask_path.unlink()
            print(f"  Removed: {mask_path}")
    except Exception as e:
        print(f"  Warning: Cleanup failed: {e}")
    
    print()
    print("=" * 60)
    print("Test Suite Complete!")
    print("=" * 60)
    
    return True


if __name__ == "__main__":
    # Get base URL from command line or use default
    base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
    
    success = test_api(base_url)
    
    if success:
        print("\nβœ“ All critical tests passed!")
        print("\nNext steps:")
        print("1. Open http://localhost:7860 in your browser")
        print("2. Upload an image and test the web interface")
        print("3. Deploy to Hugging Face Spaces (see DEPLOYMENT.md)")
        sys.exit(0)
    else:
        print("\nβœ— Some tests failed!")
        print("\nTroubleshooting:")
        print("1. Make sure the server is running:")
        print("   uvicorn app:app --host 0.0.0.0 --port 7860")
        print("2. Check that u2netp.pth is in .model_cache/")
        print("3. Check logs for errors")
        sys.exit(1)