| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader, Subset |
| | import torchvision.transforms as transforms |
| | from PIL import Image |
| | import os |
| | import numpy as np |
| | from bs4 import BeautifulSoup |
| | import argparse |
| | import logging |
| | from torch.utils.tensorboard import SummaryWriter |
| | from datetime import datetime |
| | import json |
| | from PIL import Image, ImageDraw |
| | import matplotlib.pyplot as plt |
| |
|
| |
|
| | def get_ground_truth(image, cells, otsl, split_width=5): |
| |
|
| | """ |
| | parse OTSL to derive row/column split positions. |
| | this is the groundtruth for split model training. |
| | |
| | Args: |
| | image: PIL Image |
| | html_tags: not used, kept for compatibility |
| | cells: nested list - cells[0] contains actual cell data |
| | otsl: OTSL token sequence |
| | split_width: width of split regions in pixels (default: 5) |
| | """ |
| | orig_width, orig_height = image.size |
| | target_size = 960 |
| | |
| | |
| | cells_flat = cells[0] |
| | |
| | |
| | grid = [] |
| | current_row = [] |
| | cell_idx = 0 |
| | |
| | for token in otsl: |
| | if token == 'nl': |
| | if current_row: |
| | grid.append(current_row) |
| | current_row = [] |
| | elif token == 'fcel' or token=='ecel': |
| | current_row.append({'type': token, 'cell_idx': cell_idx}) |
| | cell_idx += 1 |
| | elif token in ['lcel', 'ucel', 'xcel']: |
| | |
| | current_row.append({'type': token, 'cell_idx': None}) |
| | |
| | if current_row: |
| | grid.append(current_row) |
| | |
| | |
| | row_splits = [] |
| | for row in grid: |
| | row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None] |
| | if row_cell_indices: |
| | max_y = max(cells_flat[i]['bbox'][3] for i in row_cell_indices) |
| | row_splits.append(max_y) |
| | |
| | |
| | num_cols = len(grid[0]) if grid else 0 |
| | col_splits = [] |
| | for col_idx in range(num_cols): |
| | col_max_x = [] |
| | for row in grid: |
| | if col_idx < len(row) and row[col_idx]['cell_idx'] is not None: |
| | next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel') |
| | if not next_is_lcel: |
| | cell_id = row[col_idx]['cell_idx'] |
| | col_max_x.append(cells_flat[cell_id]['bbox'][2]) |
| | if col_max_x: |
| | col_splits.append(max(col_max_x)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | row_splits = row_splits[:-1] |
| | col_splits = col_splits[:-1] |
| |
|
| | |
| | y_scaled = [(y * target_size / orig_height) for y in row_splits] |
| | x_scaled = [(x * target_size / orig_width) for x in col_splits] |
| | |
| | |
| | horizontal_gt = [0] * target_size |
| | vertical_gt = [0] * target_size |
| |
|
| | all_x1 = [c['bbox'][0] for c in cells_flat] |
| | all_y1 = [c['bbox'][1] for c in cells_flat] |
| | all_x2 = [c['bbox'][2] for c in cells_flat] |
| | all_y2 = [c['bbox'][3] for c in cells_flat] |
| | table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)] |
| | table_y1 = int(round(table_bbox[1] * target_size / orig_height)) |
| | table_y2 = int(round(table_bbox[3] * target_size / orig_height)) |
| | table_x1 = int(round(table_bbox[0] * target_size / orig_width)) |
| | table_x2 = int(round(table_bbox[2] * target_size / orig_width)) |
| |
|
| |
|
| | |
| | |
| | for offset in range(split_width): |
| | pos = table_y1 + offset |
| | if 0 <= pos < target_size: |
| | horizontal_gt[pos] = 1 |
| |
|
| | |
| | for offset in range(split_width): |
| | pos = table_y2 - offset |
| | if 0 <= pos < target_size: |
| | horizontal_gt[pos] = 1 |
| |
|
| | |
| | for offset in range(split_width): |
| | pos = table_x1 + offset |
| | if 0 <= pos < target_size: |
| | vertical_gt[pos] = 1 |
| |
|
| | |
| | for offset in range(split_width): |
| | pos = table_x2 - offset |
| | if 0 <= pos < target_size: |
| | vertical_gt[pos] = 1 |
| |
|
| | |
| | for y in y_scaled: |
| | y_int = int(round(y)) |
| | if 0 <= y_int < target_size: |
| | for offset in range(split_width): |
| | pos = y_int + offset |
| | if 0 <= pos < target_size: |
| | horizontal_gt[pos] = 1 |
| |
|
| | for x in x_scaled: |
| | x_int = int(round(x)) |
| | if 0 <= x_int < target_size: |
| | for offset in range(split_width): |
| | pos = x_int + offset |
| | if 0 <= pos < target_size: |
| | vertical_gt[pos] = 1 |
| | |
| | return horizontal_gt, vertical_gt |
| |
|
| |
|
| | def get_ground_truth_auto_gap(image, cells, otsl): |
| | """ |
| | Parse OTSL to derive row/column split positions with DYNAMIC gap widths. |
| | This creates ground truth for the split model training. |
| | |
| | Args: |
| | image: PIL Image |
| | cells: nested list - cells[0] contains actual cell data |
| | otsl: OTSL token sequence |
| | """ |
| | orig_width, orig_height = image.size |
| | target_size = 960 |
| | |
| | |
| | cells_flat = cells[0] |
| | |
| | |
| | grid = [] |
| | current_row = [] |
| | cell_idx = 0 |
| | |
| | for token in otsl: |
| | if token == 'nl': |
| | if current_row: |
| | grid.append(current_row) |
| | current_row = [] |
| | elif token == 'fcel' or token == 'ecel': |
| | current_row.append({'type': token, 'cell_idx': cell_idx}) |
| | cell_idx += 1 |
| | elif token in ['lcel', 'ucel', 'xcel']: |
| | |
| | current_row.append({'type': token, 'cell_idx': None}) |
| | |
| | if current_row: |
| | grid.append(current_row) |
| | |
| | |
| | row_boundaries = [] |
| | for row in grid: |
| | row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None] |
| | if row_cell_indices: |
| | min_y1 = min(cells_flat[i]['bbox'][1] for i in row_cell_indices) |
| | max_y2 = max(cells_flat[i]['bbox'][3] for i in row_cell_indices) |
| | row_boundaries.append({'min_y': min_y1, 'max_y': max_y2}) |
| | |
| | |
| | num_cols = len(grid[0]) if grid else 0 |
| | col_boundaries = [] |
| | for col_idx in range(num_cols): |
| | col_cells = [] |
| | for row in grid: |
| | if col_idx < len(row) and row[col_idx]['cell_idx'] is not None: |
| | |
| | next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel') |
| | if not next_is_lcel: |
| | cell_id = row[col_idx]['cell_idx'] |
| | col_cells.append(cell_id) |
| | if col_cells: |
| | min_x1 = min(cells_flat[i]['bbox'][0] for i in col_cells) |
| | max_x2 = max(cells_flat[i]['bbox'][2] for i in col_cells) |
| | col_boundaries.append({'min_x': min_x1, 'max_x': max_x2}) |
| | |
| | |
| | all_x1 = [c['bbox'][0] for c in cells_flat] |
| | all_y1 = [c['bbox'][1] for c in cells_flat] |
| | all_x2 = [c['bbox'][2] for c in cells_flat] |
| | all_y2 = [c['bbox'][3] for c in cells_flat] |
| | table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)] |
| | |
| | |
| | horizontal_gt = [0] * target_size |
| | vertical_gt = [0] * target_size |
| | |
| | |
| | def mark_range(gt_array, start, end, orig_dim): |
| | """Mark all pixels from start to end (scaled to target_size)""" |
| | start_scaled = int(round(start * target_size / orig_dim)) |
| | end_scaled = int(round(end * target_size / orig_dim)) |
| | for pos in range(start_scaled, min(end_scaled + 1, target_size)): |
| | if 0 <= pos < target_size: |
| | gt_array[pos] = 1 |
| | |
| | |
| | |
| | if row_boundaries: |
| | mark_range(horizontal_gt, 0, row_boundaries[0]['min_y'], orig_height) |
| | |
| | |
| | for i in range(len(row_boundaries) - 1): |
| | gap_start = row_boundaries[i]['max_y'] |
| | gap_end = row_boundaries[i + 1]['min_y'] |
| | if gap_end > gap_start: |
| | mark_range(horizontal_gt, gap_start, gap_end, orig_height) |
| | |
| | |
| | if row_boundaries: |
| | mark_range(horizontal_gt, row_boundaries[-1]['max_y'], orig_height, orig_height) |
| | |
| | |
| | |
| | if col_boundaries: |
| | mark_range(vertical_gt, 0, col_boundaries[0]['min_x'], orig_width) |
| | |
| | |
| | for i in range(len(col_boundaries) - 1): |
| | gap_start = col_boundaries[i]['max_x'] |
| | gap_end = col_boundaries[i + 1]['min_x'] |
| | if gap_end > gap_start: |
| | mark_range(vertical_gt, gap_start, gap_end, orig_width) |
| | |
| | |
| | if col_boundaries: |
| | mark_range(vertical_gt, col_boundaries[-1]['max_x'], orig_width, orig_width) |
| | |
| | return horizontal_gt, vertical_gt |
| |
|
| |
|
| | def get_ground_truth_auto_gap_expand_min5pix_overlap_cells(image, cells, otsl, split_width=5): |
| | """ |
| | Parse OTSL to derive row/column split positions with DYNAMIC gap widths. |
| | This creates ground truth for the split model training. |
| | |
| | Args: |
| | image: PIL Image |
| | cells: nested list - cells[0] contains actual cell data |
| | otsl: OTSL token sequence |
| | split_width: width of split when there's no gap (default: 5) |
| | """ |
| | orig_width, orig_height = image.size |
| | target_size = 960 |
| | |
| | |
| | cells_flat = cells[0] |
| | |
| | |
| | grid = [] |
| | current_row = [] |
| | cell_idx = 0 |
| | |
| | for token in otsl: |
| | if token == 'nl': |
| | if current_row: |
| | grid.append(current_row) |
| | current_row = [] |
| | elif token in ['fcel', 'ecel']: |
| | current_row.append({'type': token, 'cell_idx': cell_idx}) |
| | cell_idx += 1 |
| | elif token in ['lcel', 'ucel', 'xcel']: |
| | |
| | current_row.append({'type': token, 'cell_idx': None}) |
| | |
| | if current_row: |
| | grid.append(current_row) |
| | |
| | |
| | row_boundaries = [] |
| | for row in grid: |
| | row_cell_indices = [item['cell_idx'] for item in row if item['cell_idx'] is not None] |
| | if row_cell_indices: |
| | min_y1 = min(cells_flat[i]['bbox'][1] for i in row_cell_indices) |
| | max_y2 = max(cells_flat[i]['bbox'][3] for i in row_cell_indices) |
| | row_boundaries.append({'min_y': min_y1, 'max_y': max_y2, 'row_cells': row_cell_indices}) |
| | |
| | |
| | num_cols = len(grid[0]) if grid else 0 |
| | col_boundaries = [] |
| | for col_idx in range(num_cols): |
| | col_cells = [] |
| | for row in grid: |
| | if col_idx < len(row) and row[col_idx]['cell_idx'] is not None: |
| | |
| | next_is_lcel = (col_idx + 1 < len(row) and row[col_idx + 1]['type'] == 'lcel') |
| | if not next_is_lcel: |
| | cell_id = row[col_idx]['cell_idx'] |
| | col_cells.append(cell_id) |
| | if col_cells: |
| | min_x1 = min(cells_flat[i]['bbox'][0] for i in col_cells) |
| | max_x2 = max(cells_flat[i]['bbox'][2] for i in col_cells) |
| | col_boundaries.append({'min_x': min_x1, 'max_x': max_x2, 'col_cells': col_cells}) |
| | |
| | |
| | all_x1 = [c['bbox'][0] for c in cells_flat] |
| | all_y1 = [c['bbox'][1] for c in cells_flat] |
| | all_x2 = [c['bbox'][2] for c in cells_flat] |
| | all_y2 = [c['bbox'][3] for c in cells_flat] |
| | table_bbox = [min(all_x1), min(all_y1), max(all_x2), max(all_y2)] |
| | |
| | |
| | horizontal_gt = [0] * target_size |
| | vertical_gt = [0] * target_size |
| | |
| | |
| | def mark_range(gt_array, start, end, orig_dim): |
| | """Mark all pixels from start to end (scaled to target_size)""" |
| | start_scaled = int(round(start * target_size / orig_dim)) |
| | end_scaled = int(round(end * target_size / orig_dim)) |
| | for pos in range(start_scaled, min(end_scaled + 1, target_size)): |
| | if 0 <= pos < target_size: |
| | gt_array[pos] = 1 |
| | |
| | |
| | |
| | if row_boundaries: |
| | mark_range(horizontal_gt, 0, row_boundaries[0]['min_y'], orig_height) |
| | |
| | |
| | for i in range(len(row_boundaries) - 1): |
| | gap_start = row_boundaries[i]['max_y'] |
| | gap_end = row_boundaries[i + 1]['min_y'] |
| | if gap_end > gap_start: |
| | mark_range(horizontal_gt, gap_start, gap_end, orig_height) |
| | else: |
| | |
| | curr_row_y2 = [cells_flat[cell_id]['bbox'][3] for cell_id in row_boundaries[i]['row_cells']] |
| | next_row_y1 = [cells_flat[cell_id]['bbox'][1] for cell_id in row_boundaries[i + 1]['row_cells']] |
| | |
| | max_curr_y2 = max(curr_row_y2) |
| | min_next_y1 = min(next_row_y1) |
| | |
| | |
| | if min_next_y1 > max_curr_y2: |
| | mark_range(horizontal_gt, max_curr_y2, min_next_y1, orig_height) |
| | else: |
| | |
| | split_pos = (max_curr_y2 + min_next_y1) / 2 |
| | mark_range(horizontal_gt, split_pos - split_width/2, split_pos + split_width/2, orig_height) |
| | |
| | |
| | if row_boundaries: |
| | mark_range(horizontal_gt, row_boundaries[-1]['max_y'], orig_height, orig_height) |
| | |
| | |
| | |
| | if col_boundaries: |
| | mark_range(vertical_gt, 0, col_boundaries[0]['min_x'], orig_width) |
| | |
| | |
| | for i in range(len(col_boundaries) - 1): |
| | gap_start = col_boundaries[i]['max_x'] |
| | gap_end = col_boundaries[i + 1]['min_x'] |
| | |
| | if gap_end > gap_start: |
| | mark_range(vertical_gt, gap_start, gap_end, orig_width) |
| | else: |
| | |
| | curr_col_x2 = [cells_flat[cell_id]['bbox'][2] for cell_id in col_boundaries[i]['col_cells']] |
| | next_col_x1 = [cells_flat[cell_id]['bbox'][0] for cell_id in col_boundaries[i + 1]['col_cells']] |
| | |
| | max_curr_x2 = max(curr_col_x2) |
| | min_next_x1 = min(next_col_x1) |
| | |
| | |
| | if min_next_x1 > max_curr_x2: |
| | mark_range(vertical_gt, max_curr_x2, min_next_x1, orig_width) |
| | else: |
| | |
| | split_pos = (max_curr_x2 + min_next_x1) / 2 |
| | mark_range(vertical_gt, split_pos - split_width/2, split_pos + split_width/2, orig_width) |
| | |
| | |
| | if col_boundaries: |
| | mark_range(vertical_gt, col_boundaries[-1]['max_x'], orig_width, orig_width) |
| | |
| | return horizontal_gt, vertical_gt |
| |
|
| |
|
| | class BasicBlock(nn.Module): |
| | """Basic ResNet block with halved channels""" |
| | def __init__(self, inplanes, planes, stride=1): |
| | super().__init__() |
| | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(planes) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(planes) |
| | |
| | self.downsample = None |
| | if stride != 1 or inplanes != planes: |
| | self.downsample = nn.Sequential( |
| | nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), |
| | nn.BatchNorm2d(planes) |
| | ) |
| | |
| | def forward(self, x): |
| | residual = x |
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| | out = self.conv2(out) |
| | out = self.bn2(out) |
| | |
| | if self.downsample is not None: |
| | residual = self.downsample(x) |
| | |
| | out += residual |
| | out = self.relu(out) |
| | return out |
| |
|
| | class ModifiedResNet18(nn.Module): |
| | """ResNet-18 with removed maxpool and halved channels""" |
| | def __init__(self): |
| | super().__init__() |
| | |
| | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False) |
| | self.bn1 = nn.BatchNorm2d(32) |
| | self.relu = nn.ReLU(inplace=True) |
| | |
| | |
| | |
| | self.layer1 = self._make_layer(32, 32, 2, stride=1) |
| | self.layer2 = self._make_layer(32, 64, 2, stride=2) |
| | self.layer3 = self._make_layer(64, 128, 2, stride=2) |
| | self.layer4 = self._make_layer(128, 256, 2, stride=2) |
| | |
| | def _make_layer(self, inplanes, planes, blocks, stride=1): |
| | layers = [] |
| | layers.append(BasicBlock(inplanes, planes, stride)) |
| | for _ in range(1, blocks): |
| | layers.append(BasicBlock(planes, planes)) |
| | return nn.Sequential(*layers) |
| | |
| | def forward(self, x): |
| | x = self.conv1(x) |
| | x = self.bn1(x) |
| | x = self.relu(x) |
| | |
| | |
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| | x = self.layer3(x) |
| | x = self.layer4(x) |
| | return x |
| |
|
| | class FPN(nn.Module): |
| | """Feature Pyramid Network outputting 128 channels at H/2×W/2""" |
| | def __init__(self): |
| | super().__init__() |
| | self.conv = nn.Conv2d(256, 128, kernel_size=1) |
| | |
| | def forward(self, x): |
| | |
| | x = self.conv(x) |
| | |
| | x = F.interpolate(x, size=(480, 480), mode='bilinear', align_corners=False) |
| | return x |
| |
|
| | class SplitModel(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.backbone = ModifiedResNet18() |
| | self.fpn = FPN() |
| | |
| | |
| | self.h_global_weight = nn.Parameter(torch.randn(480)) |
| | self.v_global_weight = nn.Parameter(torch.randn(480)) |
| | |
| | |
| | self.h_local_conv = nn.Conv2d(128, 1, kernel_size=1) |
| | self.v_local_conv = nn.Conv2d(128, 1, kernel_size=1) |
| | |
| | |
| | feature_dim = 128 + 120 |
| |
|
| | |
| | self.h_pos_embed = nn.Parameter(torch.randn(480, feature_dim)) |
| | self.v_pos_embed = nn.Parameter(torch.randn(480, feature_dim)) |
| |
|
| | |
| | self.h_transformer = nn.TransformerEncoder( |
| | nn.TransformerEncoderLayer( |
| | d_model=feature_dim, nhead=8, dim_feedforward=2048, |
| | dropout=0.1, batch_first=True |
| | ), |
| | num_layers=3 |
| | ) |
| | self.v_transformer = nn.TransformerEncoder( |
| | nn.TransformerEncoderLayer( |
| | d_model=feature_dim, nhead=8, dim_feedforward=2048, |
| | dropout=0.1, batch_first=True |
| | ), |
| | num_layers=3 |
| | ) |
| |
|
| | |
| | self.h_classifier = nn.Linear(feature_dim, 1) |
| | self.v_classifier = nn.Linear(feature_dim, 1) |
| | |
| | def forward(self, x): |
| | |
| | features = self.backbone(x) |
| | F_half = self.fpn(features) |
| |
|
| | B, C, H, W = F_half.shape |
| |
|
| | |
| | |
| | F_RG = torch.einsum('bchw,w->bch', F_half, self.h_global_weight) |
| | F_RG = F_RG.transpose(1, 2) |
| |
|
| | |
| | F_RL_pooled = F.avg_pool2d(F_half, kernel_size=(1, 4)) |
| | F_RL = self.h_local_conv(F_RL_pooled) |
| | F_RL = F_RL.squeeze(1) |
| |
|
| | |
| | F_RG_L = torch.cat([F_RG, F_RL], dim=2) |
| |
|
| | |
| | F_RG_L = F_RG_L + self.h_pos_embed |
| |
|
| | |
| | |
| | F_CG = torch.einsum('bchw,h->bcw', F_half, self.v_global_weight) |
| | F_CG = F_CG.transpose(1, 2) |
| |
|
| | |
| | F_CL_pooled = F.avg_pool2d(F_half, kernel_size=(4, 1)) |
| | F_CL = self.v_local_conv(F_CL_pooled) |
| | F_CL = F_CL.squeeze(1) |
| | F_CL = F_CL.transpose(1, 2) |
| |
|
| | |
| | F_CG_L = torch.cat([F_CG, F_CL], dim=2) |
| |
|
| | |
| | F_CG_L = F_CG_L + self.v_pos_embed |
| |
|
| | |
| | F_R = self.h_transformer(F_RG_L) |
| | F_C = self.v_transformer(F_CG_L) |
| |
|
| | |
| | h_logits = self.h_classifier(F_R).squeeze(-1) |
| | v_logits = self.v_classifier(F_C).squeeze(-1) |
| |
|
| | |
| | return torch.sigmoid(h_logits), torch.sigmoid(v_logits) |
| |
|
| | def focal_loss(predictions, targets, alpha=1.0, gamma=2.0): |
| | """Focal loss exactly as specified in paper""" |
| | ce_loss = F.binary_cross_entropy(predictions, targets, reduction='none') |
| | pt = torch.where(targets == 1, predictions, 1 - predictions) |
| | focal_weight = alpha * (1 - pt) ** gamma |
| | return (focal_weight * ce_loss).mean() |
| |
|
| | def post_process_predictions(h_pred, v_pred, threshold=0.5): |
| | """ |
| | Simple post-processing to convert predictions to binary masks |
| | """ |
| | h_binary = (h_pred > threshold).float() |
| | v_binary = (v_pred > threshold).float() |
| |
|
| | return h_binary, v_binary |
| |
|
| | class TableDataset(Dataset): |
| | def __init__(self, hf_dataset): |
| | self.hf_dataset = hf_dataset |
| | self.transform = transforms.Compose([ |
| | transforms.Resize((960, 960)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | def __len__(self): |
| | return len(self.hf_dataset) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.hf_dataset[idx] |
| |
|
| | image = item['image'].convert('RGB') |
| | image_transformed = self.transform(image) |
| |
|
| | |
| | h_gt_960, v_gt_960 = get_ground_truth_auto_gap( |
| | item['image'], |
| | item['cells'], |
| | item['otsl'], |
| | ) |
| |
|
| | |
| | h_gt_480 = [h_gt_960[i] for i in range(0, 960, 2)] |
| | v_gt_480 = [v_gt_960[i] for i in range(0, 960, 2)] |
| |
|
| | return ( |
| | image_transformed, |
| | torch.tensor(h_gt_480, dtype=torch.float), |
| | torch.tensor(v_gt_480, dtype=torch.float), |
| | torch.tensor(h_gt_960, dtype=torch.float), |
| | torch.tensor(v_gt_960, dtype=torch.float), |
| | ) |
| |
|
| |
|
| |
|