Spaces:
Build error
Build error
File size: 5,414 Bytes
9874885 8e03aff 9874885 8e03aff 9874885 8e03aff 9874885 8e03aff a6578e7 8e03aff 9874885 8e03aff 9874885 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | use burn::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Annotation {
pub img_folder: String,
pub img_name: String,
pub bbox: Vec<i32>,
pub xy: Vec<Vec<f32>>,
}
pub struct DartDataset {
pub annotations: Vec<Annotation>,
pub base_path: String,
}
impl DartDataset {
pub fn load(json_path: &str, base_path: &str) -> Self {
let file = File::open(json_path).expect("Labels JSON not found");
let reader = BufReader::new(file);
let raw_data: HashMap<String, Annotation> = serde_json::from_reader(reader).expect("JSON parse error");
let mut annotations: Vec<Annotation> = raw_data.into_values().collect();
annotations.sort_by(|a, b| a.img_name.cmp(&b.img_name));
Self {
annotations,
base_path: base_path.to_string(),
}
}
}
impl burn::data::dataset::Dataset<Annotation> for DartDataset {
fn get(&self, index: usize) -> Option<Annotation> {
self.annotations.get(index).cloned()
}
fn len(&self) -> usize {
self.annotations.len()
}
}
#[derive(Clone, Debug)]
pub struct DartBatch<B: Backend> {
pub images: Tensor<B, 4>,
pub targets: Tensor<B, 4>,
}
#[derive(Clone, Debug)]
pub struct DartBatcher<B: Backend> {
device: Device<B>,
}
use burn::data::dataloader::batcher::Batcher;
impl<B: Backend> Batcher<Annotation, DartBatch<B>> for DartBatcher<B> {
fn batch(&self, items: Vec<Annotation>) -> DartBatch<B> {
self.batch_manual(items)
}
}
impl<B: Backend> DartBatcher<B> {
pub fn new(device: Device<B>) -> Self {
Self { device }
}
pub fn batch_manual(&self, items: Vec<Annotation>) -> DartBatch<B> {
let batch_size = items.len();
// Use 800 to match original Python training config (configs/deepdarts_d1.yaml: input_size: 800)
let input_res: usize = 800;
// For tiny YOLO: grid = input_res / 16. 800/16 = 50
let grid_size: usize = 50;
let num_anchors: usize = 3;
let num_attrs: usize = 10; // x, y, w, h, obj, cls0..cls4
let num_channels: usize = num_anchors * num_attrs; // = 30
let mut images_list = Vec::with_capacity(batch_size);
let mut target_raw = vec![0.0f32; batch_size * num_channels * grid_size * grid_size];
for (b_idx, item) in items.iter().enumerate() {
// 1. Process Image
let path = format!("dataset/800/{}/{}", item.img_folder, item.img_name);
let img = image::open(&path).unwrap_or_else(|_| {
println!("⚠️ [Data] Image not found: {}", path);
image::DynamicImage::new_rgb8(input_res as u32, input_res as u32)
});
let resized = img.resize_exact(input_res as u32, input_res as u32, image::imageops::FilterType::Triangle);
let pixels: Vec<f32> = resized.to_rgb8().pixels()
.flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
.collect();
images_list.push(TensorData::new(pixels, [input_res, input_res, 3]));
for (i, p) in item.xy.iter().enumerate() {
// Clamp coordinates to valid grid range
let norm_x = p[0].clamp(0.0, 1.0 - 1e-5);
let norm_y = p[1].clamp(0.0, 1.0 - 1e-5);
let gx = (norm_x * grid_size as f32).floor() as usize;
let gy = (norm_y * grid_size as f32).floor() as usize;
// Grid-relative offset (0..1 within cell)
let tx = norm_x * grid_size as f32 - gx as f32;
let ty = norm_y * grid_size as f32 - gy as f32;
// Python convention: cal points i=0..3 -> cls=1..4, dart i>=4 -> cls=0
let cls = if i < 4 { i + 1 } else { 0 };
// Assign this keypoint to anchor (cls % num_anchors) so all 3 anchors get used
let anchor_idx = cls % num_anchors;
// Flat index layout: [batch, anchor, attr, gy, gx]
// => flat = b * (3*10*G*G) + anchor * (10*G*G) + attr * (G*G) + gy*G + gx
let cell_base = b_idx * num_channels * grid_size * grid_size
+ anchor_idx * num_attrs * grid_size * grid_size
+ gy * grid_size
+ gx;
target_raw[cell_base + 0 * grid_size * grid_size] = tx; // x offset
target_raw[cell_base + 1 * grid_size * grid_size] = ty; // y offset
target_raw[cell_base + 2 * grid_size * grid_size] = 0.025; // w (bbox_size from config)
target_raw[cell_base + 3 * grid_size * grid_size] = 0.025; // h
target_raw[cell_base + 4 * grid_size * grid_size] = 1.0; // objectness
target_raw[cell_base + (5 + cls) * grid_size * grid_size] = 1.0; // class prob
}
}
let images = Tensor::stack(
images_list.into_iter().map(|d| Tensor::<B, 3>::from_data(d, &self.device)).collect(),
0
).permute([0, 3, 1, 2]);
let targets = Tensor::from_data(
TensorData::new(target_raw, [batch_size, num_channels, grid_size, grid_size]),
&self.device
);
DartBatch { images, targets }
}
}
|