Files
saw_mill_knot_detection/annotation_gui.py

1271 lines
52 KiB
Python
Raw Normal View History

"""
Simple customizable annotation GUI with auto-labeling support.
Built with Gradio - easy to modify and extend.
Run: python annotation_gui.py
To set default paths, edit config.py
"""
from __future__ import annotations
import argparse
2025-12-23 16:55:59 -07:00
import html
import json
import subprocess
import threading
from pathlib import Path
from typing import Any
import gradio as gr
from PIL import Image, ImageDraw
# Try to load config, use fallbacks if not available
try:
from config import (
DEFAULT_IMAGES_DIR, DEFAULT_MODEL_WEIGHTS, DEFAULT_PORT,
DEFAULT_DETECTION_THRESHOLD, DEFAULT_TRAIN_EPOCHS,
DEFAULT_BATCH_SIZE, DEFAULT_LEARNING_RATE, DEFAULT_MODEL_SIZE
)
except ImportError:
DEFAULT_IMAGES_DIR = None
DEFAULT_MODEL_WEIGHTS = None
DEFAULT_PORT = 7860
DEFAULT_DETECTION_THRESHOLD = 0.5
DEFAULT_TRAIN_EPOCHS = 20
DEFAULT_BATCH_SIZE = 4
DEFAULT_LEARNING_RATE = 1e-4
DEFAULT_MODEL_SIZE = "small"
2025-12-23 17:04:59 -07:00
# Gradio 6 sanitizes <script> tags inside gr.HTML content. Also, js_on_load only
# runs when the component is first mounted, not when its HTML updates. We
# therefore install a global initializer (via demo.launch(js=...)) and have
# js_on_load call into it when available.
2025-12-23 16:55:59 -07:00
CANVAS_JS_ON_LOAD = r"""
(() => {
2025-12-23 17:04:59 -07:00
if (window.__initAnnotationCanvas) {
window.__initAnnotationCanvas(element);
}
})();
"""
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
CANVAS_GLOBAL_JS = r"""
(() => {
function init(root) {
if (!root) return;
const canvas = root.querySelector('#annotation-canvas');
const imgEl = root.querySelector('#annotation-img');
const initialBoxesEl = root.querySelector('#annotation-initial-boxes');
if (!canvas || !imgEl || !initialBoxesEl) return;
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
// Ensure we don't double-bind if Gradio reuses the DOM node.
if (canvas.dataset.bound === '1') {
// Still redraw in case boxes were updated.
if (canvas.__redraw) canvas.__redraw();
return;
}
canvas.dataset.bound = '1';
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
const ctx = canvas.getContext('2d');
const displayWidth = canvas.width;
const displayHeight = canvas.height;
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
let boxes = [];
try {
const raw = initialBoxesEl.value || initialBoxesEl.textContent || '[]';
boxes = JSON.parse(raw);
if (!Array.isArray(boxes)) boxes = [];
} catch (_) {
boxes = [];
}
2025-12-23 16:55:59 -07:00
2025-12-23 17:38:43 -07:00
const hiddenRoot = document.getElementById('canvas-boxes-data');
const hiddenInput = hiddenRoot
? (hiddenRoot.querySelector('textarea, input') || hiddenRoot)
: null;
2025-12-23 17:04:59 -07:00
const syncHidden = () => {
if (!hiddenInput) return;
2025-12-23 17:38:43 -07:00
if (!('value' in hiddenInput)) return;
2025-12-23 17:04:59 -07:00
hiddenInput.value = JSON.stringify(boxes);
hiddenInput.dispatchEvent(new Event('input', { bubbles: true }));
};
syncHidden();
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
let isDragging = false;
let dragStart = null;
let selectedCorner = null;
let selectedBoxIndex = -1;
let creatingBox = false;
let createStart = null;
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
function redraw() {
ctx.clearRect(0, 0, displayWidth, displayHeight);
// Base image is rendered via <img> below the canvas.
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
boxes.forEach((box) => {
const [x1, y1, x2, y2] = box.bbox;
const label = box.label || 'knot';
const conf = box.confidence || 1.0;
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
ctx.strokeStyle = 'red';
ctx.lineWidth = 3;
ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
const handleSize = 6;
ctx.fillStyle = 'red';
ctx.fillRect(x1 - handleSize, y1 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x2 - handleSize, y1 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x1 - handleSize, y2 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x2 - handleSize, y2 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillStyle = 'red';
ctx.font = '16px Arial';
const text = conf < 1.0 ? `${label} ${conf.toFixed(2)}` : label;
ctx.fillText(text, x1, y1 - 5);
});
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
if (creatingBox && createStart && dragStart) {
ctx.strokeStyle = 'blue';
ctx.lineWidth = 2;
ctx.setLineDash([5, 5]);
const x = Math.min(createStart.x, dragStart.x);
const y = Math.min(createStart.y, dragStart.y);
const w = Math.abs(createStart.x - dragStart.x);
const h = Math.abs(createStart.y - dragStart.y);
ctx.strokeRect(x, y, w, h);
ctx.setLineDash([]);
}
2025-12-23 16:55:59 -07:00
}
2025-12-23 17:04:59 -07:00
canvas.__redraw = redraw;
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
function getCornerAt(x, y) {
const handleSize = 6;
for (let i = 0; i < boxes.length; i++) {
const [x1, y1, x2, y2] = boxes[i].bbox;
const corners = [
{ x: x1, y: y1, type: 'top-left' },
{ x: x2, y: y1, type: 'top-right' },
{ x: x1, y: y2, type: 'bottom-left' },
{ x: x2, y: y2, type: 'bottom-right' },
];
for (const corner of corners) {
if (
x >= corner.x - handleSize &&
x <= corner.x + handleSize &&
y >= corner.y - handleSize &&
y <= corner.y + handleSize
) {
return { boxIndex: i, corner: corner.type, pos: corner };
}
2025-12-23 16:55:59 -07:00
}
}
2025-12-23 17:04:59 -07:00
return null;
2025-12-23 16:55:59 -07:00
}
2025-12-23 17:04:59 -07:00
imgEl.addEventListener(
'error',
() => {
ctx.clearRect(0, 0, displayWidth, displayHeight);
ctx.fillStyle = '#ffcccc';
ctx.fillRect(0, 0, displayWidth, displayHeight);
ctx.fillStyle = 'black';
ctx.font = '16px Arial';
ctx.fillText('Image failed to load', 10, 30);
},
{ once: true }
);
2025-12-23 16:55:59 -07:00
redraw();
2025-12-23 17:04:59 -07:00
canvas.addEventListener('mousedown', (e) => {
const rect = canvas.getBoundingClientRect();
const x = (e.clientX - rect.left) * (displayWidth / rect.width);
const y = (e.clientY - rect.top) * (displayHeight / rect.height);
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
selectedCorner = getCornerAt(x, y);
if (selectedCorner) {
isDragging = true;
selectedBoxIndex = selectedCorner.boxIndex;
canvas.style.cursor = 'move';
} else {
creatingBox = true;
createStart = { x, y };
dragStart = { x, y };
canvas.style.cursor = 'crosshair';
}
});
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
canvas.addEventListener('mousemove', (e) => {
const rect = canvas.getBoundingClientRect();
const x = (e.clientX - rect.left) * (displayWidth / rect.width);
const y = (e.clientY - rect.top) * (displayHeight / rect.height);
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
if (isDragging && selectedCorner) {
const box = boxes[selectedBoxIndex];
if (selectedCorner.corner === 'top-left') {
box.bbox[0] = Math.min(x, box.bbox[2] - 10);
box.bbox[1] = Math.min(y, box.bbox[3] - 10);
} else if (selectedCorner.corner === 'top-right') {
box.bbox[2] = Math.max(x, box.bbox[0] + 10);
box.bbox[1] = Math.min(y, box.bbox[3] - 10);
} else if (selectedCorner.corner === 'bottom-left') {
box.bbox[0] = Math.min(x, box.bbox[2] - 10);
box.bbox[3] = Math.max(y, box.bbox[1] + 10);
} else if (selectedCorner.corner === 'bottom-right') {
box.bbox[2] = Math.max(x, box.bbox[0] + 10);
box.bbox[3] = Math.max(y, box.bbox[1] + 10);
}
syncHidden();
redraw();
return;
}
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
if (creatingBox && createStart) {
dragStart = { x, y };
redraw();
return;
}
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
const corner = getCornerAt(x, y);
canvas.style.cursor = corner ? 'move' : 'crosshair';
});
canvas.addEventListener('mouseup', () => {
if (creatingBox && createStart && dragStart) {
const x1 = Math.min(createStart.x, dragStart.x);
const y1 = Math.min(createStart.y, dragStart.y);
const x2 = Math.max(createStart.x, dragStart.x);
const y2 = Math.max(createStart.y, dragStart.y);
2025-12-23 17:38:43 -07:00
const w = x2 - x1;
const h = y2 - y1;
if (w > 10 && h > 10) {
boxes.push({ bbox: [x1, y1, x2, y2], label: 'knot', confidence: 1.0, source: 'manual' });
} else {
// Click without drag: create a default-size box around the click.
const size = 120;
const cx = createStart.x;
const cy = createStart.y;
const bx1 = Math.max(0, cx - size / 2);
const by1 = Math.max(0, cy - size / 2);
const bx2 = Math.min(displayWidth, cx + size / 2);
const by2 = Math.min(displayHeight, cy + size / 2);
boxes.push({ bbox: [bx1, by1, bx2, by2], label: 'knot', confidence: 1.0, source: 'manual' });
2025-12-23 17:04:59 -07:00
}
2025-12-23 17:38:43 -07:00
syncHidden();
redraw();
2025-12-23 16:55:59 -07:00
}
2025-12-23 17:04:59 -07:00
isDragging = false;
creatingBox = false;
selectedCorner = null;
selectedBoxIndex = -1;
createStart = null;
dragStart = null;
canvas.style.cursor = 'crosshair';
});
}
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
window.__initAnnotationCanvas = init;
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
function scan() {
document.querySelectorAll('[data-annotation-root="1"]').forEach((root) => init(root));
}
2025-12-23 16:55:59 -07:00
2025-12-23 17:04:59 -07:00
const obs = new MutationObserver(() => scan());
obs.observe(document.documentElement, { childList: true, subtree: true });
window.addEventListener('load', () => scan());
2025-12-23 17:38:43 -07:00
scan();
setTimeout(() => scan(), 0);
setTimeout(() => scan(), 100);
setTimeout(() => scan(), 250);
setTimeout(() => scan(), 500);
setTimeout(() => scan(), 1000);
setTimeout(() => scan(), 2000);
2025-12-23 16:55:59 -07:00
})();
"""
class AnnotationApp:
def __init__(self, images_dir: Path | None = None, model_weights: Path | None = None):
self.images_dir = images_dir if images_dir else Path.cwd()
self.current_model_path = model_weights
self.current_model_type = None # Track model type: 'rf-detr', 'rt-detr', 'yolov6', 'yolox'
self.available_models = [] # List of discovered models for quick switching
self.image_paths = []
self.current_idx = 0
self.annotations = {} # image_name -> list of boxes
self.model = None
# Load images if directory provided
if images_dir and images_dir.exists():
self._load_images(images_dir)
if model_weights and model_weights.exists():
self._load_model(model_weights)
def _load_images(self, images_dir: Path):
"""Load images from directory."""
self.images_dir = images_dir
self.image_paths = sorted(
list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
)
self.current_idx = 0
# Load existing annotations if present
self.ann_file = images_dir / "annotations.json"
if self.ann_file.exists():
with self.ann_file.open("r") as f:
self.annotations = json.load(f)
else:
self.annotations = {}
return f"✓ Loaded {len(self.image_paths)} images from {images_dir}"
def find_best_weights(self, directory: Path) -> tuple[Path | None, str | None]:
"""Find the best weights file in a directory based on model type detection."""
if not directory.exists():
return None, None
# Check for RF-DETR weights (checkpoint_best_total.pth)
rf_detr_weights = directory / "checkpoint_best_total.pth"
if rf_detr_weights.exists():
return rf_detr_weights, "rf-detr"
# Check for Ultralytics weights (best.pt) in weights/ subdirectory
ultralytics_weights = directory / "weights" / "best.pt"
if ultralytics_weights.exists():
# Try to determine specific type from directory name or other clues
dir_name = directory.name.lower()
if "rtdetr" in dir_name:
return ultralytics_weights, "rt-detr"
elif "yolov6" in dir_name:
return ultralytics_weights, "yolov6"
elif "yolox" in dir_name:
return ultralytics_weights, "yolox"
else:
# Default to rt-detr for ultralytics models
return ultralytics_weights, "rt-detr"
# Check for Ultralytics weights in training/weights/ subdirectory (YOLOv6/YOLOX format)
training_weights = directory / "training" / "weights" / "best.pt"
if training_weights.exists():
dir_name = directory.name.lower()
if "rtdetr" in dir_name:
return training_weights, "rt-detr"
elif "yolov6" in dir_name:
return training_weights, "yolov6"
elif "yolox" in dir_name:
return training_weights, "yolox"
else:
# Default to yolox for training/weights structure
return training_weights, "yolox"
# Check for direct best.pt in directory
direct_best = directory / "best.pt"
if direct_best.exists():
return direct_best, "rt-detr" # Default assumption
# Check for any .pth or .pt files as fallback
pth_files = list(directory.glob("*.pth")) + list(directory.glob("*.pt"))
if pth_files:
# Prefer files with "best" in name
best_files = [f for f in pth_files if "best" in f.name.lower()]
if best_files:
return best_files[0], self._guess_model_type_from_path(best_files[0])
else:
return pth_files[0], self._guess_model_type_from_path(pth_files[0])
return None, None
def _guess_model_type_from_path(self, path: Path) -> str:
"""Guess model type from file path."""
path_str = str(path).lower()
if "rf" in path_str or "checkpoint" in path_str:
return "rf-detr"
elif "rtdetr" in path_str:
return "rt-detr"
elif "yolov6" in path_str:
return "yolov6"
elif "yolox" in path_str:
return "yolox"
else:
return "rt-detr" # Default
def _load_model(self, weights_path: Path, model_type: str = None):
"""Load model for auto-labeling based on type."""
try:
import torch
if model_type is None:
model_type = self._guess_model_type_from_path(weights_path)
print(f"Loading {model_type} model from {weights_path}...")
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if model_type == "rf-detr":
# RF-DETR uses custom loader - try different model sizes
from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall
# Try to determine model size from checkpoint or use nano as default
checkpoint = torch.load(weights_path, map_location='cpu', weights_only=False)
if 'model' in checkpoint:
# Training checkpoint - check the model size from the state dict
state_dict = checkpoint['model']
# Look for clues about model size in the state dict keys
if any('backbone.0.encoder.encoder.embeddings.position_embeddings' in key for key in state_dict.keys()):
# Try different model sizes to find the right one
models_to_try = [
("nano", RFDETRNano),
("small", RFDETRSmall),
("medium", RFDETRMedium),
("base", RFDETRBase)
]
for size_name, model_class in models_to_try:
try:
self.model = model_class()
# Try loading with strict=False to handle mismatches
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
if len(missing_keys) < len(self.model.state_dict()): # Some keys matched
print(f"✓ Loaded RF-DETR {size_name} model (with {len(missing_keys)} missing keys)")
# Move to GPU if available
self.model = self.model.to(device)
break
except Exception as e:
continue
else:
raise Exception("Could not load checkpoint with any RF-DETR model size")
else:
# Direct weights file
self.model = RFDETRNano(pretrain_weights=str(weights_path))
self.model = self.model.to(device)
else:
# Direct weights file
self.model = RFDETRNano(pretrain_weights=str(weights_path))
self.model = self.model.to(device)
else:
# RT-DETR, YOLOv6, YOLOX all use Ultralytics
if model_type == "rt-detr":
from ultralytics import RTDETR
self.model = RTDETR(str(weights_path))
else:
from ultralytics import YOLO
self.model = YOLO(str(weights_path))
# Ultralytics models should automatically use GPU if available
# but let's ensure they're on the right device
if hasattr(self.model, 'to'):
self.model = self.model.to(device)
self.current_model_path = weights_path
self.current_model_type = model_type
# Add to available models for quick switching
model_display = f"Custom: {weights_path.name} ({model_type.upper()})"
existing_model = next((m for m in self.available_models if m['path'] == weights_path), None)
if not existing_model:
self.available_models.append({
"path": weights_path,
"type": model_type,
"dir": weights_path.parent.name,
"display": model_display
})
print("✓ Model loaded")
return f"{model_type.upper()} model loaded from {weights_path.name}"
except Exception as e:
error_msg = f"⚠ Could not load {model_type or 'model'}: {e}"
print(error_msg)
self.model = None
self.current_model_type = None
return error_msg
def load_new_model(self, weights_path: str, model_type: str = "Auto-detect") -> str:
"""Load a new model from the GUI."""
path = Path(weights_path)
if not path.exists():
return f"❌ File not found: {weights_path}"
# Convert dropdown value to internal type
if model_type == "Auto-detect":
model_type = None
elif model_type == "rf-detr":
model_type = "rf-detr"
elif model_type == "rt-detr":
model_type = "rt-detr"
elif model_type == "yolov6":
model_type = "yolov6"
elif model_type == "yolox":
model_type = "yolox"
return self._load_model(path, model_type)
def load_model_from_directory(self, directory_path: str) -> str:
"""Load the best model found in a directory."""
path = Path(directory_path)
if not path.exists():
return f"❌ Directory not found: {directory_path}"
if not path.is_dir():
return f"❌ Not a directory: {directory_path}"
weights_path, detected_type = self.find_best_weights(path)
if weights_path is None:
return f"❌ No model weights found in {directory_path}"
return self._load_model(weights_path, detected_type)
def scan_for_models(self, return_info: bool = True) -> str:
"""Scan for available trained models in common directories."""
runs_dir = Path("runs")
available_models = []
if runs_dir.exists():
for subdir in runs_dir.iterdir():
if subdir.is_dir():
weights_path, model_type = self.find_best_weights(subdir)
if weights_path:
available_models.append({
"path": weights_path,
"type": model_type,
"dir": subdir.name,
"display": f"{subdir.name} ({model_type.upper()})"
})
# Store available models for quick access
self.available_models = available_models
if not return_info:
return ""
if not available_models:
return "❌ No trained models found in 'runs/' directory"
# Format as readable list
lines = ["📂 Available Models:"]
for i, model in enumerate(available_models, 1):
2025-12-23 17:38:43 -07:00
lines.append(f"{i}. {model['dir']} -> {model['path'].name} ({model['type'].upper()})")
lines.append("\n💡 Use the Model Selector dropdown above to quickly switch models")
return "\n".join(lines)
def get_available_models_list(self) -> list:
"""Get list of available models for dropdown."""
if not self.available_models:
self.scan_for_models(return_info=False) # This will populate self.available_models
if not self.available_models:
return ["No models found - click '🔍 Scan for Models'"]
return [model['display'] for model in self.available_models]
def load_model_by_index(self, model_display: str) -> str:
"""Load a model by its display name from the available models list."""
if not hasattr(self, 'available_models') or not self.available_models:
return "❌ No models available. Click '🔍 Scan for Models' first."
for model in self.available_models:
if model['display'] == model_display:
return self._load_model(model['path'], model['type'])
return f"❌ Model '{model_display}' not found"
2025-12-23 17:38:43 -07:00
def load_new_images_dir(self, images_dir: str) -> tuple[str, str, str, str]:
"""Load a new images directory from the GUI."""
path = Path(images_dir)
if not path.exists():
2025-12-23 17:38:43 -07:00
return "<div>Directory not found</div>", "", "Image: -/-", f"❌ Directory not found: {images_dir}"
if not path.is_dir():
2025-12-23 17:38:43 -07:00
return "<div>Not a directory</div>", "", "Image: -/-", f"❌ Not a directory: {images_dir}"
result = self._load_images(path)
# Load first image
if self.image_paths:
img, filename = self.get_current_image()
boxes = self.annotations.get(filename, [])
2025-12-23 16:17:19 -07:00
img_html = self.generate_interactive_canvas(boxes)
2025-12-23 17:38:43 -07:00
boxes_html = self._format_boxes_html(boxes)
image_label = self._current_image_label(filename)
status = result
return img_html, boxes_html, image_label, status
else:
2025-12-23 17:38:43 -07:00
return "<div>No images found</div>", "", "Image: -/-", f"{result}\n⚠️ No .jpg or .png images found in directory"
def get_current_model_info(self) -> str:
"""Get info about currently loaded model."""
if self.model and self.current_model_path:
type_info = f" ({self.current_model_type.upper()})" if self.current_model_type else ""
return f"📦 Loaded: {self.current_model_path}{type_info}"
elif self.model:
return "📦 Model loaded (pretrained)"
else:
return "⚠️ No model loaded"
def get_current_dir_info(self) -> str:
"""Get info about current images directory."""
return f"📁 {self.images_dir} ({len(self.image_paths)} images)"
2025-12-23 17:38:43 -07:00
def _current_image_label(self, filename: str) -> str:
"""Stable image index display (kept separate from status messages)."""
if not filename or not self.image_paths:
return "Image: -/-"
return f"Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
def get_current_image(self) -> tuple[Image.Image, str]:
"""Get current image and filename."""
if not self.image_paths:
return None, ""
path = self.image_paths[self.current_idx]
img = Image.open(path).convert("RGB")
return img, path.name
def draw_boxes_on_image(self, img: Image.Image, boxes: list[dict]) -> Image.Image:
"""Draw bounding boxes on image."""
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
for box in boxes:
x1, y1, x2, y2 = box["bbox"]
label = box.get("label", "knot")
conf = box.get("confidence", 1.0)
# Draw box
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
2025-12-23 16:17:19 -07:00
# Draw corner handles for editing (small squares)
handle_size = 6
draw.rectangle([x1-handle_size, y1-handle_size, x1+handle_size, y1+handle_size], fill="red")
draw.rectangle([x2-handle_size, y1-handle_size, x2+handle_size, y1+handle_size], fill="red")
draw.rectangle([x1-handle_size, y2-handle_size, x1+handle_size, y2+handle_size], fill="red")
draw.rectangle([x2-handle_size, y2-handle_size, x2+handle_size, y2+handle_size], fill="red")
# Draw label
text = f"{label} {conf:.2f}" if conf < 1.0 else label
draw.text((x1, y1 - 20), text, fill="red")
return img_draw
2025-12-23 16:17:19 -07:00
def generate_interactive_canvas(self, boxes: list[dict] = None) -> str:
"""Generate HTML with interactive canvas for annotation."""
img, filename = self.get_current_image()
if not img:
2025-12-23 16:17:19 -07:00
return "<div>No image loaded</div>"
2025-12-23 16:17:19 -07:00
if boxes is None:
boxes = self.annotations.get(filename, [])
# Resize image for display if too large
max_width = 1200
max_height = 800
if img.width > max_width or img.height > max_height:
ratio = min(max_width / img.width, max_height / img.height)
display_width = int(img.width * ratio)
display_height = int(img.height * ratio)
img = img.resize((display_width, display_height), Image.Resampling.LANCZOS)
else:
display_width = img.width
display_height = img.height
# Convert PIL image to base64
import base64
from io import BytesIO
buffer = BytesIO()
img.save(buffer, format="PNG", optimize=True)
img_base64 = base64.b64encode(buffer.getvalue()).decode()
2025-12-23 16:55:59 -07:00
# Build HTML without <script> tags (Gradio sanitizes them). The interactive
# canvas logic runs via CANVAS_JS_ON_LOAD.
2025-12-23 16:17:19 -07:00
boxes_json = json.dumps(boxes)
2025-12-23 16:55:59 -07:00
boxes_escaped = html.escape(boxes_json)
html_out = f"""
<div style="display: inline-block; border: 1px solid #ccc; padding: 5px;">
<div style="font-size: 12px; color: #666; margin-bottom: 4px;">Canvas Size: {display_width}x{display_height}</div>
<textarea id="annotation-initial-boxes" style="display:none;">{boxes_escaped}</textarea>
2025-12-23 17:04:59 -07:00
<div data-annotation-root="1" style="position: relative; width: {display_width}px; height: {display_height}px;">
2025-12-23 16:55:59 -07:00
<img id="annotation-img" src="data:image/png;base64,{img_base64}"
2025-12-23 17:38:43 -07:00
style="position:absolute; left:0; top:0; width:{display_width}px; height:{display_height}px; z-index: 1;" />
2025-12-23 16:55:59 -07:00
<canvas id="annotation-canvas" width="{display_width}" height="{display_height}"
2025-12-23 17:38:43 -07:00
style="position:absolute; left:0; top:0; width:{display_width}px; height:{display_height}px; cursor: crosshair; background: transparent; z-index: 2; pointer-events: auto;"></canvas>
2025-12-23 16:17:19 -07:00
</div>
</div>
"""
2025-12-23 16:55:59 -07:00
return html_out
2025-12-23 17:38:43 -07:00
def _format_boxes_html(self, boxes: list[dict]) -> str:
"""Format boxes as HTML list with delete buttons."""
if not boxes:
2025-12-23 17:38:43 -07:00
return "<div style='color: #999; font-style: italic;'>No annotations</div>"
2025-12-23 17:38:43 -07:00
lines = ["<div style='font-family: monospace; font-size: 13px;'>"]
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box["bbox"]
conf = box.get("confidence", 1.0)
source = box.get("source", "manual")
2025-12-23 17:38:43 -07:00
lines.append(
f"<div style='margin: 4px 0; padding: 4px; background: #f5f5f5; border-radius: 3px; display: flex; justify-content: space-between; align-items: center;'>"
f"<span style='flex: 1;'>{i}: [{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}] {conf:.2f} ({source})</span>"
f"<button onclick='window.deleteBox({i})' style='background: #ff4444; color: white; border: none; padding: 2px 8px; border-radius: 3px; cursor: pointer; margin-left: 8px;'>✕</button>"
f"</div>"
)
lines.append("</div>")
return "".join(lines)
def _parse_boxes_text(self, text: str) -> list[dict] | None:
"""Parse edited JSON from the Annotations textbox."""
if not text:
return None
try:
data = json.loads(text)
except json.JSONDecodeError:
return None
if not isinstance(data, list):
return None
cleaned: list[dict] = []
for item in data:
if not isinstance(item, dict) or "bbox" not in item:
continue
bbox = item.get("bbox")
if not (isinstance(bbox, list) and len(bbox) == 4):
continue
try:
x1, y1, x2, y2 = [float(v) for v in bbox]
except Exception:
continue
cleaned.append({
"bbox": [x1, y1, x2, y2],
"label": item.get("label", "knot"),
"confidence": float(item.get("confidence", 1.0)),
"source": item.get("source", "manual"),
})
return cleaned
2025-12-23 17:38:43 -07:00
def load_image(self, direction: str = "current") -> tuple[str, str, str, str]:
"""Load image (current/next/prev)."""
if direction == "next":
self.current_idx = min(self.current_idx + 1, len(self.image_paths) - 1)
elif direction == "prev":
self.current_idx = max(self.current_idx - 1, 0)
img, filename = self.get_current_image()
if not img:
2025-12-23 17:38:43 -07:00
return "<div>No images</div>", "", "Image: -/-", "No images"
# Load existing annotations
boxes = self.annotations.get(filename, [])
2025-12-23 16:17:19 -07:00
img_html = self.generate_interactive_canvas(boxes)
2025-12-23 17:38:43 -07:00
boxes_html = self._format_boxes_html(boxes)
image_label = self._current_image_label(filename)
return img_html, boxes_html, image_label, ""
2025-12-23 17:38:43 -07:00
def add_box_manual(self, x1: int, y1: int, x2: int, y2: int) -> tuple[str, str, str, str]:
"""Manually add a bounding box."""
img, filename = self.get_current_image()
if not img:
2025-12-23 17:38:43 -07:00
return "<div>No images</div>", "", "Image: -/-", "No images"
# Add box
box = {
"bbox": [float(x1), float(y1), float(x2), float(y2)],
"label": "knot",
"confidence": 1.0,
"source": "manual"
}
if filename not in self.annotations:
self.annotations[filename] = []
self.annotations[filename].append(box)
self._save_annotations()
# Redraw
boxes = self.annotations[filename]
2025-12-23 16:17:19 -07:00
img_html = self.generate_interactive_canvas(boxes)
2025-12-23 17:38:43 -07:00
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), f"✓ Added box: {len(boxes)} total"
2025-12-23 17:38:43 -07:00
def delete_last_box(self) -> tuple[str, str, str, str]:
"""Delete the last box from current image."""
img, filename = self.get_current_image()
if not img:
2025-12-23 17:38:43 -07:00
return "<div>No images</div>", "", "Image: -/-", "No images"
if filename in self.annotations and self.annotations[filename]:
self.annotations[filename].pop()
self._save_annotations()
# Redraw
boxes = self.annotations.get(filename, [])
2025-12-23 16:17:19 -07:00
img_html = self.generate_interactive_canvas(boxes)
2025-12-23 17:38:43 -07:00
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), f"✓ Deleted last box: {len(boxes)} remaining"
def delete_box_by_index(self, index: int) -> tuple[str, str, str, str]:
"""Delete a specific box by index."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
2025-12-23 17:38:43 -07:00
boxes = self.annotations.get(filename, [])
if 0 <= index < len(boxes):
boxes.pop(index)
self.annotations[filename] = boxes
self._save_annotations()
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), f"✓ Deleted box {index}"
else:
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), "❌ Invalid box index"
2025-12-23 17:38:43 -07:00
def clear_boxes(self) -> tuple[str, str, str, str]:
"""Clear all boxes from current image."""
img, filename = self.get_current_image()
if not img:
2025-12-23 17:38:43 -07:00
return "<div>No images</div>", "", "Image: -/-", "No images"
self.annotations[filename] = []
self._save_annotations()
2025-12-23 16:17:19 -07:00
boxes = []
img_html = self.generate_interactive_canvas(boxes)
2025-12-23 17:38:43 -07:00
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), "✓ Cleared all boxes"
2025-12-23 17:38:43 -07:00
def auto_label_current(self, threshold: float = 0.5) -> tuple[str, str, str, str]:
"""Auto-label current image using loaded model."""
if not self.model:
2025-12-23 17:38:43 -07:00
return "<div>No model loaded</div>", "", "Image: -/-", "❌ No model loaded"
img, filename = self.get_current_image()
if not img:
2025-12-23 17:38:43 -07:00
return "<div>No images</div>", "", "Image: -/-", "No images"
try:
# Run inference based on model type
if self.current_model_type == "rf-detr":
# RF-DETR custom prediction
detections = self.model.predict(img, threshold=threshold)
boxes = []
for i in range(len(detections)):
xyxy = detections.xyxy[i]
conf = float(detections.confidence[i]) if detections.confidence is not None else 1.0
x1, y1, x2, y2 = xyxy
boxes.append({
"bbox": [float(x1), float(y1), float(x2), float(y2)],
"label": "knot",
"confidence": conf,
"source": "auto"
})
else:
# Ultralytics models (RT-DETR, YOLOv6, YOLOX)
results = self.model.predict(
source=img,
conf=threshold,
save=False,
verbose=False
)
boxes = []
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
conf = float(box.conf[0])
boxes.append({
"bbox": [x1, y1, x2, y2],
"label": "knot",
"confidence": conf,
"source": "auto"
})
# Add to existing annotations
if filename not in self.annotations:
self.annotations[filename] = []
self.annotations[filename].extend(boxes)
self._save_annotations()
# Redraw
2025-12-23 16:17:19 -07:00
img_html = self.generate_interactive_canvas(self.annotations[filename])
2025-12-23 17:38:43 -07:00
boxes_html = self._format_boxes_html(self.annotations[filename])
return img_html, boxes_html, self._current_image_label(filename), f"🤖 Auto-labeled: {len(boxes)} detections"
except Exception as e:
2025-12-23 16:17:19 -07:00
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
2025-12-23 17:38:43 -07:00
return img_html, self._format_boxes_html(boxes), self._current_image_label(filename), f"❌ Auto-label failed: {e}"
2025-12-23 16:17:19 -07:00
2025-12-23 17:38:43 -07:00
def save_canvas_changes(self, boxes_json: str) -> tuple[str, str, str, str]:
"""Auto-save changes made in the canvas."""
2025-12-23 16:17:19 -07:00
img, filename = self.get_current_image()
if not img:
2025-12-23 17:38:43 -07:00
return "<div>No images</div>", "", "Image: -/-", "No images"
2025-12-23 16:17:19 -07:00
try:
if boxes_json:
boxes = json.loads(boxes_json)
self.annotations[filename] = boxes
self._save_annotations()
2025-12-23 17:38:43 -07:00
status = ""
2025-12-23 16:17:19 -07:00
else:
boxes = self.annotations.get(filename, [])
2025-12-23 17:38:43 -07:00
status = ""
2025-12-23 16:17:19 -07:00
except json.JSONDecodeError:
boxes = self.annotations.get(filename, [])
2025-12-23 17:38:43 -07:00
status = "❌ Invalid canvas data"
2025-12-23 16:17:19 -07:00
img_html = self.generate_interactive_canvas(boxes)
2025-12-23 17:38:43 -07:00
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), status
def _save_annotations(self):
"""Save annotations to JSON file."""
with self.ann_file.open("w") as f:
json.dump(self.annotations, f, indent=2)
def export_to_coco(self, output_path: Path):
"""Export annotations to COCO format."""
coco_data = {
"images": [],
"annotations": [],
"categories": [{"id": 0, "name": "knot", "supercategory": "defect"}]
}
ann_id = 0
for img_id, img_path in enumerate(self.image_paths):
filename = img_path.name
img = Image.open(img_path)
width, height = img.size
coco_data["images"].append({
"id": img_id,
"file_name": filename,
"width": width,
"height": height
})
# Add annotations
boxes = self.annotations.get(filename, [])
for box in boxes:
2025-12-23 18:12:01 -07:00
if isinstance(box, dict) and "bbox" in box:
x1, y1, x2, y2 = box["bbox"]
score = box.get("confidence", 1.0)
else:
# Backward/experimental compatibility: [x1, y1, x2, y2]
x1, y1, x2, y2 = box
score = 1.0
w = x2 - x1
h = y2 - y1
coco_data["annotations"].append({
"id": ann_id,
"image_id": img_id,
"category_id": 0,
"bbox": [x1, y1, w, h],
"area": w * h,
"iscrowd": 0,
2025-12-23 18:12:01 -07:00
"score": score
})
ann_id += 1
with output_path.open("w") as f:
json.dump(coco_data, f, indent=2)
return f"✓ Exported {len(coco_data['annotations'])} annotations to {output_path}"
2025-12-23 12:53:52 -07:00
def get_model_path_from_display(self, model_display: str) -> Path | None:
"""Get the actual model path from a display name."""
if not hasattr(self, 'available_models') or not self.available_models:
return None
for model in self.available_models:
if model['display'] == model_display:
return model['path']
return None
def create_ui(app: AnnotationApp) -> gr.Blocks:
"""Create Gradio UI."""
with gr.Blocks(title="Knot Annotation Tool") as demo:
gr.Markdown("""
2025-12-23 17:38:43 -07:00
# Wood Knot Annotation Tool
2025-12-23 18:12:01 -07:00
**Label -> Auto-Label -> Export**
- Manually annotate images or use **Auto-Label** with your trained model
2025-12-23 18:12:01 -07:00
- Export annotations to COCO format for training
- Use separate training and deployment scripts for model development
""")
# Settings section at the top
2025-12-23 17:38:43 -07:00
with gr.Accordion("Settings", open=False):
with gr.Row():
with gr.Column():
images_dir_input = gr.Textbox(
label="Images Directory",
value=str(app.images_dir),
placeholder="/path/to/images"
)
load_images_btn = gr.Button("📁 Load Images Directory")
dir_info = gr.Textbox(label="Current Directory", value=app.get_current_dir_info(), interactive=False)
with gr.Column():
# Quick Model Selector
model_selector = gr.Dropdown(
choices=app.get_available_models_list(),
value=None,
label="🚀 Quick Model Switcher",
info="Select from available trained models (refresh with scan)",
allow_custom_value=True
)
quick_load_btn = gr.Button("⚡ Load Selected Model", variant="primary")
# Manual Model Loading
model_weights_input = gr.Textbox(
label="Model Weights Path",
value=str(app.current_model_path) if app.current_model_path else "",
placeholder="runs/training/checkpoint_best_total.pth"
)
model_type_dropdown = gr.Dropdown(
choices=["Auto-detect", "rf-detr", "rt-detr", "yolov6", "yolox"],
value="Auto-detect",
label="Model Type",
info="Auto-detect will try to determine from file path"
)
with gr.Row():
load_model_btn = gr.Button("🤖 Load Model Weights")
scan_models_btn = gr.Button("🔍 Scan for Models")
model_info = gr.Textbox(label="Current Model", value=app.get_current_model_info(), interactive=False)
with gr.Row():
with gr.Column(scale=3):
2025-12-23 16:55:59 -07:00
image_display = gr.HTML(
label="Current Image",
value="<div style='width: 800px; height: 400px; border: 1px solid #ccc; display: flex; align-items: center; justify-content: center; color: #666;'>Load images from Settings to start annotating</div>",
js_on_load=CANVAS_JS_ON_LOAD,
)
with gr.Row():
prev_btn = gr.Button("⬅️ Previous")
next_btn = gr.Button("Next ➡️")
auto_label_btn = gr.Button("🤖 Auto-Label", variant="primary")
2025-12-23 16:17:19 -07:00
# Hidden textbox to store canvas boxes data
2025-12-23 16:55:59 -07:00
canvas_boxes_data = gr.Textbox(visible=False, elem_id="canvas-boxes-data")
with gr.Row():
threshold_slider = gr.Slider(0.1, 0.9, DEFAULT_DETECTION_THRESHOLD, label="Detection Threshold")
with gr.Column(scale=1):
2025-12-23 17:38:43 -07:00
image_index_text = gr.Textbox(label="Image", lines=1, interactive=False)
info_text = gr.Textbox(label="Status", lines=2)
2025-12-23 17:38:43 -07:00
boxes_html = gr.HTML(label="Annotations")
delete_box_index = gr.Number(visible=False, value=-1)
gr.Markdown("### Manual Annotation")
with gr.Row():
x1_input = gr.Number(label="x1", value=100)
y1_input = gr.Number(label="y1", value=100)
with gr.Row():
x2_input = gr.Number(label="x2", value=200)
y2_input = gr.Number(label="y2", value=200)
add_box_btn = gr.Button(" Add Box")
delete_btn = gr.Button("🗑️ Delete Last")
clear_btn = gr.Button("❌ Clear All")
2025-12-23 18:12:01 -07:00
gr.Markdown("### Export Annotations")
export_path = gr.Textbox(
label="Export Path",
value="annotations_coco.json"
)
2025-12-23 17:38:43 -07:00
export_btn = gr.Button("Export COCO")
export_result = gr.Textbox(label="Export Result", lines=1)
# Event handlers
def on_load():
return app.load_image("current")
# Settings handlers
load_images_btn.click(
app.load_new_images_dir,
inputs=[images_dir_input],
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
).then(
lambda: (app.get_current_dir_info(), app.get_current_model_info()),
outputs=[dir_info, model_info]
)
load_model_btn.click(
app.load_new_model,
inputs=[model_weights_input, model_type_dropdown],
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
scan_models_btn.click(
app.scan_for_models,
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
quick_load_btn.click(
app.load_model_by_index,
inputs=[model_selector],
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
prev_btn.click(
lambda: app.load_image("prev"),
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
)
next_btn.click(
lambda: app.load_image("next"),
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
)
auto_label_btn.click(
lambda t: app.auto_label_current(t),
inputs=[threshold_slider],
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
)
2025-12-23 17:38:43 -07:00
# Auto-save when canvas changes
canvas_boxes_data.change(
2025-12-23 16:17:19 -07:00
app.save_canvas_changes,
inputs=[canvas_boxes_data],
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
)
# Delete box handler (called from HTML button clicks via JS)
delete_box_index.change(
lambda idx: app.delete_box_by_index(int(idx)) if idx >= 0 else (None, None, None, None),
inputs=[delete_box_index],
outputs=[image_display, boxes_html, image_index_text, info_text]
2025-12-23 16:17:19 -07:00
)
add_box_btn.click(
app.add_box_manual,
inputs=[x1_input, y1_input, x2_input, y2_input],
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
)
delete_btn.click(
app.delete_last_box,
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
)
clear_btn.click(
app.clear_boxes,
2025-12-23 17:38:43 -07:00
outputs=[image_display, boxes_html, image_index_text, info_text]
)
export_btn.click(
lambda path: app.export_to_coco(Path(path)),
inputs=[export_path],
outputs=[export_result]
)
# Load first image on start
2025-12-23 17:38:43 -07:00
demo.load(on_load, outputs=[image_display, boxes_html, image_index_text, info_text])
return demo
def main():
parser = argparse.ArgumentParser(description="Simple annotation GUI with auto-labeling")
parser.add_argument(
"--images-dir",
type=Path,
default=Path(DEFAULT_IMAGES_DIR) if DEFAULT_IMAGES_DIR else None,
help="Default directory with images (can be changed in GUI)"
)
parser.add_argument(
"--model-weights",
type=Path,
default=Path(DEFAULT_MODEL_WEIGHTS) if DEFAULT_MODEL_WEIGHTS else None,
help="Default trained model for auto-labeling (can be changed in GUI)"
)
2025-12-23 16:55:59 -07:00
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help="Port to run the GUI on"
)
args = parser.parse_args()
# Validate paths if provided
if args.images_dir and not args.images_dir.exists():
print(f"⚠️ Warning: Images directory not found: {args.images_dir}")
print("You can load a different directory from the GUI Settings")
args.images_dir = None
if args.model_weights and not args.model_weights.exists():
print(f"⚠️ Warning: Model weights not found: {args.model_weights}")
print("You can load different weights from the GUI Settings")
args.model_weights = None
# Create app
app = AnnotationApp(args.images_dir, args.model_weights)
# Scan for available models on startup
app.scan_for_models(return_info=False)
# Create and launch UI
demo = create_ui(app)
print(f"\n{'='*60}")
print(f"🚀 Starting annotation tool...")
if args.images_dir:
print(f"📁 Default images: {args.images_dir} ({len(app.image_paths)} images)")
else:
print(f"📁 No default images - load directory from Settings")
if app.model:
print(f"🤖 Model: Loaded from {args.model_weights}")
else:
print(f"⚠️ No model loaded - load from Settings or train one")
print(f"💡 You can change images directory and model weights from the Settings panel")
print(f"{'='*60}\n")
2025-12-23 17:38:43 -07:00
# Combine canvas JS with delete button handler
combined_js = CANVAS_GLOBAL_JS + r"""
// Wire delete buttons to hidden number input
window.deleteBox = function(index) {
const hiddenInput = document.querySelector('input[type=number][style*=display], input[type=number].\\!hidden');
if (hiddenInput) {
hiddenInput.value = index;
hiddenInput.dispatchEvent(new Event('input', { bubbles: true }));
// Reset after triggering
setTimeout(() => { hiddenInput.value = -1; }, 100);
}
};
"""
demo.launch(
server_name="0.0.0.0",
2025-12-23 16:55:59 -07:00
server_port=args.port,
2025-12-23 17:38:43 -07:00
js=combined_js,
share=False
)
if __name__ == "__main__":
main()