Files
saw_mill_knot_detection/annotation_gui.py
2025-12-23 16:55:59 -07:00

1659 lines
68 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Simple customizable annotation GUI with auto-labeling support.
Built with Gradio - easy to modify and extend.
Run: python annotation_gui.py
To set default paths, edit config.py
"""
from __future__ import annotations
import argparse
import html
import json
import subprocess
import threading
from pathlib import Path
from typing import Any
import gradio as gr
from PIL import Image, ImageDraw
# Try to load config, use fallbacks if not available
try:
from config import (
DEFAULT_IMAGES_DIR, DEFAULT_MODEL_WEIGHTS, DEFAULT_PORT,
DEFAULT_DETECTION_THRESHOLD, DEFAULT_TRAIN_EPOCHS,
DEFAULT_BATCH_SIZE, DEFAULT_LEARNING_RATE, DEFAULT_MODEL_SIZE
)
except ImportError:
DEFAULT_IMAGES_DIR = None
DEFAULT_MODEL_WEIGHTS = None
DEFAULT_PORT = 7860
DEFAULT_DETECTION_THRESHOLD = 0.5
DEFAULT_TRAIN_EPOCHS = 20
DEFAULT_BATCH_SIZE = 4
DEFAULT_LEARNING_RATE = 1e-4
DEFAULT_MODEL_SIZE = "small"
# Gradio 6 sanitizes <script> tags inside gr.HTML content, so any canvas drawing code
# must live in the component's supported js_on_load hook.
CANVAS_JS_ON_LOAD = r"""
(() => {
const root = element;
if (!root) return;
const canvas = root.querySelector('#annotation-canvas');
const imgEl = root.querySelector('#annotation-img');
const initialBoxesEl = root.querySelector('#annotation-initial-boxes');
if (!canvas || !imgEl || !initialBoxesEl) return;
const ctx = canvas.getContext('2d');
const displayWidth = canvas.width;
const displayHeight = canvas.height;
let boxes = [];
try {
const raw = initialBoxesEl.value || initialBoxesEl.textContent || '[]';
boxes = JSON.parse(raw);
if (!Array.isArray(boxes)) boxes = [];
} catch (_) {
boxes = [];
}
const hiddenInput = document.getElementById('canvas-boxes-data');
const syncHidden = () => {
if (!hiddenInput) return;
hiddenInput.value = JSON.stringify(boxes);
hiddenInput.dispatchEvent(new Event('input', { bubbles: true }));
};
syncHidden();
let isDragging = false;
let dragStart = null;
let selectedCorner = null;
let selectedBoxIndex = -1;
let creatingBox = false;
let createStart = null;
function redraw() {
ctx.clearRect(0, 0, displayWidth, displayHeight);
// Base image is rendered via <img> below the canvas.
boxes.forEach((box) => {
const [x1, y1, x2, y2] = box.bbox;
const label = box.label || 'knot';
const conf = box.confidence || 1.0;
ctx.strokeStyle = 'red';
ctx.lineWidth = 3;
ctx.strokeRect(x1, y1, x2 - x1, y2 - y1);
const handleSize = 6;
ctx.fillStyle = 'red';
ctx.fillRect(x1 - handleSize, y1 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x2 - handleSize, y1 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x1 - handleSize, y2 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillRect(x2 - handleSize, y2 - handleSize, handleSize * 2, handleSize * 2);
ctx.fillStyle = 'red';
ctx.font = '16px Arial';
const text = conf < 1.0 ? `${label} ${conf.toFixed(2)}` : label;
ctx.fillText(text, x1, y1 - 5);
});
if (creatingBox && createStart && dragStart) {
ctx.strokeStyle = 'blue';
ctx.lineWidth = 2;
ctx.setLineDash([5, 5]);
const x = Math.min(createStart.x, dragStart.x);
const y = Math.min(createStart.y, dragStart.y);
const w = Math.abs(createStart.x - dragStart.x);
const h = Math.abs(createStart.y - dragStart.y);
ctx.strokeRect(x, y, w, h);
ctx.setLineDash([]);
}
}
function getCornerAt(x, y) {
const handleSize = 6;
for (let i = 0; i < boxes.length; i++) {
const [x1, y1, x2, y2] = boxes[i].bbox;
const corners = [
{ x: x1, y: y1, type: 'top-left' },
{ x: x2, y: y1, type: 'top-right' },
{ x: x1, y: y2, type: 'bottom-left' },
{ x: x2, y: y2, type: 'bottom-right' },
];
for (const corner of corners) {
if (
x >= corner.x - handleSize &&
x <= corner.x + handleSize &&
y >= corner.y - handleSize &&
y <= corner.y + handleSize
) {
return { boxIndex: i, corner: corner.type, pos: corner };
}
}
}
return null;
}
// Ensure we don't double-bind if Gradio reuses the DOM node.
if (canvas.dataset.bound === '1') {
redraw();
return;
}
canvas.dataset.bound = '1';
// If the <img> fails to load, draw a message on the canvas.
imgEl.addEventListener('error', () => {
ctx.clearRect(0, 0, displayWidth, displayHeight);
ctx.fillStyle = '#ffcccc';
ctx.fillRect(0, 0, displayWidth, displayHeight);
ctx.fillStyle = 'black';
ctx.font = '16px Arial';
ctx.fillText('Image failed to load', 10, 30);
}, { once: true });
// Initial draw
redraw();
canvas.addEventListener('mousedown', (e) => {
const rect = canvas.getBoundingClientRect();
const x = (e.clientX - rect.left) * (displayWidth / rect.width);
const y = (e.clientY - rect.top) * (displayHeight / rect.height);
selectedCorner = getCornerAt(x, y);
if (selectedCorner) {
isDragging = true;
selectedBoxIndex = selectedCorner.boxIndex;
canvas.style.cursor = 'move';
} else {
creatingBox = true;
createStart = { x, y };
dragStart = { x, y };
canvas.style.cursor = 'crosshair';
}
});
canvas.addEventListener('mousemove', (e) => {
const rect = canvas.getBoundingClientRect();
const x = (e.clientX - rect.left) * (displayWidth / rect.width);
const y = (e.clientY - rect.top) * (displayHeight / rect.height);
if (isDragging && selectedCorner) {
const box = boxes[selectedBoxIndex];
if (selectedCorner.corner === 'top-left') {
box.bbox[0] = Math.min(x, box.bbox[2] - 10);
box.bbox[1] = Math.min(y, box.bbox[3] - 10);
} else if (selectedCorner.corner === 'top-right') {
box.bbox[2] = Math.max(x, box.bbox[0] + 10);
box.bbox[1] = Math.min(y, box.bbox[3] - 10);
} else if (selectedCorner.corner === 'bottom-left') {
box.bbox[0] = Math.min(x, box.bbox[2] - 10);
box.bbox[3] = Math.max(y, box.bbox[1] + 10);
} else if (selectedCorner.corner === 'bottom-right') {
box.bbox[2] = Math.max(x, box.bbox[0] + 10);
box.bbox[3] = Math.max(y, box.bbox[1] + 10);
}
syncHidden();
redraw();
return;
}
if (creatingBox && createStart) {
dragStart = { x, y };
redraw();
return;
}
const corner = getCornerAt(x, y);
canvas.style.cursor = corner ? 'move' : 'crosshair';
});
canvas.addEventListener('mouseup', () => {
if (creatingBox && createStart && dragStart) {
const x1 = Math.min(createStart.x, dragStart.x);
const y1 = Math.min(createStart.y, dragStart.y);
const x2 = Math.max(createStart.x, dragStart.x);
const y2 = Math.max(createStart.y, dragStart.y);
if (x2 - x1 > 10 && y2 - y1 > 10) {
boxes.push({
bbox: [x1, y1, x2, y2],
label: 'knot',
confidence: 1.0,
source: 'manual',
});
syncHidden();
redraw();
}
}
isDragging = false;
creatingBox = false;
selectedCorner = null;
selectedBoxIndex = -1;
createStart = null;
dragStart = null;
canvas.style.cursor = 'crosshair';
});
})();
"""
class AnnotationApp:
def __init__(self, images_dir: Path | None = None, model_weights: Path | None = None):
self.images_dir = images_dir if images_dir else Path.cwd()
self.current_model_path = model_weights
self.current_model_type = None # Track model type: 'rf-detr', 'rt-detr', 'yolov6', 'yolox'
self.available_models = [] # List of discovered models for quick switching
self.image_paths = []
self.current_idx = 0
self.annotations = {} # image_name -> list of boxes
self.model = None
self.training_process = None
self.training_thread = None
self.training_status = "Not training"
# Load images if directory provided
if images_dir and images_dir.exists():
self._load_images(images_dir)
if model_weights and model_weights.exists():
self._load_model(model_weights)
def _load_images(self, images_dir: Path):
"""Load images from directory."""
self.images_dir = images_dir
self.image_paths = sorted(
list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
)
self.current_idx = 0
# Load existing annotations if present
self.ann_file = images_dir / "annotations.json"
if self.ann_file.exists():
with self.ann_file.open("r") as f:
self.annotations = json.load(f)
else:
self.annotations = {}
return f"✓ Loaded {len(self.image_paths)} images from {images_dir}"
def find_best_weights(self, directory: Path) -> tuple[Path | None, str | None]:
"""Find the best weights file in a directory based on model type detection."""
if not directory.exists():
return None, None
# Check for RF-DETR weights (checkpoint_best_total.pth)
rf_detr_weights = directory / "checkpoint_best_total.pth"
if rf_detr_weights.exists():
return rf_detr_weights, "rf-detr"
# Check for Ultralytics weights (best.pt) in weights/ subdirectory
ultralytics_weights = directory / "weights" / "best.pt"
if ultralytics_weights.exists():
# Try to determine specific type from directory name or other clues
dir_name = directory.name.lower()
if "rtdetr" in dir_name:
return ultralytics_weights, "rt-detr"
elif "yolov6" in dir_name:
return ultralytics_weights, "yolov6"
elif "yolox" in dir_name:
return ultralytics_weights, "yolox"
else:
# Default to rt-detr for ultralytics models
return ultralytics_weights, "rt-detr"
# Check for Ultralytics weights in training/weights/ subdirectory (YOLOv6/YOLOX format)
training_weights = directory / "training" / "weights" / "best.pt"
if training_weights.exists():
dir_name = directory.name.lower()
if "rtdetr" in dir_name:
return training_weights, "rt-detr"
elif "yolov6" in dir_name:
return training_weights, "yolov6"
elif "yolox" in dir_name:
return training_weights, "yolox"
else:
# Default to yolox for training/weights structure
return training_weights, "yolox"
# Check for direct best.pt in directory
direct_best = directory / "best.pt"
if direct_best.exists():
return direct_best, "rt-detr" # Default assumption
# Check for any .pth or .pt files as fallback
pth_files = list(directory.glob("*.pth")) + list(directory.glob("*.pt"))
if pth_files:
# Prefer files with "best" in name
best_files = [f for f in pth_files if "best" in f.name.lower()]
if best_files:
return best_files[0], self._guess_model_type_from_path(best_files[0])
else:
return pth_files[0], self._guess_model_type_from_path(pth_files[0])
return None, None
def _guess_model_type_from_path(self, path: Path) -> str:
"""Guess model type from file path."""
path_str = str(path).lower()
if "rf" in path_str or "checkpoint" in path_str:
return "rf-detr"
elif "rtdetr" in path_str:
return "rt-detr"
elif "yolov6" in path_str:
return "yolov6"
elif "yolox" in path_str:
return "yolox"
else:
return "rt-detr" # Default
def _load_model(self, weights_path: Path, model_type: str = None):
"""Load model for auto-labeling based on type."""
try:
import torch
if model_type is None:
model_type = self._guess_model_type_from_path(weights_path)
print(f"Loading {model_type} model from {weights_path}...")
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if model_type == "rf-detr":
# RF-DETR uses custom loader - try different model sizes
from rfdetr import RFDETRBase, RFDETRMedium, RFDETRNano, RFDETRSmall
# Try to determine model size from checkpoint or use nano as default
checkpoint = torch.load(weights_path, map_location='cpu', weights_only=False)
if 'model' in checkpoint:
# Training checkpoint - check the model size from the state dict
state_dict = checkpoint['model']
# Look for clues about model size in the state dict keys
if any('backbone.0.encoder.encoder.embeddings.position_embeddings' in key for key in state_dict.keys()):
# Try different model sizes to find the right one
models_to_try = [
("nano", RFDETRNano),
("small", RFDETRSmall),
("medium", RFDETRMedium),
("base", RFDETRBase)
]
for size_name, model_class in models_to_try:
try:
self.model = model_class()
# Try loading with strict=False to handle mismatches
missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
if len(missing_keys) < len(self.model.state_dict()): # Some keys matched
print(f"✓ Loaded RF-DETR {size_name} model (with {len(missing_keys)} missing keys)")
# Move to GPU if available
self.model = self.model.to(device)
break
except Exception as e:
continue
else:
raise Exception("Could not load checkpoint with any RF-DETR model size")
else:
# Direct weights file
self.model = RFDETRNano(pretrain_weights=str(weights_path))
self.model = self.model.to(device)
else:
# Direct weights file
self.model = RFDETRNano(pretrain_weights=str(weights_path))
self.model = self.model.to(device)
else:
# RT-DETR, YOLOv6, YOLOX all use Ultralytics
if model_type == "rt-detr":
from ultralytics import RTDETR
self.model = RTDETR(str(weights_path))
else:
from ultralytics import YOLO
self.model = YOLO(str(weights_path))
# Ultralytics models should automatically use GPU if available
# but let's ensure they're on the right device
if hasattr(self.model, 'to'):
self.model = self.model.to(device)
self.current_model_path = weights_path
self.current_model_type = model_type
# Add to available models for quick switching
model_display = f"Custom: {weights_path.name} ({model_type.upper()})"
existing_model = next((m for m in self.available_models if m['path'] == weights_path), None)
if not existing_model:
self.available_models.append({
"path": weights_path,
"type": model_type,
"dir": weights_path.parent.name,
"display": model_display
})
print("✓ Model loaded")
return f"{model_type.upper()} model loaded from {weights_path.name}"
except Exception as e:
error_msg = f"⚠ Could not load {model_type or 'model'}: {e}"
print(error_msg)
self.model = None
self.current_model_type = None
return error_msg
def load_new_model(self, weights_path: str, model_type: str = "Auto-detect") -> str:
"""Load a new model from the GUI."""
path = Path(weights_path)
if not path.exists():
return f"❌ File not found: {weights_path}"
# Convert dropdown value to internal type
if model_type == "Auto-detect":
model_type = None
elif model_type == "rf-detr":
model_type = "rf-detr"
elif model_type == "rt-detr":
model_type = "rt-detr"
elif model_type == "yolov6":
model_type = "yolov6"
elif model_type == "yolox":
model_type = "yolox"
return self._load_model(path, model_type)
def load_model_from_directory(self, directory_path: str) -> str:
"""Load the best model found in a directory."""
path = Path(directory_path)
if not path.exists():
return f"❌ Directory not found: {directory_path}"
if not path.is_dir():
return f"❌ Not a directory: {directory_path}"
weights_path, detected_type = self.find_best_weights(path)
if weights_path is None:
return f"❌ No model weights found in {directory_path}"
return self._load_model(weights_path, detected_type)
def scan_for_models(self, return_info: bool = True) -> str:
"""Scan for available trained models in common directories."""
runs_dir = Path("runs")
available_models = []
if runs_dir.exists():
for subdir in runs_dir.iterdir():
if subdir.is_dir():
weights_path, model_type = self.find_best_weights(subdir)
if weights_path:
available_models.append({
"path": weights_path,
"type": model_type,
"dir": subdir.name,
"display": f"{subdir.name} ({model_type.upper()})"
})
# Store available models for quick access
self.available_models = available_models
if not return_info:
return ""
if not available_models:
return "❌ No trained models found in 'runs/' directory"
# Format as readable list
lines = ["📂 Available Models:"]
for i, model in enumerate(available_models, 1):
lines.append(f"{i}. {model['dir']}{model['path'].name} ({model['type'].upper()})")
lines.append("\n💡 Use the Model Selector dropdown above to quickly switch models")
return "\n".join(lines)
def get_available_models_list(self) -> list:
"""Get list of available models for dropdown."""
if not self.available_models:
self.scan_for_models(return_info=False) # This will populate self.available_models
if not self.available_models:
return ["No models found - click '🔍 Scan for Models'"]
return [model['display'] for model in self.available_models]
def load_model_by_index(self, model_display: str) -> str:
"""Load a model by its display name from the available models list."""
if not hasattr(self, 'available_models') or not self.available_models:
return "❌ No models available. Click '🔍 Scan for Models' first."
for model in self.available_models:
if model['display'] == model_display:
return self._load_model(model['path'], model['type'])
return f"❌ Model '{model_display}' not found"
def load_new_images_dir(self, images_dir: str) -> tuple[str, str, str]:
"""Load a new images directory from the GUI."""
path = Path(images_dir)
if not path.exists():
return "<div>Directory not found</div>", "", f"❌ Directory not found: {images_dir}"
if not path.is_dir():
return "<div>Not a directory</div>", "", f"❌ Not a directory: {images_dir}"
result = self._load_images(path)
# Load first image
if self.image_paths:
img, filename = self.get_current_image()
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
boxes_text = self._format_boxes_text(boxes)
info = f"{result}\nImage 1/{len(self.image_paths)}: {filename}"
return img_html, boxes_text, info
else:
return "<div>No images found</div>", "", f"{result}\n⚠️ No .jpg or .png images found in directory"
def get_current_model_info(self) -> str:
"""Get info about currently loaded model."""
if self.model and self.current_model_path:
type_info = f" ({self.current_model_type.upper()})" if self.current_model_type else ""
return f"📦 Loaded: {self.current_model_path}{type_info}"
elif self.model:
return "📦 Model loaded (pretrained)"
else:
return "⚠️ No model loaded"
def get_current_dir_info(self) -> str:
"""Get info about current images directory."""
return f"📁 {self.images_dir} ({len(self.image_paths)} images)"
def get_current_image(self) -> tuple[Image.Image, str]:
"""Get current image and filename."""
if not self.image_paths:
return None, ""
path = self.image_paths[self.current_idx]
img = Image.open(path).convert("RGB")
return img, path.name
def draw_boxes_on_image(self, img: Image.Image, boxes: list[dict]) -> Image.Image:
"""Draw bounding boxes on image."""
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
for box in boxes:
x1, y1, x2, y2 = box["bbox"]
label = box.get("label", "knot")
conf = box.get("confidence", 1.0)
# Draw box
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
# Draw corner handles for editing (small squares)
handle_size = 6
draw.rectangle([x1-handle_size, y1-handle_size, x1+handle_size, y1+handle_size], fill="red")
draw.rectangle([x2-handle_size, y1-handle_size, x2+handle_size, y1+handle_size], fill="red")
draw.rectangle([x1-handle_size, y2-handle_size, x1+handle_size, y2+handle_size], fill="red")
draw.rectangle([x2-handle_size, y2-handle_size, x2+handle_size, y2+handle_size], fill="red")
# Draw label
text = f"{label} {conf:.2f}" if conf < 1.0 else label
draw.text((x1, y1 - 20), text, fill="red")
return img_draw
def generate_interactive_canvas(self, boxes: list[dict] = None) -> str:
"""Generate HTML with interactive canvas for annotation."""
img, filename = self.get_current_image()
if not img:
return "<div>No image loaded</div>"
if boxes is None:
boxes = self.annotations.get(filename, [])
# Resize image for display if too large
max_width = 1200
max_height = 800
if img.width > max_width or img.height > max_height:
ratio = min(max_width / img.width, max_height / img.height)
display_width = int(img.width * ratio)
display_height = int(img.height * ratio)
img = img.resize((display_width, display_height), Image.Resampling.LANCZOS)
else:
display_width = img.width
display_height = img.height
# Convert PIL image to base64
import base64
from io import BytesIO
buffer = BytesIO()
img.save(buffer, format="PNG", optimize=True)
img_base64 = base64.b64encode(buffer.getvalue()).decode()
# Build HTML without <script> tags (Gradio sanitizes them). The interactive
# canvas logic runs via CANVAS_JS_ON_LOAD.
boxes_json = json.dumps(boxes)
boxes_escaped = html.escape(boxes_json)
html_out = f"""
<div style="display: inline-block; border: 1px solid #ccc; padding: 5px;">
<div style="font-size: 12px; color: #666; margin-bottom: 4px;">Canvas Size: {display_width}x{display_height}</div>
<textarea id="annotation-initial-boxes" style="display:none;">{boxes_escaped}</textarea>
<div style="position: relative; width: {display_width}px; height: {display_height}px;">
<img id="annotation-img" src="data:image/png;base64,{img_base64}"
style="position:absolute; left:0; top:0; width:{display_width}px; height:{display_height}px;" />
<canvas id="annotation-canvas" width="{display_width}" height="{display_height}"
style="position:absolute; left:0; top:0; width:{display_width}px; height:{display_height}px; cursor: crosshair; background: transparent;"></canvas>
</div>
</div>
"""
return html_out
def _format_boxes_text(self, boxes: list[dict]) -> str:
"""Format boxes for display."""
if not boxes:
return "No annotations"
lines = []
for i, box in enumerate(boxes):
x1, y1, x2, y2 = box["bbox"]
conf = box.get("confidence", 1.0)
source = box.get("source", "manual")
lines.append(f"{i}: [{x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f}] conf={conf:.2f} ({source})")
return "\n".join(lines)
def load_image(self, direction: str = "current") -> tuple[str, str, str]:
"""Load image (current/next/prev)."""
if direction == "next":
self.current_idx = min(self.current_idx + 1, len(self.image_paths) - 1)
elif direction == "prev":
self.current_idx = max(self.current_idx - 1, 0)
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "No images"
# Load existing annotations
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
boxes_text = self._format_boxes_text(boxes)
info = f"Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
return img_html, boxes_text, info
def add_box_manual(self, x1: int, y1: int, x2: int, y2: int) -> tuple[str, str, str]:
"""Manually add a bounding box."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "No images"
# Add box
box = {
"bbox": [float(x1), float(y1), float(x2), float(y2)],
"label": "knot",
"confidence": 1.0,
"source": "manual"
}
if filename not in self.annotations:
self.annotations[filename] = []
self.annotations[filename].append(box)
self._save_annotations()
# Redraw
boxes = self.annotations[filename]
img_html = self.generate_interactive_canvas(boxes)
boxes_text = self._format_boxes_text(boxes)
info = f"✓ Added box: {len(boxes)} total | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
return img_html, boxes_text, info
def delete_last_box(self) -> tuple[str, str, str]:
"""Delete the last box from current image."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "No images"
if filename in self.annotations and self.annotations[filename]:
self.annotations[filename].pop()
self._save_annotations()
# Redraw
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
boxes_text = self._format_boxes_text(boxes)
info = f"✓ Deleted last box: {len(boxes)} remaining | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
return img_html, boxes_text, info
def clear_boxes(self) -> tuple[str, str, str]:
"""Clear all boxes from current image."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "No images"
self.annotations[filename] = []
self._save_annotations()
boxes = []
img_html = self.generate_interactive_canvas(boxes)
boxes_text = "No annotations"
info = f"✓ Cleared all boxes | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
return img_html, boxes_text, info
def auto_label_current(self, threshold: float = 0.5) -> tuple[str, str, str]:
"""Auto-label current image using loaded model."""
if not self.model:
return "<div>No model loaded</div>", "", "❌ No model loaded"
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "No images"
try:
# Run inference based on model type
if self.current_model_type == "rf-detr":
# RF-DETR custom prediction
detections = self.model.predict(img, threshold=threshold)
boxes = []
for i in range(len(detections)):
xyxy = detections.xyxy[i]
conf = float(detections.confidence[i]) if detections.confidence is not None else 1.0
x1, y1, x2, y2 = xyxy
boxes.append({
"bbox": [float(x1), float(y1), float(x2), float(y2)],
"label": "knot",
"confidence": conf,
"source": "auto"
})
else:
# Ultralytics models (RT-DETR, YOLOv6, YOLOX)
results = self.model.predict(
source=img,
conf=threshold,
save=False,
verbose=False
)
boxes = []
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = box.xyxy[0].tolist()
conf = float(box.conf[0])
boxes.append({
"bbox": [x1, y1, x2, y2],
"label": "knot",
"confidence": conf,
"source": "auto"
})
# Add to existing annotations
if filename not in self.annotations:
self.annotations[filename] = []
self.annotations[filename].extend(boxes)
self._save_annotations()
# Redraw
img_html = self.generate_interactive_canvas(self.annotations[filename])
boxes_text = self._format_boxes_text(self.annotations[filename])
info = f"🤖 Auto-labeled: {len(boxes)} detections | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
return img_html, boxes_text, info
except Exception as e:
boxes = self.annotations.get(filename, [])
img_html = self.generate_interactive_canvas(boxes)
return img_html, self._format_boxes_text(boxes), f"❌ Auto-label failed: {e}"
def save_canvas_changes(self, boxes_json: str) -> tuple[str, str, str]:
"""Save changes made in the interactive canvas."""
img, filename = self.get_current_image()
if not img:
return "<div>No images</div>", "", "No images"
try:
# Parse the boxes from JSON
if boxes_json:
boxes = json.loads(boxes_json)
self.annotations[filename] = boxes
self._save_annotations()
info = f"✓ Saved {len(boxes)} boxes | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
else:
boxes = self.annotations.get(filename, [])
info = f"✓ No changes to save | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
except json.JSONDecodeError:
boxes = self.annotations.get(filename, [])
info = f"❌ Invalid boxes data | Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
img_html = self.generate_interactive_canvas(boxes)
boxes_text = self._format_boxes_text(boxes)
return img_html, boxes_text, info
def _save_annotations(self):
"""Save annotations to JSON file."""
with self.ann_file.open("w") as f:
json.dump(self.annotations, f, indent=2)
def export_to_coco(self, output_path: Path):
"""Export annotations to COCO format."""
coco_data = {
"images": [],
"annotations": [],
"categories": [{"id": 0, "name": "knot", "supercategory": "defect"}]
}
ann_id = 0
for img_id, img_path in enumerate(self.image_paths):
filename = img_path.name
img = Image.open(img_path)
width, height = img.size
coco_data["images"].append({
"id": img_id,
"file_name": filename,
"width": width,
"height": height
})
# Add annotations
boxes = self.annotations.get(filename, [])
for box in boxes:
x1, y1, x2, y2 = box["bbox"]
w = x2 - x1
h = y2 - y1
coco_data["annotations"].append({
"id": ann_id,
"image_id": img_id,
"category_id": 0,
"bbox": [x1, y1, w, h],
"area": w * h,
"iscrowd": 0,
"score": box.get("confidence", 1.0)
})
ann_id += 1
with output_path.open("w") as f:
json.dump(coco_data, f, indent=2)
return f"✓ Exported {len(coco_data['annotations'])} annotations to {output_path}"
def prepare_training_dataset(self, output_dir: Path, train_split: float = 0.8, valid_split: float = 0.1):
"""Prepare dataset in RF-DETR format (train/valid/test splits)."""
output_dir.mkdir(parents=True, exist_ok=True)
# Create splits
import random
annotated_images = [img for img in self.image_paths if img.name in self.annotations and self.annotations[img.name]]
if len(annotated_images) < 10:
return f"⚠️ Need at least 10 annotated images, have {len(annotated_images)}"
random.shuffle(annotated_images)
n = len(annotated_images)
train_n = int(n * train_split)
valid_n = int(n * valid_split)
splits = {
"train": annotated_images[:train_n],
"valid": annotated_images[train_n:train_n + valid_n],
"test": annotated_images[train_n + valid_n:]
}
# Create directories and copy images
import shutil
for split_name, split_images in splits.items():
split_dir = output_dir / split_name
split_dir.mkdir(exist_ok=True)
# Prepare COCO JSON for this split
coco_data = {
"images": [],
"annotations": [],
"categories": [{"id": 0, "name": "knot", "supercategory": "defect"}]
}
ann_id = 0
for img_id, img_path in enumerate(split_images):
# Copy image
dest = split_dir / img_path.name
shutil.copy2(img_path, dest)
# Add to COCO
img = Image.open(img_path)
width, height = img.size
coco_data["images"].append({
"id": img_id,
"file_name": img_path.name,
"width": width,
"height": height
})
# Add annotations
boxes = self.annotations.get(img_path.name, [])
for box in boxes:
x1, y1, x2, y2 = box["bbox"]
w = x2 - x1
h = y2 - y1
coco_data["annotations"].append({
"id": ann_id,
"image_id": img_id,
"category_id": 0,
"bbox": [x1, y1, w, h],
"area": w * h,
"iscrowd": 0
})
ann_id += 1
# Save COCO JSON
with (split_dir / "_annotations.coco.json").open("w") as f:
json.dump(coco_data, f, indent=2)
return f"✓ Dataset prepared: {len(splits['train'])} train, {len(splits['valid'])} valid, {len(splits['test'])} test"
def start_training(self, framework: str, dataset_dir: str, output_dir: str, model_size: str,
epochs: int, batch_size: int, lr: float, progress=gr.Progress()):
"""Start training in background."""
dataset_path = Path(dataset_dir)
output_path = Path(output_dir)
if not dataset_path.exists():
return "❌ Dataset directory not found"
if self.training_process and self.training_process.poll() is None:
return "⚠️ Training already in progress"
output_path.mkdir(parents=True, exist_ok=True)
# Build training command based on framework
venv_python = Path(__file__).parent / ".venv/bin/python"
if framework == "RF-DETR":
train_script = Path(__file__).parent / "train_rfdetr.py"
# Map sizes: nano->nano, small->small, medium->medium, base->base
size_map = {"nano": "nano", "small": "small", "medium": "medium", "base": "base"}
model_arg = size_map.get(model_size, "medium")
cmd = [
str(venv_python),
str(train_script),
"--dataset-dir", str(dataset_path),
"--output-dir", str(output_path),
"--model", model_arg,
"--epochs", str(epochs),
"--batch-size", str(batch_size),
"--grad-accum-steps", "2", # Default grad accum
"--lr", str(lr)
]
elif framework == "RT-DETR":
train_script = Path(__file__).parent / "train_rtdetr.py"
# Map sizes: nano->r18, small->r34, medium->r50, base->l
size_map = {"nano": "rtdetr-r18", "small": "rtdetr-r34", "medium": "rtdetr-r50", "base": "rtdetr-l"}
model_arg = size_map.get(model_size, "rtdetr-r18")
cmd = [
str(venv_python),
str(train_script),
"--dataset-dir", str(dataset_path),
"--output-dir", str(output_path),
"--model", model_arg,
"--epochs", str(epochs),
"--batch-size", str(batch_size),
"--lr", str(lr)
]
elif framework == "YOLOv6":
train_script = Path(__file__).parent / "train_yolov6.py"
# Map sizes: nano->n, small->s, medium->m, base->l
size_map = {"nano": "yolov6n", "small": "yolov6s", "medium": "yolov6m", "base": "yolov6l"}
model_arg = size_map.get(model_size, "yolov6n")
cmd = [
str(venv_python),
str(train_script),
"--dataset-dir", str(dataset_path),
"--output-dir", str(output_path),
"--model", model_arg,
"--epochs", str(epochs),
"--batch-size", str(batch_size),
"--lr", str(lr)
]
elif framework == "YOLOX":
train_script = Path(__file__).parent / "train_yolox.py"
# Map sizes: nano->nano, small->s, medium->m, base->l
size_map = {"nano": "yolox-nano", "small": "yolox-s", "medium": "yolox-m", "base": "yolox-l"}
model_arg = size_map.get(model_size, "yolox-nano")
cmd = [
str(venv_python),
str(train_script),
"--dataset-dir", str(dataset_path),
"--output-dir", str(output_path),
"--model", model_arg,
"--epochs", str(epochs),
"--batch-size", str(batch_size),
"--lr", str(lr)
]
else:
return f"❌ Unknown framework: {framework}"
# Start training process
log_file = output_path / "training.log"
self.training_status = f"🚀 Starting {framework} training..."
def run_training():
try:
with log_file.open("w") as f:
self.training_process = subprocess.Popen(
cmd,
stdout=f,
stderr=subprocess.STDOUT,
text=True
)
self.training_status = f"⏳ Training in progress (PID: {self.training_process.pid})"
self.training_process.wait()
if self.training_process.returncode == 0:
self.training_status = "✅ Training completed successfully!"
# Reload model with new weights
if framework == "RF-DETR":
# RF-DETR uses checkpoint_best_total.pth
best_weights = output_path / "checkpoint_best_total.pth"
model_type = "rf-detr"
elif framework == "RT-DETR":
# RT-DETR uses best.pt in weights/ subdirectory (Ultralytics)
best_weights = output_path / "weights" / "best.pt"
model_type = "rt-detr"
elif framework == "YOLOv6":
best_weights = output_path / "weights" / "best.pt"
model_type = "yolov6"
elif framework == "YOLOX":
best_weights = output_path / "weights" / "best.pt"
model_type = "yolox"
if best_weights.exists():
self._load_model(best_weights, model_type)
else:
self.training_status = f"❌ Training failed (exit code {self.training_process.returncode})"
except Exception as e:
self.training_status = f"❌ Error: {e}"
self.training_thread = threading.Thread(target=run_training, daemon=True)
self.training_thread.start()
return f"✓ Training started! Check {log_file} for progress"
def get_training_status(self):
"""Get current training status."""
return self.training_status
def stop_training(self):
"""Stop the training process."""
if self.training_process and self.training_process.poll() is None:
self.training_process.terminate()
self.training_status = "⏹️ Training stopped by user"
return "✓ Training process terminated"
return "⚠️ No training in progress"
def get_model_path_from_display(self, model_display: str) -> Path | None:
"""Get the actual model path from a display name."""
if not hasattr(self, 'available_models') or not self.available_models:
return None
for model in self.available_models:
if model['display'] == model_display:
return model['path']
return None
def export_for_oak_d(self, model_display: str, output_dir: str = "oak_d_export", img_size: int = 640):
"""Export trained model for OAK-D camera deployment."""
try:
# Convert display name to actual path
weights_path = self.get_model_path_from_display(model_display)
if not weights_path:
return f"❌ Model '{model_display}' not found. Try clicking '🔍 Scan for Models' first."
output_path = Path(output_dir)
if not weights_path.exists():
return f"❌ Model weights not found at: {weights_path}"
output_path.mkdir(parents=True, exist_ok=True)
# Determine model type
model_type = self._guess_model_type_from_path(weights_path)
print(f"Exporting {model_type} model for OAK-D...")
if model_type == "rf-detr":
# RF-DETR export - use existing export_onnx.py logic
from rfdetr import RFDETRBase
model = RFDETRBase(pretrain_weights=str(weights_path))
model.export() # Creates output/model.onnx
# Move to output directory
onnx_source = Path("output/model.onnx")
if onnx_source.exists():
onnx_dest = output_path / "rf_detr_model.onnx"
onnx_source.rename(onnx_dest)
return f"✓ RF-DETR exported for OAK-D!\n📁 Output: {output_path}\n🔗 Next: Convert ONNX to blob using blobconverter.luxonis.com"
else:
return "❌ ONNX export failed"
else:
# Ultralytics models (RT-DETR, YOLOv6, YOLOX)
if model_type == "rt-detr":
from ultralytics import RTDETR
model = RTDETR(str(weights_path))
else:
from ultralytics import YOLO
model = YOLO(str(weights_path))
# Export to ONNX
onnx_path = model.export(
format="onnx",
imgsz=img_size,
simplify=True,
opset=11, # OAK-compatible opset
)
# Move ONNX to output directory
if Path(onnx_path).exists():
final_onnx = output_path / f"{model_type}_model.onnx"
Path(onnx_path).rename(final_onnx)
onnx_path = final_onnx
# Try to export to OpenVINO if available
try:
openvino_path = model.export(
format="openvino",
imgsz=img_size,
half=False, # Use FP32 for better compatibility
)
# Move OpenVINO files to output directory
if Path(openvino_path).exists():
import shutil
openvino_dir = Path(openvino_path)
for file in openvino_dir.glob("*"):
if file.is_file():
shutil.move(str(file), str(output_path / file.name))
openvino_dir.rmdir() # Remove empty dir
return f"{model_type.upper()} exported for OAK-D!\n📁 Output: {output_path}\n🔗 Next: Convert .xml/.bin to blob using blobconverter.luxonis.com"
except Exception as e:
# OpenVINO not available, just return ONNX
import shutil
docker_hint = ""
if shutil.which("docker") is None:
docker_hint = "\n⚠️ Docker not found (needed for offline conversion via ModelConverter)."
return (
f"{model_type.upper()} exported to ONNX!\n"
f"📁 Output: {output_path}\n"
f"🔗 Next: Convert ONNX → RVC using HubAI (online) or ModelConverter (offline).\n"
f"Docs: https://docs.luxonis.com/software-v3/ai-inference/conversion/\n"
f"💡 Offline conversion: Use Luxonis ModelConverter with Docker\n"
f"⚠️ OpenVINO export not available: {str(e)}"
f"{docker_hint}"
)
except Exception as e:
return f"❌ Export failed: {str(e)}"
def create_ui(app: AnnotationApp) -> gr.Blocks:
"""Create Gradio UI."""
with gr.Blocks(title="Knot Annotation Tool") as demo:
gr.Markdown("""
# 🪵 Wood Knot Annotation Tool
**Label → Train → Auto-Label → Repeat**
- Manually annotate images or use **Auto-Label** with your trained model
- Export and prepare dataset for training
- Train **RF-DETR, RT-DETR, YOLOv6, or YOLOX** (all free for commercial use!)
- Optimized for OAK-D camera deployment
- Use trained model to auto-label more images
""")
# Settings section at the top
with gr.Accordion("⚙️ Settings", open=False):
with gr.Row():
with gr.Column():
images_dir_input = gr.Textbox(
label="Images Directory",
value=str(app.images_dir),
placeholder="/path/to/images"
)
load_images_btn = gr.Button("📁 Load Images Directory")
dir_info = gr.Textbox(label="Current Directory", value=app.get_current_dir_info(), interactive=False)
with gr.Column():
# Quick Model Selector
model_selector = gr.Dropdown(
choices=app.get_available_models_list(),
value=None,
label="🚀 Quick Model Switcher",
info="Select from available trained models (refresh with scan)",
allow_custom_value=True
)
quick_load_btn = gr.Button("⚡ Load Selected Model", variant="primary")
# Manual Model Loading
model_weights_input = gr.Textbox(
label="Model Weights Path",
value=str(app.current_model_path) if app.current_model_path else "",
placeholder="runs/training/checkpoint_best_total.pth"
)
model_type_dropdown = gr.Dropdown(
choices=["Auto-detect", "rf-detr", "rt-detr", "yolov6", "yolox"],
value="Auto-detect",
label="Model Type",
info="Auto-detect will try to determine from file path"
)
with gr.Row():
load_model_btn = gr.Button("🤖 Load Model Weights")
scan_models_btn = gr.Button("🔍 Scan for Models")
model_info = gr.Textbox(label="Current Model", value=app.get_current_model_info(), interactive=False)
with gr.Row():
with gr.Column(scale=3):
image_display = gr.HTML(
label="Current Image",
value="<div style='width: 800px; height: 400px; border: 1px solid #ccc; display: flex; align-items: center; justify-content: center; color: #666;'>Load images from Settings to start annotating</div>",
js_on_load=CANVAS_JS_ON_LOAD,
)
with gr.Row():
prev_btn = gr.Button("⬅️ Previous")
next_btn = gr.Button("Next ➡️")
auto_label_btn = gr.Button("🤖 Auto-Label", variant="primary")
save_canvas_btn = gr.Button("💾 Save Canvas Changes")
# Hidden textbox to store canvas boxes data
canvas_boxes_data = gr.Textbox(visible=False, elem_id="canvas-boxes-data")
with gr.Row():
threshold_slider = gr.Slider(0.1, 0.9, DEFAULT_DETECTION_THRESHOLD, label="Detection Threshold")
with gr.Column(scale=1):
info_text = gr.Textbox(label="Status", lines=2)
boxes_text = gr.Textbox(label="Annotations", lines=10)
gr.Markdown("### Manual Annotation")
with gr.Row():
x1_input = gr.Number(label="x1", value=100)
y1_input = gr.Number(label="y1", value=100)
with gr.Row():
x2_input = gr.Number(label="x2", value=200)
y2_input = gr.Number(label="y2", value=200)
add_box_btn = gr.Button(" Add Box")
delete_btn = gr.Button("🗑️ Delete Last")
clear_btn = gr.Button("❌ Clear All")
gr.Markdown("### Export & Training")
export_path = gr.Textbox(
label="Export Path",
value="annotations_coco.json"
)
export_btn = gr.Button("💾 Export COCO")
export_result = gr.Textbox(label="Export Result", lines=1)
# Training tab
with gr.Tab("🎯 Training"):
gr.Markdown("""
### Train Object Detection Model
**Choose your framework:**
- **RF-DETR** (MIT): Custom transformer, high accuracy
- **RT-DETR** (Apache 2.0): Ultralytics transformer, great accuracy
- **YOLOv6** (MIT): Fast, proven on OAK cameras
- **YOLOX** (MIT): Similar to YOLOv6, slight differences
**All MIT/Apache 2.0 licensed - free for commercial use!** ✅
**Steps:**
1. Annotate at least 50-100 images in the Annotation tab
2. Click "Prepare Dataset" to create train/valid/test splits
3. Select your framework and configure training parameters
4. Click "Start Training" (runs in background)
5. After training, export for OAK-D deployment
""")
with gr.Row():
with gr.Column():
dataset_prep_dir = gr.Textbox(
label="Dataset Output Directory",
value="dataset_prepared"
)
train_split = gr.Slider(0.5, 0.9, 0.8, label="Train Split Ratio")
valid_split = gr.Slider(0.05, 0.3, 0.1, label="Valid Split Ratio")
prep_btn = gr.Button("📦 Prepare Dataset", variant="secondary")
prep_result = gr.Textbox(label="Preparation Result", lines=2)
with gr.Column():
gr.Markdown("### Training Configuration")
model_framework = gr.Dropdown(
choices=["RF-DETR", "RT-DETR", "YOLOv6", "YOLOX"],
value="RT-DETR",
label="Model Framework",
info="All MIT/Apache 2.0 licensed - free for commercial use. Optimized for OAK cameras."
)
train_dataset_dir = gr.Textbox(
label="Dataset Directory",
value="dataset_prepared"
)
train_output_dir = gr.Textbox(
label="Output Directory",
value="runs/gui_training"
)
model_size = gr.Dropdown(
choices=["nano", "small", "medium", "base"],
value=DEFAULT_MODEL_SIZE,
label="Model Size"
)
epochs = gr.Slider(5, 100, DEFAULT_TRAIN_EPOCHS, step=5, label="Epochs")
batch_size = gr.Slider(1, 16, DEFAULT_BATCH_SIZE, step=1, label="Batch Size")
learning_rate = gr.Number(value=DEFAULT_LEARNING_RATE, label="Learning Rate")
with gr.Row():
start_train_btn = gr.Button("🚀 Start Training", variant="primary")
stop_train_btn = gr.Button("⏹️ Stop Training", variant="stop")
refresh_status_btn = gr.Button("🔄 Refresh Status")
training_status = gr.Textbox(
label="Training Status",
value="Not training",
lines=3
)
gr.Markdown("""
**Note**: Training runs in the background. You can continue annotating while training.
Check the training log file for detailed progress.
""")
# OAK-D Deployment tab
with gr.Tab("🚀 OAK-D Deployment"):
gr.Markdown("""
### Deploy Trained Model to OAK-D Camera
Convert your trained model to work with the **OAK-D 4 Pro** camera for real-time edge inference.
**Supported Models**: RF-DETR, RT-DETR, YOLOv6, YOLOX
**Process**:
1. Select a trained model from your runs/ directory
2. Export to ONNX and OpenVINO formats
3. Convert OpenVINO model to blob for OAK-D
4. Deploy blob to your OAK-D camera
""")
with gr.Row():
with gr.Column():
oak_model_selector = gr.Dropdown(
choices=app.get_available_models_list(),
value=None,
label="Select Trained Model",
info="Choose a model from your training runs",
allow_custom_value=True
)
oak_output_dir = gr.Textbox(
label="Output Directory",
value="oak_d_deployment",
placeholder="oak_d_deployment"
)
oak_img_size = gr.Dropdown(
choices=[320, 416, 512, 640, 800, 1024],
value=640,
label="Image Size",
info="Input size for the model (should match training)"
)
with gr.Row():
oak_scan_btn = gr.Button("🔍 Scan for Models")
oak_export_btn = gr.Button("🚀 Export for OAK-D", variant="primary")
oak_status = gr.Textbox(
label="Export Status",
value="Ready to export",
lines=4
)
with gr.Column():
gr.Markdown("""
### 📋 Deployment Instructions
**After Export:**
1. **Test OpenVINO Model** (optional):
```bash
python -c "from openvino.runtime import Core; core = Core(); model = core.read_model('model.xml'); print('✓ Model loaded')"
```
2. **Convert to RVC compiled format** (recommended by Luxonis):
- Online: HubAI conversion (fastest setup)
- Offline: ModelConverter (requires Docker)
- Docs: https://docs.luxonis.com/software-v3/ai-inference/conversion/
3. **Deploy to OAK-D**:
- Use DepthAI Python API
- Or use OAK-D examples with your blob
### 💡 Tips
- **Nano models** work best on edge devices
- If you quantize, use real calibration images for best accuracy
- Test inference speed vs accuracy trade-off
""")
# Event handlers
def on_load():
return app.load_image("current")
# Settings handlers
load_images_btn.click(
app.load_new_images_dir,
inputs=[images_dir_input],
outputs=[image_display, boxes_text, info_text]
).then(
lambda: (app.get_current_dir_info(), app.get_current_model_info()),
outputs=[dir_info, model_info]
)
load_model_btn.click(
app.load_new_model,
inputs=[model_weights_input, model_type_dropdown],
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
scan_models_btn.click(
app.scan_for_models,
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
quick_load_btn.click(
app.load_model_by_index,
inputs=[model_selector],
outputs=[model_info]
).then(
app.get_available_models_list,
outputs=[model_selector]
)
prev_btn.click(
lambda: app.load_image("prev"),
outputs=[image_display, boxes_text, info_text]
)
next_btn.click(
lambda: app.load_image("next"),
outputs=[image_display, boxes_text, info_text]
)
auto_label_btn.click(
lambda t: app.auto_label_current(t),
inputs=[threshold_slider],
outputs=[image_display, boxes_text, info_text]
)
save_canvas_btn.click(
app.save_canvas_changes,
inputs=[canvas_boxes_data],
outputs=[image_display, boxes_text, info_text],
js="""() => {
const hiddenInput = document.getElementById('canvas-boxes-data');
if (hiddenInput) {
return hiddenInput.value;
}
return '';
}"""
)
add_box_btn.click(
app.add_box_manual,
inputs=[x1_input, y1_input, x2_input, y2_input],
outputs=[image_display, boxes_text, info_text]
)
delete_btn.click(
app.delete_last_box,
outputs=[image_display, boxes_text, info_text]
)
clear_btn.click(
app.clear_boxes,
outputs=[image_display, boxes_text, info_text]
)
export_btn.click(
lambda path: app.export_to_coco(Path(path)),
inputs=[export_path],
outputs=[export_result]
)
# Training handlers
prep_btn.click(
lambda out, train, valid: app.prepare_training_dataset(Path(out), train, valid),
inputs=[dataset_prep_dir, train_split, valid_split],
outputs=[prep_result]
)
start_train_btn.click(
app.start_training,
inputs=[model_framework, train_dataset_dir, train_output_dir, model_size, epochs, batch_size, learning_rate],
outputs=[training_status]
)
stop_train_btn.click(
app.stop_training,
outputs=[training_status]
)
refresh_status_btn.click(
app.get_training_status,
outputs=[training_status]
)
# OAK-D Deployment handlers
oak_scan_btn.click(
app.scan_for_models,
outputs=[oak_status]
).then(
app.get_available_models_list,
outputs=[oak_model_selector]
)
oak_export_btn.click(
app.export_for_oak_d,
inputs=[oak_model_selector, oak_output_dir, oak_img_size],
outputs=[oak_status]
)
# Load first image on start
demo.load(on_load, outputs=[image_display, boxes_text, info_text])
return demo
def main():
parser = argparse.ArgumentParser(description="Simple annotation GUI with auto-labeling")
parser.add_argument(
"--images-dir",
type=Path,
default=Path(DEFAULT_IMAGES_DIR) if DEFAULT_IMAGES_DIR else None,
help="Default directory with images (can be changed in GUI)"
)
parser.add_argument(
"--model-weights",
type=Path,
default=Path(DEFAULT_MODEL_WEIGHTS) if DEFAULT_MODEL_WEIGHTS else None,
help="Default trained model for auto-labeling (can be changed in GUI)"
)
parser.add_argument(
"--port",
type=int,
default=DEFAULT_PORT,
help="Port to run the GUI on"
)
args = parser.parse_args()
# Validate paths if provided
if args.images_dir and not args.images_dir.exists():
print(f"⚠️ Warning: Images directory not found: {args.images_dir}")
print("You can load a different directory from the GUI Settings")
args.images_dir = None
if args.model_weights and not args.model_weights.exists():
print(f"⚠️ Warning: Model weights not found: {args.model_weights}")
print("You can load different weights from the GUI Settings")
args.model_weights = None
# Create app
app = AnnotationApp(args.images_dir, args.model_weights)
# Scan for available models on startup
app.scan_for_models(return_info=False)
# Create and launch UI
demo = create_ui(app)
print(f"\n{'='*60}")
print(f"🚀 Starting annotation tool...")
if args.images_dir:
print(f"📁 Default images: {args.images_dir} ({len(app.image_paths)} images)")
else:
print(f"📁 No default images - load directory from Settings")
if app.model:
print(f"🤖 Model: Loaded from {args.model_weights}")
else:
print(f"⚠️ No model loaded - load from Settings or train one")
print(f"💡 You can change images directory and model weights from the Settings panel")
print(f"{'='*60}\n")
demo.launch(
server_name="0.0.0.0",
server_port=args.port,
share=False
)
if __name__ == "__main__":
main()