Files
saw_mill_knot_detection/annotation_gui.py
2025-12-23 18:12:01 -07:00

1271 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Simple customizable annotation GUI with auto-labeling support.
Built with Gradio - easy to modify and extend.
Run: python annotation_gui.py
To set default paths, edit config.py
"""
from __future__ import annotations
import argparse
import html
import json
import subprocess
import threading
from pathlib import Path
from typing import Any
import gradio as gr
from PIL import Image, ImageDraw
# Try to load config, use fallbacks if not available
try:
from config import (
DEFAULT_IMAGES_DIR, DEFAULT_MODEL_WEIGHTS, DEFAULT_PORT,
DEFAULT_DETECTION_THRESHOLD, DEFAULT_TRAIN_EPOCHS,
DEFAULT_BATCH_SIZE, DEFAULT_LEARNING_RATE, DEFAULT_MODEL_SIZE
)
except ImportError:
DEFAULT_IMAGES_DIR = None
DEFAULT_MODEL_WEIGHTS = None
DEFAULT_PORT = 7860
DEFAULT_DETECTION_THRESHOLD = 0.5
DEFAULT_TRAIN_EPOCHS = 20
DEFAULT_BATCH_SIZE = 4
DEFAULT_LEARNING_RATE = 1e-4
DEFAULT_MODEL_SIZE = "small"
# Gradio 6 sanitizes <script> tags inside gr.HTML content. Also, js_on_load only
# runs when the component is first mounted, not when its HTML updates. We
# therefore install a global initializer (via demo.launch(js=...)) and have
# js_on_load call into it when available.
CANVAS_JS_ON_LOAD = r"""
(() => {
if (window.__initAnnotationCanvas) {
window.__initAnnotationCanvas(element);
}
})();
"""
CANVAS_GLOBAL_JS = r"""
(() => {
function init(root) {
if (!root) return;
const canvas = root.querySelector('#annotation-canvas');
const imgEl = root.querySelector('#annotation-img');
const initialBoxesEl = root.querySelector('#annotation-initial-boxes');
if (!canvas || !imgEl || !initialBoxesEl) return;
// Ensure we don't double-bind if Gradio reuses the DOM node.
if (canvas.dataset.bound === '1') {
// Still redraw in case boxes were updated.
if (canvas.__redraw) canvas.__redraw();
return;
}
canvas.dataset.bound = '1';
const ctx = canvas.getContext('2d');
const displayWidth = canvas.width;
const displayHeight = canvas.height;
let boxes = [];
try {
const raw = initialBoxesEl.value || initialBoxesEl.textContent || '[]';
boxes = JSON.parse(raw);
if (!Array.isArray(boxes)) boxes = [];
} catch (_) {
boxes = [];
}
const hiddenRoot = document.getElementById('canvas-boxes-data');
const hiddenInput = hiddenRoot
? (hiddenRoot.querySelector('textarea, input') || hiddenRoot)
: null;
const syncHidden = () => {
if (!hiddenInput) return;
if (!('value' in hiddenInput)) return;
hiddenInput.value = JSON.stringify(boxes);
hiddenInput.dispatchEvent(new Event('input', { bubbles: true }));
};
syncHidden();
let isDragging = false;
let dragStart = null;
let selectedCorner = null;
let selectedBoxIndex = -1;
let creatingBox = false;
let createStart = null;
function redraw() {
ctx.clearRect(0, 0, displayWidth, displayHeight);
// Base image is rendered via <img> below the canvas.
boxes.forEach((box) => {
const [x1, y1, x2, y2] = box.bbox;
const label = box.label || 'knot';
const conf = box.confidence || 1.0;
ctx.strokeStyle = 'red';
ctx.lineWidth = 3;
ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);
const handleSize = 6;
ctx.fillStyle = 'red';
ctx.fillRect(x1 - handleSize, y1 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x2 - handleSize, y1 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x1 - handleSize, y2 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x2 - handleSize, y2 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillStyle = 'red';
ctx.font = '16px Arial';
const text = conf < 1.0 ? `${label} ${conf.toFixed(2)}` : label;
ctx.fillText(text, x1, y1 - 5);
});
if (creatingBox && createStart && dragStart) {
ctx.strokeStyle = 'blue';
ctx.lineWidth = 2;
ctx.setLineDash([5, 5]);
const x = Math.min(createStart.x, dragStart.x);
const y = Math.min(createStart.y, dragStart.y);
const w = Math.abs(createStart.x - dragStart.x);
const h = Math.abs(createStart.y - dragStart.y);
ctx.strokeRect(x, y, w, h);
ctx.setLineDash([]);
}
}
canvas.__redraw = redraw;
function getCornerAt(x, y) {
const handleSize = 6;
for (let i = 0; i < boxes.length; i++) {
const [x1, y1, x2, y2] = boxes[i].bbox;
const corners = [
{ x: x1, y: y1, type: 'top-left' },
{ x: x2, y: y1, type: 'top-right' },
{ x: x1, y: y2, type: 'bottom-left' },
{ x: x2, y: y2, type: 'bottom-right' },
];
for (const corner of corners) {
if (
x >= corner.x - handleSize &&
x <= corner.x + handleSize &&
y >= corner.y - handleSize &&
y <= corner.y + handleSize
) {
return { boxIndex: i, corner: corner.type, pos: corner };
}
}
}
return null;
}
imgEl.addEventListener(
'error',
() => {
ctx.clearRect(0, 0, displayWidth, displayHeight);
ctx.fillStyle = '#ffcccc';
ctx.fillRect(0, 0, displayWidth, displayHeight);
ctx.fillStyle = 'black';
ctx.font = '16px Arial';
ctx.fillText('Image failed to load', 10, 30);
},
{ once: true }
);
redraw();
canvas.addEventListener('mousedown', (e) => {
const rect = canvas.getBoundingClientRect();
const x = (e.clientX - rect.left) * (displayWidth / rect.width);
const y = (e.clientY - rect.top) * (displayHeight / rect.height);
selectedCorner = getCornerAt(x, y);
if (selectedCorner) {
isDragging = true;
selectedBoxIndex = selectedCorner.boxIndex;
canvas.style.cursor = 'move';
} else {
creatingBox = true;
createStart = { x, y };
dragStart = { x, y };
canvas.style.cursor = 'crosshair';
}
});
canvas.addEventListener('mousemove', (e) => {
const rect = canvas.getBoundingClientRect();
const x = (e.clientX - rect.left) * (displayWidth / rect.width);
const y = (e.clientY - rect.top) * (displayHeight / rect.height);
if (isDragging && selectedCorner) {
const box = boxes[selectedBoxIndex];
if (selectedCorner.corner === 'top-left') {
box.bbox[0] = Math.min(x, box.bbox[2] - 10);
box.bbox[1] = Math.min(y, box.bbox[3] - 10);
} else if (selectedCorner.corner === 'top-right') {
box.bbox[2] = Math.max(x, box.bbox[0] + 10);
box.bbox[1] = Math.min(y, box.bbox[3] - 10);
} else if (selectedCorner.corner === 'bottom-left') {
box.bbox[0] = Math.min(x, box.bbox[2] - 10);
box.bbox[3] = Math.max(y, box.bbox[1] + 10);
} else if (selectedCorner.corner === 'bottom-right') {
box.bbox[2] = Math.max(x, box.bbox[0] + 10);
box.bbox[3] = Math.max(y, box.bbox[1] + 10);
}
syncHidden();
redraw();
return;
}
if (creatingBox && createStart) {
dragStart = { x, y };
redraw();
return;
}
const corner = getCornerAt(x, y);
canvas.style.cursor = corner ? 'move' : 'crosshair';
});
canvas.addEventListener('mouseup', () => {
if (creatingBox && createStart && dragStart) {
const x1 = Math.min(createStart.x, dragStart.x);
const y1 = Math.min(createStart.y, dragStart.y);
const x2 = Math.max(createStart.x, dragStart.x);
const y2 = Math.max(createStart.y, dragStart.y);
const w = x2 - x1;
const h = y2 - y1;
if (w > 10 && h > 10) {
boxes.push({ bbox: [x1, y1, x2, y2], label: 'knot', confidence: 1.0, source: 'manual' });
} else {
// Click without drag: create a default-size box around the click.
const size = 120;
const cx = createStart.x;
const cy = createStart.y;
const bx1 = Math.max(0, cx - size / 2);
const by1 = Math.max(0, cy - size / 2);
const bx2 = Math.min(displayWidth, cx + size / 2);
const by2 = Math.min(displayHeight, cy + size / 2);
boxes.push({ bbox: [bx1, by1, bx2, by2], label: 'knot', confidence: 1.0, source: 'manual' });
}
syncHidden();
redraw();
}
isDragging = false;
creatingBox = false;
selectedCorner = null;
selectedBoxIndex = -1;
createStart = null;
dragStart = null;
canvas.style.cursor = 'crosshair';
});
}
window.__initAnnotationCanvas = init;
function scan() {
document.querySelectorAll('[data-annotation-root="1"]').forEach((root) => init(root));
}
const obs = new MutationObserver(() => scan());
obs.observe(document.documentElement, { childList: true, subtree: true });
window.addEventListener('load', () => scan());
scan();
setTimeout(() => scan(), 0);
setTimeout(() => scan(), 100);
setTimeout(() => scan(), 250);
setTimeout(() => scan(), 500);
setTimeout(() => scan(), 1000);
setTimeout(() => scan(), 2000);
})();
"""
class AnnotationApp:
def __init__(self, images_dir: Path | None = None, model_weights: Path | None = None):
self.images_dir = images_dir if images_dir else Path.cwd()
self.current_model_path = model_weights
self.current_model_type = None # Track model type: 'rf-detr', 'rt-detr', 'yolov6', 'yolox'
self.available_models = [] # List of discovered models for quick switching
self.image_paths = []
self.current_idx = 0
self.annotations = {} # image_name -> list of boxes
self.model = None
# Load images if directory provided
if images_dir and images_dir.exists():
self._load_images(images_dir)
if model_weights and model_weights.exists():
self._load_model(model_weights)
def _load_images(self, images_dir: Path):
"""Load images from directory."""
self.images_dir = images_dir
self.image_paths = sorted(
list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
)
self.current_idx = 0
# Load existing annotations if present
self.ann_file = images_dir / "annotations.json"
if self.ann_file.exists():
with self.ann_file.open("r") as f:
self.annotations = json.load(f)
else:
self.annotations = {}
return f"✓ Loaded {len(self.image_paths)} images from {images_dir}"
def find_best_weights(self, directory: Path) -> tuple[Path | None, str | None]:
"""Find the best weights file in a directory based on model type detection."""
if not directory.exists():
return None, None
# Check for RF-DETR weights (checkpoint_best_total.pth)
rf_detr_weights = directory / "checkpoint_best_total.pth"
if rf_detr_weights.exists():
return rf_detr_weights, "rf-detr"
# Check for Ultralytics weights (best.pt) in weights/ subdirectory
ultralytics_weights = directory / "weights" / "best.pt"
if ultralytics_weights.exists():
# Try to determine specific type from directory name or other clues
dir_name = directory.name.lower()
if "rtdetr" in dir_name:
return ultralytics_weights, "rt-detr"
elif "yolov6" in dir_name:
return ultralytics_weights, "yolov6"
elif "yolox" in dir_name:
return ultralytics_weights, "yolox"
else:
# Default to rt-detr for ultralytics models
return ultralytics_weights, "rt-detr"
# Check for Ultralytics weights in training/weights/ subdirectory (YOLOv6/YOLOX format)
training_weights = directory / "training" / "weights" / "best.pt"
if training_weights.exists():
dir_name = directory.name.lower()
if "rtdetr" in dir_name:
return training_weights, "rt-detr"
elif "yolov6" in dir_name:
return training_weights, "yolov6"
elif "yolox" in dir_name:
return training_weights, "yolox"
else:
# Default to yolox for training/weights structure
return training_weights, "yolox"
# Check for direct best.pt in directory
direct_best = directory / "best.pt"
if direct_best.exists():
return direct_best, "rt-detr" # Default assumption
# Check for any .pth or .pt files as fallback
pth_files = list(directory.glob("*.pth")) + list(directory.glob("*.pt"))
if pth_files:
# Prefer files with "best" in name
best_files = [f for f in pth_files if "best" in f.name.lower()]
if best_files:
return best_files[0], self._guess_model_type_from_path(best_files[0])
else:
return pth_files[0], self._guess_model_type_from_path(pth_files[0])
return None, None
def _guess_model_type_from_path(self, path: Path) -> str:
"""Guess model type from file path."""
path_str = str(path).lower()
if "rf" in path_str or "checkpoint" in path_str:
return "rf-detr"
elif "rtdetr" in path_str:
return "rt-detr"
elif "yolov6" in path_str:
return "yolov6"
elif "yolox" in path_str:
return "yolox"
else:
return "rt-detr" # Default
def _load_model(self, weights_path: Path, model_type: str = None):
"""Load model for auto-labeling based on type."""
try:
import torch
if model_type is None:
model_type = self._guess_model_type_from_path(weights_path)
print(f"Loading {model_type} model from {weights_path}...")
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if model_type == "rf-detr":
# RF-DETR uses custom loader - try different model sizes
from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall
# Try to determine model size from checkpoint or use nano as default
checkpoint = torch.load(weights_path, map_location='cpu', weights_only=False)
if 'model' in checkpoint:
# Training checkpoint - check the model size from the state dict
state_dict = checkpoint['model']
# Look for clues about model size in the state dict keys
if any('backbone.0.encoder.encoder.embeddings.position_embeddings' in key for key in state_dict.keys()):
# Try different model sizes to find the right one
models_to_try = [
("nano", RFDETRNano),
("small", RFDETRSmall),
("medium", RFDETRMedium),
("base", RFDETRBase)
]
for size_name, model_class in models_to_try:
try:
self.model = model_class()
# Try loading with strict=False to handle mismatches
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
if len(missing_keys) < len(self.model.state_dict()): # Some keys matched
print(f"✓ Loaded RF-DETR {size_name} model (with {len(missing_keys)} missing keys)")
# Move to GPU if available
self.model = self.model.to(device)
break
except Exception as e:
continue
else:
raise Exception("Could not load checkpoint with any RF-DETR model size")
else:
# Direct weights file
self.model = RFDETRNano(pretrain_weights=str(weights_path))
self.model = self.model.to(device)
else:
# Direct weights file
self.model = RFDETRNano(pretrain_weights=str(weights_path))
self.model = self.model.to(device)
else:
# RT-DETR, YOLOv6, YOLOX all use Ultralytics
if model_type == "rt-detr":
from ultralytics import RTDETR
self.model = RTDETR(str(weights_path))
else:
from ultralytics import YOLO
self.model = YOLO(str(weights_path))
# Ultralytics models should automatically use GPU if available
# but let's ensure they're on the right device
if hasattr(self.model, 'to'):
self.model = self.model.to(device)
self.current_model_path = weights_path
self.current_model_type = model_type
# Add to available models for quick switching
model_display = f"Custom: {weights_path.name} ({model_type.upper()})"
existing_model = next((m for m in self.available_models if m['path'] == weights_path), None)
if not existing_model:
self.available_models.append({
"path": weights_path,
"type": model_type,
"dir": weights_path.parent.name,
"display": model_display
})
print("✓ Model loaded")
return f"{model_type.upper()} model loaded from {weights_path.name}"
except Exception as e:
error_msg = f"⚠ Could not load {model_type or 'model'}: {e}"
print(error_msg)
self.model = None
self.current_model_type = None
return error_msg
def load_new_model(self, weights_path: str, model_type: str = "Auto-detect") -> str:
"""Load a new model from the GUI."""
path = Path(weights_path)
if not path.exists():
return f"❌ File not found: {weights_path}"
# Convert dropdown value to internal type
if model_type == "Auto-detect":
model_type = None
elif model_type == "rf-detr":
model_type = "rf-detr"
elif model_type == "rt-detr":
model_type = "rt-detr"
elif model_type == "yolov6":
model_type = "yolov6"
elif model_type == "yolox":
model_type = "yolox"
return self._load_model(path, model_type)
def load_model_from_directory(self, directory_path: str) -> str:
"""Load the best model found in a directory."""
path = Path(directory_path)
if not path.exists():
return f"❌ Directory not found: {directory_path}"
if not path.is_dir():
return f"❌ Not a directory: {directory_path}"
weights_path, detected_type = self.find_best_weights(path)
if weights_path is None:
return f"❌ No model weights found in {directory_path}"
return self._load_model(weights_path, detected_type)
def scan_for_models(self, return_info: bool = True) -> str:
"""Scan for available trained models in common directories."""
runs_dir = Path("runs")
available_models = []
if runs_dir.exists():
for subdir in runs_dir.iterdir():
if subdir.is_dir():
weights_path, model_type = self.find_best_weights(subdir)
if weights_path:
available_models.append({
"path": weights_path,
"type": model_type,
"dir": subdir.name,
"display": f"{subdir.name} ({model_type.upper()})"
})
# Store available models for quick access
self.available_models = available_models
if not return_info:
return ""
if not available_models:
return "❌ No trained models found in 'runs/' directory"
# Format as readable list
lines = ["📂 Available Models:"]
for i, model in enumerate(available_models, 1):
lines.append(f"{i}. {model['dir']} -> {model['path'].name} ({model['type'].upper()})")
lines.append("\n💡 Use the Model Selector dropdown above to quickly switch models")
return "\n".join(lines)
def get_available_models_list(self) -> list:
"""Get list of available models for dropdown."""
if not self.available_models:
self.scan_for_models(return_info=False) # This will populate self.available_models
if not self.available_models:
return ["No models found - click '🔍 Scan for Models'"]
return [model['display'] for model in self.available_models]
def load_model_by_index(self, model_display: str) -> str:
"""Load a model by its display name from the available models list."""
if not hasattr(self, 'available_models') or not self.available_models:
return "❌ No models available. Click '🔍 Scan for Models' first."
for model in self.available_models:
if model['display'] == model_display:
return self._load_model(model['path'], model['type'])
return f"❌ Model '{model_display}' not found"
def load_new_images_dir(self, images_dir: str) -> tuple[str, str, str, str]:
"""Load a new images directory from the GUI."""
path = Path(images_dir)
if not path.exists():
return "<div>Directory not found</div>", "", "Image: -/-", f"❌ Directory not found: {images_dir}"
if not path.is_dir():
return "<div>Not a directory</div>", "", "Image: -/-", f"❌ Not a directory: {images_dir}"
result = self._load_images(path)
# Load first image
if self.image_paths:
img, filename = self.get_current_image()
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
image_label = self._current_image_label(filename)
status = result
return img_html, boxes_html, image_label, status
else:
return "<div>No images found</div>", "", "Image: -/-", f"{result}\n⚠️ No .jpg or .png images found in directory"
def get_current_model_info(self) -> str:
"""Get info about currently loaded model."""
if self.model and self.current_model_path:
type_info = f" ({self.current_model_type.upper()})" if self.current_model_type else ""
return f"📦 Loaded: {self.current_model_path}{type_info}"
elif self.model:
return "📦 Model loaded (pretrained)"
else:
return "⚠️ No model loaded"
def get_current_dir_info(self) -> str:
"""Get info about current images directory."""
return f"📁 {self.images_dir} ({len(self.image_paths)} images)"
def _current_image_label(self, filename: str) -> str:
"""Stable image index display (kept separate from status messages)."""
if not filename or not self.image_paths:
return "Image: -/-"
return f"Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
def get_current_image(self) -> tuple[Image.Image, str]:
"""Get current image and filename."""
if not self.image_paths:
return None, ""
path = self.image_paths[self.current_idx]
img = Image.open(path).convert("RGB")
return img, path.name
def draw_boxes_on_image(self, img: Image.Image, boxes: list[dict]) -> Image.Image:
"""Draw bounding boxes on image."""
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
for box in boxes:
x1, y1, x2, y2 = box["bbox"]
label = box.get("label", "knot")
conf = box.get("confidence", 1.0)
# Draw box
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
# Draw corner handles for editing (small squares)
handle_size = 6
draw.rectangle([x1-handle_size, y1-handle_size, x1+handle_size, y1+handle_size], fill="red")
draw.rectangle([x2-handle_size, y1-handle_size, x2+handle_size, y1+handle_size], fill="red")
draw.rectangle([x1-handle_size, y2-handle_size, x1+handle_size, y2+handle_size], fill="red")
draw.rectangle([x2-handle_size, y2-handle_size, x2+handle_size, y2+handle_size], fill="red")
# Draw label
text = f"{label} {conf:.2f}" if conf < 1.0 else label
draw.text((x1, y1 - 20), text, fill="red")
return img_draw
def generate_interactive_canvas(self, boxes: list[dict] = None) -> str:
"""Generate HTML with interactive canvas for annotation."""
img, filename = self.get_current_image()
if not img:
return "<div>No image loaded</div>"
if boxes is None:
boxes = self.annotations.get(filename, [])
# Resize image for display if too large
max_width = 1200
max_height = 800
if img.width > max_width or img.height > max_height:
ratio = min(max_width / img.width, max_height / img.height)
display_width = int(img.width * ratio)
display_height = int(img.height * ratio)
img = img.resize((display_width, display_height), Image.Resampling.LANCZOS)
else:
display_width = img.width
display_height = img.height
# Convert PIL image to base64
import base64
from io import BytesIO
buffer = BytesIO()
img.save(buffer, format="PNG", optimize=True)
img_base64 = base64.b64encode(buffer.getvalue()).decode()
# Build HTML without <script> tags (Gradio sanitizes them). The interactive
# canvas logic runs via CANVAS_JS_ON_LOAD.
boxes_json = json.dumps(boxes)
boxes_escaped = html.escape(boxes_json)
html_out = f"""
<div style="display: inline-block; border: 1px solid #ccc; padding: 5px;">
<div style="font-size: 12px; color: #666; margin-bottom: 4px;">Canvas Size: {display_width}x{display_height}</div>
<textarea id="annotation-initial-boxes" style="display:none;">{boxes_escaped}</textarea>
<div data-annotation-root="1" style="position: relative; width: {display_width}px; height: {display_height}px;">
<img id="annotation-img" src="data:image/png;base64,{img_base64}"
style="position:absolute; left:0; top:0; width:{display_width}px; height:{display_height}px; z-index: 1;" />
<canvas id="annotation-canvas" width="{display_width}" height="{display_height}"
style="position:absolute; left:0; top:0; width:{display_width}px; height:{display_height}px; cursor: crosshair; background: transparent; z-index: 2; pointer-events: auto;"></canvas>
</div>
</div>
"""
return html_out
def _format_boxes_html(self, boxes: list[dict]) -> str:
"""Format boxes as HTML list with delete buttons."""
if not boxes:
return "<div style='color: #999; font-style: italic;'>No annotations</div>"
lines = ["<div style='font-family: monospace; font-size: 13px;'>"]
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box["bbox"]
conf = box.get("confidence", 1.0)
source = box.get("source", "manual")
lines.append(
f"<div style='margin: 4px 0; padding: 4px; background: #f5f5f5; border-radius: 3px; display: flex; justify-content: space-between; align-items: center;'>"
f"<span style='flex: 1;'>{i}: [{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}] {conf:.2f} ({source})</span>"
f"<button onclick='window.deleteBox({i})' style='background: #ff4444; color: white; border: none; padding: 2px 8px; border-radius: 3px; cursor: pointer; margin-left: 8px;'>✕</button>"
f"</div>"
)
lines.append("</div>")
return "".join(lines)
def _parse_boxes_text(self, text: str) -> list[dict] | None:
"""Parse edited JSON from the Annotations textbox."""
if not text:
return None
try:
data = json.loads(text)
except json.JSONDecodeError:
return None
if not isinstance(data, list):
return None
cleaned: list[dict] = []
for item in data:
if not isinstance(item, dict) or "bbox" not in item:
continue
bbox = item.get("bbox")
if not (isinstance(bbox, list) and len(bbox) == 4):
continue
try:
x1, y1, x2, y2 = [float(v) for v in bbox]
except Exception:
continue
cleaned.append({
"bbox": [x1, y1, x2, y2],
"label": item.get("label", "knot"),
"confidence": float(item.get("confidence", 1.0)),
"source": item.get("source", "manual"),
})
return cleaned
def load_image(self, direction: str = "current") -> tuple[str, str, str, str]:
"""Load image (current/next/prev)."""
if direction == "next":
self.current_idx = min(self.current_idx + 1, len(self.image_paths) - 1)
elif direction == "prev":
self.current_idx = max(self.current_idx - 1, 0)
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
# Load existing annotations
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
image_label = self._current_image_label(filename)
return img_html, boxes_html, image_label, ""
def add_box_manual(self, x1: int, y1: int, x2: int, y2: int) -> tuple[str, str, str, str]:
"""Manually add a bounding box."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
# Add box
box = {
"bbox": [float(x1), float(y1), float(x2), float(y2)],
"label": "knot",
"confidence": 1.0,
"source": "manual"
}
if filename not in self.annotations:
self.annotations[filename] = []
self.annotations[filename].append(box)
self._save_annotations()
# Redraw
boxes = self.annotations[filename]
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), f"✓ Added box: {len(boxes)} total"
def delete_last_box(self) -> tuple[str, str, str, str]:
"""Delete the last box from current image."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
if filename in self.annotations and self.annotations[filename]:
self.annotations[filename].pop()
self._save_annotations()
# Redraw
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), f"✓ Deleted last box: {len(boxes)} remaining"
def delete_box_by_index(self, index: int) -> tuple[str, str, str, str]:
"""Delete a specific box by index."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
boxes = self.annotations.get(filename, [])
if 0 <= index < len(boxes):
boxes.pop(index)
self.annotations[filename] = boxes
self._save_annotations()
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), f"✓ Deleted box {index}"
else:
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), "❌ Invalid box index"
def clear_boxes(self) -> tuple[str, str, str, str]:
"""Clear all boxes from current image."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
self.annotations[filename] = []
self._save_annotations()
boxes = []
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), "✓ Cleared all boxes"
def auto_label_current(self, threshold: float = 0.5) -> tuple[str, str, str, str]:
"""Auto-label current image using loaded model."""
if not self.model:
return "<div>No model loaded</div>", "", "Image: -/-", "❌ No model loaded"
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
try:
# Run inference based on model type
if self.current_model_type == "rf-detr":
# RF-DETR custom prediction
detections = self.model.predict(img, threshold=threshold)
boxes = []
for i in range(len(detections)):
xyxy = detections.xyxy[i]
conf = float(detections.confidence[i]) if detections.confidence is not None else 1.0
x1, y1, x2, y2 = xyxy
boxes.append({
"bbox": [float(x1), float(y1), float(x2), float(y2)],
"label": "knot",
"confidence": conf,
"source": "auto"
})
else:
# Ultralytics models (RT-DETR, YOLOv6, YOLOX)
results = self.model.predict(
source=img,
conf=threshold,
save=False,
verbose=False
)
boxes = []
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
conf = float(box.conf[0])
boxes.append({
"bbox": [x1, y1, x2, y2],
"label": "knot",
"confidence": conf,
"source": "auto"
})
# Add to existing annotations
if filename not in self.annotations:
self.annotations[filename] = []
self.annotations[filename].extend(boxes)
self._save_annotations()
# Redraw
img_html = self.generate_interactive_canvas(self.annotations[filename])
boxes_html = self._format_boxes_html(self.annotations[filename])
return img_html, boxes_html, self._current_image_label(filename), f"🤖 Auto-labeled: {len(boxes)} detections"
except Exception as e:
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
return img_html, self._format_boxes_html(boxes), self._current_image_label(filename), f"❌ Auto-label failed: {e}"
def save_canvas_changes(self, boxes_json: str) -> tuple[str, str, str, str]:
"""Auto-save changes made in the canvas."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "Image: -/-", "No images"
try:
if boxes_json:
boxes = json.loads(boxes_json)
self.annotations[filename] = boxes
self._save_annotations()
status = ""
else:
boxes = self.annotations.get(filename, [])
status = ""
except json.JSONDecodeError:
boxes = self.annotations.get(filename, [])
status = "❌ Invalid canvas data"
img_html = self.generate_interactive_canvas(boxes)
boxes_html = self._format_boxes_html(boxes)
return img_html, boxes_html, self._current_image_label(filename), status
def _save_annotations(self):
"""Save annotations to JSON file."""
with self.ann_file.open("w") as f:
json.dump(self.annotations, f, indent=2)
def export_to_coco(self, output_path: Path):
"""Export annotations to COCO format."""
coco_data = {
"images": [],
"annotations": [],
"categories": [{"id": 0, "name": "knot", "supercategory": "defect"}]
}
ann_id = 0
for img_id, img_path in enumerate(self.image_paths):
filename = img_path.name
img = Image.open(img_path)
width, height = img.size
coco_data["images"].append({
"id": img_id,
"file_name": filename,
"width": width,
"height": height
})
# Add annotations
boxes = self.annotations.get(filename, [])
for box in boxes:
if isinstance(box, dict) and "bbox" in box:
x1, y1, x2, y2 = box["bbox"]
score = box.get("confidence", 1.0)
else:
# Backward/experimental compatibility: [x1, y1, x2, y2]
x1, y1, x2, y2 = box
score = 1.0
w = x2 - x1
h = y2 - y1
coco_data["annotations"].append({
"id": ann_id,
"image_id": img_id,
"category_id": 0,
"bbox": [x1, y1, w, h],
"area": w * h,
"iscrowd": 0,
"score": score
})
ann_id += 1
with output_path.open("w") as f:
json.dump(coco_data, f, indent=2)
return f"✓ Exported {len(coco_data['annotations'])} annotations to {output_path}"
def get_model_path_from_display(self, model_display: str) -> Path | None:
"""Get the actual model path from a display name."""
if not hasattr(self, 'available_models') or not self.available_models:
return None
for model in self.available_models:
if model['display'] == model_display:
return model['path']
return None
def create_ui(app: AnnotationApp) -> gr.Blocks:
"""Create Gradio UI."""
with gr.Blocks(title="Knot Annotation Tool") as demo:
gr.Markdown("""
# Wood Knot Annotation Tool
**Label -> Auto-Label -> Export**
- Manually annotate images or use **Auto-Label** with your trained model
- Export annotations to COCO format for training
- Use separate training and deployment scripts for model development
""")
# Settings section at the top
with gr.Accordion("Settings", open=False):
with gr.Row():
with gr.Column():
images_dir_input = gr.Textbox(
label="Images Directory",
value=str(app.images_dir),
placeholder="/path/to/images"
)
load_images_btn = gr.Button("📁 Load Images Directory")
dir_info = gr.Textbox(label="Current Directory", value=app.get_current_dir_info(), interactive=False)
with gr.Column():
# Quick Model Selector
model_selector = gr.Dropdown(
choices=app.get_available_models_list(),
value=None,
label="🚀 Quick Model Switcher",
info="Select from available trained models (refresh with scan)",
allow_custom_value=True
)
quick_load_btn = gr.Button("⚡ Load Selected Model", variant="primary")
# Manual Model Loading
model_weights_input = gr.Textbox(
label="Model Weights Path",
value=str(app.current_model_path) if app.current_model_path else "",
placeholder="runs/training/checkpoint_best_total.pth"
)
model_type_dropdown = gr.Dropdown(
choices=["Auto-detect", "rf-detr", "rt-detr", "yolov6", "yolox"],
value="Auto-detect",
label="Model Type",
info="Auto-detect will try to determine from file path"
)
with gr.Row():
load_model_btn = gr.Button("🤖 Load Model Weights")
scan_models_btn = gr.Button("🔍 Scan for Models")
model_info = gr.Textbox(label="Current Model", value=app.get_current_model_info(), interactive=False)
with gr.Row():
with gr.Column(scale=3):
image_display = gr.HTML(
label="Current Image",
value="<div style='width: 800px; height: 400px; border: 1px solid #ccc; display: flex; align-items: center; justify-content: center; color: #666;'>Load images from Settings to start annotating</div>",
js_on_load=CANVAS_JS_ON_LOAD,
)
with gr.Row():
prev_btn = gr.Button("⬅️ Previous")
next_btn = gr.Button("Next ➡️")
auto_label_btn = gr.Button("🤖 Auto-Label", variant="primary")
# Hidden textbox to store canvas boxes data
canvas_boxes_data = gr.Textbox(visible=False, elem_id="canvas-boxes-data")
with gr.Row():
threshold_slider = gr.Slider(0.1, 0.9, DEFAULT_DETECTION_THRESHOLD, label="Detection Threshold")
with gr.Column(scale=1):
image_index_text = gr.Textbox(label="Image", lines=1, interactive=False)
info_text = gr.Textbox(label="Status", lines=2)
boxes_html = gr.HTML(label="Annotations")
delete_box_index = gr.Number(visible=False, value=-1)
gr.Markdown("### Manual Annotation")
with gr.Row():
x1_input = gr.Number(label="x1", value=100)
y1_input = gr.Number(label="y1", value=100)
with gr.Row():
x2_input = gr.Number(label="x2", value=200)
y2_input = gr.Number(label="y2", value=200)
add_box_btn = gr.Button(" Add Box")
delete_btn = gr.Button("🗑️ Delete Last")
clear_btn = gr.Button("❌ Clear All")
gr.Markdown("### Export Annotations")
export_path = gr.Textbox(
label="Export Path",
value="annotations_coco.json"
)
export_btn = gr.Button("Export COCO")
export_result = gr.Textbox(label="Export Result", lines=1)
# Event handlers
def on_load():
return app.load_image("current")
# Settings handlers
load_images_btn.click(
app.load_new_images_dir,
inputs=[images_dir_input],
outputs=[image_display, boxes_html, image_index_text, info_text]
).then(
lambda: (app.get_current_dir_info(), app.get_current_model_info()),
outputs=[dir_info, model_info]
)
load_model_btn.click(
app.load_new_model,
inputs=[model_weights_input, model_type_dropdown],
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
scan_models_btn.click(
app.scan_for_models,
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
quick_load_btn.click(
app.load_model_by_index,
inputs=[model_selector],
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
prev_btn.click(
lambda: app.load_image("prev"),
outputs=[image_display, boxes_html, image_index_text, info_text]
)
next_btn.click(
lambda: app.load_image("next"),
outputs=[image_display, boxes_html, image_index_text, info_text]
)
auto_label_btn.click(
lambda t: app.auto_label_current(t),
inputs=[threshold_slider],
outputs=[image_display, boxes_html, image_index_text, info_text]
)
# Auto-save when canvas changes
canvas_boxes_data.change(
app.save_canvas_changes,
inputs=[canvas_boxes_data],
outputs=[image_display, boxes_html, image_index_text, info_text]
)
# Delete box handler (called from HTML button clicks via JS)
delete_box_index.change(
lambda idx: app.delete_box_by_index(int(idx)) if idx >= 0 else (None, None, None, None),
inputs=[delete_box_index],
outputs=[image_display, boxes_html, image_index_text, info_text]
)
add_box_btn.click(
app.add_box_manual,
inputs=[x1_input, y1_input, x2_input, y2_input],
outputs=[image_display, boxes_html, image_index_text, info_text]
)
delete_btn.click(
app.delete_last_box,
outputs=[image_display, boxes_html, image_index_text, info_text]
)
clear_btn.click(
app.clear_boxes,
outputs=[image_display, boxes_html, image_index_text, info_text]
)
export_btn.click(
lambda path: app.export_to_coco(Path(path)),
inputs=[export_path],
outputs=[export_result]
)
# Load first image on start
demo.load(on_load, outputs=[image_display, boxes_html, image_index_text, info_text])
return demo
def main():
parser = argparse.ArgumentParser(description="Simple annotation GUI with auto-labeling")
parser.add_argument(
"--images-dir",
type=Path,
default=Path(DEFAULT_IMAGES_DIR) if DEFAULT_IMAGES_DIR else None,
help="Default directory with images (can be changed in GUI)"
)
parser.add_argument(
"--model-weights",
type=Path,
default=Path(DEFAULT_MODEL_WEIGHTS) if DEFAULT_MODEL_WEIGHTS else None,
help="Default trained model for auto-labeling (can be changed in GUI)"
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help="Port to run the GUI on"
)
args = parser.parse_args()
# Validate paths if provided
if args.images_dir and not args.images_dir.exists():
print(f"⚠️ Warning: Images directory not found: {args.images_dir}")
print("You can load a different directory from the GUI Settings")
args.images_dir = None
if args.model_weights and not args.model_weights.exists():
print(f"⚠️ Warning: Model weights not found: {args.model_weights}")
print("You can load different weights from the GUI Settings")
args.model_weights = None
# Create app
app = AnnotationApp(args.images_dir, args.model_weights)
# Scan for available models on startup
app.scan_for_models(return_info=False)
# Create and launch UI
demo = create_ui(app)
print(f"\n{'='*60}")
print(f"🚀 Starting annotation tool...")
if args.images_dir:
print(f"📁 Default images: {args.images_dir} ({len(app.image_paths)} images)")
else:
print(f"📁 No default images - load directory from Settings")
if app.model:
print(f"🤖 Model: Loaded from {args.model_weights}")
else:
print(f"⚠️ No model loaded - load from Settings or train one")
print(f"💡 You can change images directory and model weights from the Settings panel")
print(f"{'='*60}\n")
# Combine canvas JS with delete button handler
combined_js = CANVAS_GLOBAL_JS + r"""
// Wire delete buttons to hidden number input
window.deleteBox = function(index) {
const hiddenInput = document.querySelector('input[type=number][style*=display], input[type=number].\\!hidden');
if (hiddenInput) {
hiddenInput.value = index;
hiddenInput.dispatchEvent(new Event('input', { bubbles: true }));
// Reset after triggering
setTimeout(() => { hiddenInput.value = -1; }, 100);
}
};
"""
demo.launch(
server_name="0.0.0.0",
server_port=args.port,
js=combined_js,
share=False
)
if __name__ == "__main__":
main()