File size: 2,004 Bytes
90dd6a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e03aff
90dd6a4
8e03aff
90dd6a4
 
 
 
 
 
 
 
 
 
 
 
8e03aff
90dd6a4
 
 
 
8e03aff
90dd6a4
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use crate::model::DartVisionModel;
use burn::backend::Wgpu;
use burn::backend::wgpu::WgpuDevice;
use burn::prelude::*;
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};

pub fn test_model(device: WgpuDevice, img_path: &str) {
    println!("🔍 Testing model on: {}", img_path);
    let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
    let model = DartVisionModel::<Wgpu>::new(&device);

    let record = match recorder.load("model_weights".into(), &device) {
        Ok(r) => r,
        Err(_) => {
            println!("⚠️ Weights not found, using initial model.");
            model.clone().into_record()
        }
    };
    let model = model.load_record(record);

    let img = image::open(img_path).unwrap_or_else(|_| {
        println!("❌ Image not found at {}. Using random tensor.", img_path);
        image::DynamicImage::new_rgb8(800, 800)
    });
    let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
    let pixels: Vec<f32> = resized
        .to_rgb8()
        .pixels()
        .flat_map(|p| {
            vec![
                p[0] as f32 / 255.0,
                p[1] as f32 / 255.0,
                p[2] as f32 / 255.0,
            ]
        })
        .collect();

    let tensor_data = TensorData::new(pixels, [1, 800, 800, 3]);
    let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &device).permute([0, 3, 1, 2]);
    let (out, _): (Tensor<Wgpu, 4>, _) = model.forward(input);

    let obj = burn::tensor::activation::sigmoid(out.clone().narrow(1, 4, 1));
    let (max_val, _) = obj.reshape([1_usize, 2500]).max_dim_with_indices(1);

    let score = max_val
        .to_data()
        .convert::<f32>()
        .as_slice::<f32>()
        .unwrap()[0];
    println!("📊 Max Objectness Score: {:.6}", score);
    
    if score > 0.1 {
        println!("✅ Model detection looks promising!");
    } else {
        println!("⚠️ Low confidence detection. Training may still be in progress.");
    }
}