removed gradio
This commit is contained in:
148
GUI_README.md
148
GUI_README.md
@ -1,148 +0,0 @@
|
||||
# Custom Annotation GUI
|
||||
|
||||
A simple, **fully customizable** annotation tool built with Gradio (pure Python).
|
||||
|
||||
## Features
|
||||
|
||||
✅ **Auto-labeling** with your trained RF-DETR model
|
||||
✅ **Manual annotation** by entering box coordinates
|
||||
✅ **Edit/delete** annotations easily
|
||||
✅ **Navigation** between images
|
||||
✅ **Export** to COCO JSON format
|
||||
✅ **100% Python** - easy to modify and extend
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Install dependencies
|
||||
```bash
|
||||
/home/dillon/_code/saw_mill_knot_detection/.venv/bin/python -m pip install gradio>=4.0.0
|
||||
```
|
||||
|
||||
### 2. Run the GUI
|
||||
|
||||
**With auto-labeling (requires trained model):**
|
||||
```bash
|
||||
/home/dillon/_code/saw_mill_knot_detection/.venv/bin/python annotation_gui.py \
|
||||
--images-dir /path/to/images \
|
||||
--model-weights runs/knot_rfdetr_medium/checkpoint_best_total.pth
|
||||
```
|
||||
|
||||
**Manual annotation only:**
|
||||
```bash
|
||||
/home/dillon/_code/saw_mill_knot_detection/.venv/bin/python annotation_gui.py \
|
||||
--images-dir /path/to/images
|
||||
```
|
||||
|
||||
### 3. Open in browser
|
||||
Opens automatically at http://localhost:7860
|
||||
|
||||
## Usage
|
||||
|
||||
1. **Auto-Label**: Click "🤖 Auto-Label" to detect knots with your model
|
||||
2. **Adjust threshold**: Lower = more detections, Higher = only confident ones
|
||||
3. **Manual boxes**: Enter coordinates (x1, y1, x2, y2) and click "➕ Add Box"
|
||||
4. **Delete mistakes**: Click "🗑️ Delete Last" to remove last box
|
||||
5. **Navigate**: Use "Previous" / "Next" buttons
|
||||
6. **Export**: Click "💾 Export COCO" when done
|
||||
|
||||
## Customization Examples
|
||||
|
||||
### Add keyboard shortcuts
|
||||
```python
|
||||
# In create_ui(), add:
|
||||
image_display.keyboard_shortcuts = {
|
||||
"d": delete_btn.click, # Press 'd' to delete
|
||||
"n": next_btn.click, # Press 'n' for next
|
||||
}
|
||||
```
|
||||
|
||||
### Add interactive drawing
|
||||
```python
|
||||
# Replace manual coordinates with image annotator:
|
||||
from gradio_image_annotation import image_annotator
|
||||
|
||||
annotator = image_annotator(
|
||||
label="Draw boxes",
|
||||
type="numpy"
|
||||
)
|
||||
```
|
||||
|
||||
### Change box colors by confidence
|
||||
```python
|
||||
# In draw_boxes_on_image():
|
||||
color = "green" if conf > 0.8 else "yellow" if conf > 0.5 else "red"
|
||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
|
||||
```
|
||||
|
||||
### Add multiple label classes
|
||||
```python
|
||||
# Add a dropdown:
|
||||
label_choice = gr.Dropdown(
|
||||
choices=["knot", "crack", "hole"],
|
||||
value="knot",
|
||||
label="Label Type"
|
||||
)
|
||||
|
||||
# Update box dict:
|
||||
box = {
|
||||
"bbox": [x1, y1, x2, y2],
|
||||
"label": label_choice_value, # from the dropdown
|
||||
"confidence": 1.0
|
||||
}
|
||||
```
|
||||
|
||||
### Save checkpoints automatically
|
||||
```python
|
||||
# In _save_annotations(), add:
|
||||
import shutil
|
||||
backup_path = self.ann_file.with_suffix('.backup.json')
|
||||
shutil.copy(self.ann_file, backup_path)
|
||||
```
|
||||
|
||||
### Add image filters/preprocessing
|
||||
```python
|
||||
# Add before annotation:
|
||||
def preprocess_image(img: Image.Image) -> Image.Image:
|
||||
from PIL import ImageEnhance
|
||||
enhancer = ImageEnhance.Contrast(img)
|
||||
return enhancer.enhance(1.5) # Increase contrast
|
||||
```
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
annotation_gui.py
|
||||
├── AnnotationApp # Main logic (easy to extend)
|
||||
│ ├── auto_label_current() # Modify for different models
|
||||
│ ├── add_box_manual() # Customize annotation format
|
||||
│ ├── export_to_coco() # Change export format
|
||||
│ └── draw_boxes_on_image() # Customize visualization
|
||||
└── create_ui() # Gradio interface (add components)
|
||||
```
|
||||
|
||||
## Advantages vs Label Studio
|
||||
|
||||
| Feature | Custom GUI | Label Studio |
|
||||
|---------|-----------|--------------|
|
||||
| **Modify code** | ✅ Easy (pure Python) | ❌ Complex (React + Python) |
|
||||
| **Add features** | ✅ ~10-50 lines | ❌ Hundreds of lines |
|
||||
| **Custom models** | ✅ Direct integration | ⚠️ Need ONNX export |
|
||||
| **Learning curve** | ✅ Simple Gradio | ⚠️ Larger codebase |
|
||||
| **Setup** | ✅ pip install | ⚠️ Docker/complex |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Port already in use:**
|
||||
```bash
|
||||
python annotation_gui.py --images-dir /path --port 7861
|
||||
```
|
||||
|
||||
**Model not loading:**
|
||||
- Check the weights path exists
|
||||
- Verify it's a valid checkpoint file
|
||||
- Try without `--model-weights` for manual-only mode
|
||||
|
||||
**Need more features?**
|
||||
- Check Gradio docs: https://www.gradio.app/docs/
|
||||
- Add custom components easily
|
||||
- Fork and modify the code freely!
|
||||
40
README.md
40
README.md
@ -1,12 +1,12 @@
|
||||
# Saw Mill Knot Detection
|
||||
|
||||
This repository contains a complete wood defect detection system with a web-based annotation GUI and separate training/deployment scripts. Supports multiple model frameworks (RF-DETR, RT-DETR, YOLOv6, YOLOX) and is optimized for deployment on OAK-D cameras.
|
||||
This repository contains a complete wood defect detection system with a Tkinter-based annotation GUI and separate training/deployment scripts. Supports multiple model frameworks (RF-DETR, RT-DETR, YOLOv6, YOLOX) and is optimized for deployment on OAK-D cameras.
|
||||
|
||||
## 🎯 Project Overview
|
||||
|
||||
- **Models**: RF-DETR, RT-DETR, YOLOv6, YOLOX (all MIT/Apache 2.0 licensed)
|
||||
- **Dataset**: 20,276 wood surface defect images
|
||||
- **Annotation GUI**: Gradio-based web interface for manual annotation
|
||||
- **Annotation GUI**: Tkinter desktop app for manual annotation
|
||||
- **Training Scripts**: Separate Python scripts for model training
|
||||
- **Deployment**: OAK-D camera optimization with OpenVINO conversion
|
||||
- **License**: All models free for commercial use
|
||||
@ -47,21 +47,16 @@ pip install -r requirements.txt
|
||||
|
||||
### 2. Run the Annotation GUI
|
||||
|
||||
The repository includes an automated script that handles virtual environment activation:
|
||||
|
||||
```bash
|
||||
# Run the GUI (automatically detects and activates venv/conda environment)
|
||||
./run_gui.sh
|
||||
|
||||
# Or run manually
|
||||
source .venv/bin/activate # or conda activate your_env
|
||||
python annotation_gui.py
|
||||
./run_tk_gui.sh --images-dir IMAGE/
|
||||
# or
|
||||
python tk_annotation_gui.py --images-dir IMAGE/
|
||||
```
|
||||
|
||||
Install dependencies:
|
||||
Auto-label requires Ultralytics for YOLO/RT-DETR weights:
|
||||
|
||||
```bash
|
||||
pip install -U pip
|
||||
pip install ultralytics gradio rfdetr
|
||||
pip install ultralytics
|
||||
```
|
||||
|
||||
### 2. Setup Datasets
|
||||
@ -78,24 +73,13 @@ python setup_datasets.py # Creates dataset_coco/ and updates configs
|
||||
|
||||
### 3. Launch Annotation GUI
|
||||
|
||||
```bash
|
||||
python annotation_gui.py
|
||||
```
|
||||
|
||||
Tkinter version (new):
|
||||
|
||||
```bash
|
||||
python tk_annotation_gui.py
|
||||
# or
|
||||
./run_tk_gui.sh
|
||||
```
|
||||
|
||||
Open http://localhost:7860 in your browser to access the web-based annotation interface with:
|
||||
- Image navigation with index display
|
||||
- Auto-labeling with trained models
|
||||
- Manual annotation tools with delete buttons
|
||||
- Real-time result visualization
|
||||
- Export to COCO format
|
||||
The Tkinter GUI supports image navigation, autosave annotations, and optional auto-label.
|
||||
|
||||
### 4. Train Models
|
||||
|
||||
@ -127,7 +111,8 @@ python convert_for_deployment.py --model runs/training/weights/best.pt --output
|
||||
|
||||
```
|
||||
saw_mill_knot_detection/
|
||||
├── annotation_gui.py # Gradio web interface for annotation
|
||||
├── tk_annotation_gui.py # Tkinter desktop annotation GUI
|
||||
├── run_tk_gui.sh # Convenience launcher
|
||||
├── train_model.py # Unified training script for all frameworks
|
||||
├── convert_for_deployment.py # Model conversion for OAK-D deployment
|
||||
├── TRAINING_README.md # Detailed training and deployment guide
|
||||
@ -174,7 +159,7 @@ saw_mill_knot_detection/
|
||||
|
||||
### Annotation GUI Features
|
||||
|
||||
The Gradio-based annotation interface provides:
|
||||
The Tkinter annotation GUI provides:
|
||||
|
||||
- **Image Navigation**: Browse through dataset with current index display
|
||||
- **Auto-Labeling**: One-click defect detection using trained YOLOX model
|
||||
@ -267,4 +252,3 @@ This project uses the Kaggle Wood Surface Defects dataset. Please refer to the o
|
||||
|
||||
- Kaggle for providing the wood surface defects dataset
|
||||
- Ultralytics for the YOLO framework
|
||||
- Gradio for the web interface framework
|
||||
|
||||
@ -14,9 +14,9 @@ RT-DETR (Real-Time Detection Transformer) is Apache 2.0 licensed - **free for co
|
||||
|
||||
### 1. Annotate Images
|
||||
|
||||
Use the annotation GUI:
|
||||
Use the Tkinter annotation GUI:
|
||||
```bash
|
||||
.venv/bin/python annotation_gui.py
|
||||
.venv/bin/python tk_annotation_gui.py --images-dir IMAGE/
|
||||
```
|
||||
|
||||
- Load your images from Settings
|
||||
@ -25,24 +25,14 @@ Use the annotation GUI:
|
||||
|
||||
### 2. Train Model
|
||||
|
||||
From the GUI:
|
||||
1. Go to **Training** tab
|
||||
2. Click "Prepare Dataset" (creates train/valid/test splits)
|
||||
3. Select **RT-DETR** framework
|
||||
4. Choose model size:
|
||||
- `nano` (r18): Fastest, 30-40 FPS on OAK
|
||||
- `small` (r34): Balanced
|
||||
- `medium` (r50): More accurate
|
||||
- `base` (l): Best accuracy, slower
|
||||
5. Click "Start Training"
|
||||
|
||||
Or from command line:
|
||||
Train from the command line:
|
||||
```bash
|
||||
.venv/bin/python train_rtdetr.py \
|
||||
--dataset-dir dataset_prepared \
|
||||
--model rtdetr-r18 \
|
||||
--epochs 100 \
|
||||
--batch-size 8
|
||||
.venv/bin/python train_model.py \
|
||||
--framework rtdetr \
|
||||
--dataset dataset_prepared \
|
||||
--output runs/rtdetr_training \
|
||||
--model-size small \
|
||||
--epochs 100
|
||||
```
|
||||
|
||||
### 3. Test Model
|
||||
|
||||
121
TRAINING_README.md
Normal file
121
TRAINING_README.md
Normal file
@ -0,0 +1,121 @@
|
||||
# Training and Deployment Scripts
|
||||
|
||||
This directory contains separate scripts for training models and converting them for deployment, extracted from the annotation GUI.
|
||||
|
||||
## Training Script
|
||||
|
||||
### `train_model.py`
|
||||
|
||||
Train object detection models for wood knot detection.
|
||||
|
||||
**Supported frameworks:**
|
||||
- RF-DETR (MIT license)
|
||||
- RT-DETR (Apache 2.0 license)
|
||||
- YOLOv6 (MIT license)
|
||||
- YOLOX (MIT license)
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
python train_model.py --framework rtdetr --dataset dataset_prepared --output runs/training
|
||||
|
||||
# Full options
|
||||
python train_model.py \
|
||||
--framework rtdetr \
|
||||
--dataset dataset_prepared \
|
||||
--output runs/training \
|
||||
--model-size small \
|
||||
--epochs 20 \
|
||||
--batch-size 4 \
|
||||
--lr 0.001 \
|
||||
--prepare-dataset \
|
||||
--images-dir IMAGE \
|
||||
--annotations annotations.json
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--framework`: Model framework (rf-detr, rt-detr, yolov6, yolox)
|
||||
- `--dataset`: Path to prepared dataset directory
|
||||
- `--output`: Output directory for trained model
|
||||
- `--model-size`: Model size/variant (nano, small, medium, base)
|
||||
- `--epochs`: Number of training epochs
|
||||
- `--batch-size`: Batch size for training
|
||||
- `--lr`: Learning rate
|
||||
- `--prepare-dataset`: Prepare dataset from annotations first
|
||||
- `--images-dir`: Images directory (for --prepare-dataset)
|
||||
- `--annotations`: Annotations file (for --prepare-dataset)
|
||||
|
||||
## Deployment Conversion Script
|
||||
|
||||
### `convert_for_deployment.py`
|
||||
|
||||
Convert trained models for OAK-D deployment.
|
||||
|
||||
**Supported conversions:**
|
||||
- ONNX export
|
||||
- OpenVINO IR export
|
||||
- Model optimization for edge devices
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
python convert_for_deployment.py --model runs/training/weights/best.pt --output oak_d_deployment
|
||||
|
||||
# Full options
|
||||
python convert_for_deployment.py \
|
||||
--model runs/training/weights/best.pt \
|
||||
--output oak_d_deployment \
|
||||
--img-size 640 \
|
||||
--framework auto
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--model`: Path to trained model weights (.pt file)
|
||||
- `--output`: Output directory for converted models
|
||||
- `--img-size`: Input image size for the model (320, 416, 512, 640, 800, 1024)
|
||||
- `--framework`: Model framework (auto-detect if not specified)
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Annotate images** using the Tkinter GUI (`./run_tk_gui.sh` or `python tk_annotation_gui.py`)
|
||||
2. **Export annotations** to COCO format from the GUI
|
||||
3. **Prepare dataset** (optional, can be done by training script):
|
||||
```bash
|
||||
python train_model.py --prepare-dataset --images-dir IMAGE --annotations annotations.json --dataset dataset_prepared
|
||||
```
|
||||
4. **Train model**:
|
||||
```bash
|
||||
python train_model.py --framework rtdetr --dataset dataset_prepared --output runs/training
|
||||
```
|
||||
5. **Convert for deployment**:
|
||||
```bash
|
||||
python convert_for_deployment.py --model runs/training/weights/best.pt --output oak_d_deployment
|
||||
```
|
||||
|
||||
## Next Steps After Conversion
|
||||
|
||||
After running `convert_for_deployment.py`:
|
||||
|
||||
1. **Test OpenVINO Model** (optional):
|
||||
```bash
|
||||
python -c "from openvino.runtime import Core; core = Core(); model = core.read_model('model.xml'); print('✓ Model loaded')"
|
||||
```
|
||||
|
||||
2. **Convert to RVC compiled format** (recommended by Luxonis):
|
||||
- Online: HubAI conversion (fastest setup)
|
||||
- Offline: ModelConverter (requires Docker)
|
||||
- Docs: https://docs.luxonis.com/software-v3/ai-inference/conversion/
|
||||
|
||||
3. **Deploy to OAK-D**:
|
||||
- Use DepthAI Python API
|
||||
- Or use OAK-D examples with your blob
|
||||
|
||||
## Tips
|
||||
|
||||
- **Nano models** work best on edge devices
|
||||
- If you quantize, use real calibration images for best accuracy
|
||||
- Test inference speed vs accuracy trade-off
|
||||
- All models are MIT/Apache 2.0 licensed - free for commercial use!
|
||||
<parameter name="filePath">/home/dillon/_code/saw_mill_knot_detection/TRAINING_README.md
|
||||
1270
annotation_gui.py
1270
annotation_gui.py
File diff suppressed because it is too large
Load Diff
@ -1,708 +0,0 @@
|
||||
"""
|
||||
Simple customizable annotation GUI with auto-labeling support.
|
||||
|
||||
Built with Gradio - easy to modify and extend.
|
||||
Run: python annotation_gui.py --images-dir /path/to/images
|
||||
|
||||
To set default paths, edit config.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
# Try to load config, use fallbacks if not available
|
||||
try:
|
||||
from config import (
|
||||
DEFAULT_IMAGES_DIR, DEFAULT_MODEL_WEIGHTS, DEFAULT_PORT,
|
||||
DEFAULT_DETECTION_THRESHOLD, DEFAULT_TRAIN_EPOCHS,
|
||||
DEFAULT_BATCH_SIZE, DEFAULT_LEARNING_RATE, DEFAULT_MODEL_SIZE
|
||||
)
|
||||
except ImportError:
|
||||
DEFAULT_IMAGES_DIR = None
|
||||
DEFAULT_MODEL_WEIGHTS = None
|
||||
DEFAULT_PORT = 7860
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.5
|
||||
DEFAULT_TRAIN_EPOCHS = 20
|
||||
DEFAULT_BATCH_SIZE = 4
|
||||
DEFAULT_LEARNING_RATE = 1e-4
|
||||
DEFAULT_MODEL_SIZE = "small"
|
||||
|
||||
|
||||
class AnnotationApp:
|
||||
def __init__(self, images_dir: Path | None = None, model_weights: Path | None = None):
|
||||
self.images_dir = images_dir if images_dir else Path.cwd()
|
||||
self.current_model_path = model_weights
|
||||
self.image_paths = []
|
||||
self.current_idx = 0
|
||||
self.annotations = {} # image_name -> list of boxes
|
||||
self.model = None
|
||||
self.training_process = None
|
||||
self.training_thread = None
|
||||
self.training_status = "Not training"
|
||||
|
||||
# Load images if directory provided
|
||||
if images_dir and images_dir.exists():
|
||||
self._load_images(images_dir)
|
||||
|
||||
if model_weights and model_weights.exists():
|
||||
self._load_model(model_weights)
|
||||
|
||||
def _load_images(self, images_dir: Path):
|
||||
"""Load images from directory."""
|
||||
self.images_dir = images_dir
|
||||
self.image_paths = sorted(
|
||||
list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
|
||||
)
|
||||
self.current_idx = 0
|
||||
|
||||
# Load existing annotations if present
|
||||
self.ann_file = images_dir / "annotations.json"
|
||||
if self.ann_file.exists():
|
||||
with self.ann_file.open("r") as f:
|
||||
self.annotations = json.load(f)
|
||||
else:
|
||||
self.annotations = {}
|
||||
|
||||
return f"✓ Loaded {len(self.image_paths)} images from {images_dir}"
|
||||
|
||||
def _load_model(self, weights_path: Path):
|
||||
"""Load RF-DETR model for auto-labeling."""
|
||||
try:
|
||||
from rfdetr import RFDETRBase
|
||||
print(f"Loading model from {weights_path}...")
|
||||
self.model = RFDETRBase(pretrain_weights=str(weights_path))
|
||||
self.current_model_path = weights_path
|
||||
print("✓ Model loaded")
|
||||
return f"✓ Model loaded from {weights_path.name}"
|
||||
except Exception as e:
|
||||
error_msg = f"⚠ Could not load model: {e}"
|
||||
print(error_msg)
|
||||
self.model = None
|
||||
return error_msg
|
||||
|
||||
def load_new_model(self, weights_path: str) -> str:
|
||||
"""Load a new model from the GUI."""
|
||||
path = Path(weights_path)
|
||||
if not path.exists():
|
||||
return f"❌ File not found: {weights_path}"
|
||||
|
||||
return self._load_model(path)
|
||||
|
||||
def load_new_images_dir(self, images_dir: str) -> tuple[Image.Image | None, str, str]:
|
||||
"""Load a new images directory from the GUI."""
|
||||
path = Path(images_dir)
|
||||
if not path.exists():
|
||||
return None, "", f"❌ Directory not found: {images_dir}"
|
||||
|
||||
if not path.is_dir():
|
||||
return None, "", f"❌ Not a directory: {images_dir}"
|
||||
|
||||
result = self._load_images(path)
|
||||
|
||||
# Load first image
|
||||
if self.image_paths:
|
||||
img, filename = self.get_current_image()
|
||||
boxes = self.annotations.get(filename, [])
|
||||
img_with_boxes = self.draw_boxes_on_image(img, boxes) if boxes else img
|
||||
boxes_text = self._format_boxes_text(boxes)
|
||||
info = f"{result}\nImage 1/{len(self.image_paths)}: {filename}"
|
||||
return img_with_boxes, boxes_text, info
|
||||
else:
|
||||
return None, "", f"{result}\n⚠️ No .jpg or .png images found in directory"
|
||||
|
||||
def get_current_model_info(self) -> str:
|
||||
"""Get info about currently loaded model."""
|
||||
if self.model and self.current_model_path:
|
||||
return f"📦 Loaded: {self.current_model_path}"
|
||||
elif self.model:
|
||||
return "📦 Model loaded (pretrained)"
|
||||
else:
|
||||
return "⚠️ No model loaded"
|
||||
|
||||
def get_current_dir_info(self) -> str:
|
||||
"""Get info about current images directory."""
|
||||
return f"📁 {self.images_dir} ({len(self.image_paths)} images)"
|
||||
|
||||
def get_current_image(self) -> tuple[Image.Image, str]:
|
||||
"""Get current image and filename."""
|
||||
if not self.image_paths:
|
||||
return None, ""
|
||||
path = self.image_paths[self.current_idx]
|
||||
img = Image.open(path).convert("RGB")
|
||||
return img, path.name
|
||||
|
||||
def draw_boxes_on_image(self, img: Image.Image, boxes: list[dict]) -> Image.Image:
|
||||
"""Draw bounding boxes on image."""
|
||||
img_draw = img.copy()
|
||||
draw = ImageDraw.Draw(img_draw)
|
||||
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
label = box.get("label", "knot")
|
||||
conf = box.get("confidence", 1.0)
|
||||
|
||||
# Draw box
|
||||
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
|
||||
|
||||
# Draw label
|
||||
text = f"{label} {conf:.2f}" if conf < 1.0 else label
|
||||
draw.text((x1, y1 - 20), text, fill="red")
|
||||
|
||||
return img_draw
|
||||
|
||||
def auto_label_current(self, threshold: float = 0.5) -> tuple[Image.Image, str, str]:
|
||||
"""Auto-label current image with model."""
|
||||
if not self.model:
|
||||
return self.get_current_image()[0], "", "⚠ No model loaded"
|
||||
|
||||
img, filename = self.get_current_image()
|
||||
if not img:
|
||||
return None, "", "No images"
|
||||
|
||||
# Run inference
|
||||
detections = self.model.predict(img, threshold=threshold)
|
||||
|
||||
# Convert to our format
|
||||
boxes = []
|
||||
for i in range(len(detections)):
|
||||
xyxy = detections.xyxy[i].tolist()
|
||||
conf = float(detections.confidence[i]) if detections.confidence is not None else 1.0
|
||||
boxes.append({
|
||||
"bbox": xyxy,
|
||||
"label": "knot",
|
||||
"confidence": conf,
|
||||
"source": "auto"
|
||||
})
|
||||
|
||||
# Save
|
||||
self.annotations[filename] = boxes
|
||||
self._save_annotations()
|
||||
|
||||
# Draw
|
||||
img_with_boxes = self.draw_boxes_on_image(img, boxes)
|
||||
info = f"✓ Auto-labeled: {len(boxes)} boxes detected"
|
||||
boxes_text = self._format_boxes_text(boxes)
|
||||
|
||||
return img_with_boxes, boxes_text, info
|
||||
|
||||
def _format_boxes_text(self, boxes: list[dict]) -> str:
|
||||
"""Format boxes for display."""
|
||||
if not boxes:
|
||||
return "No annotations"
|
||||
|
||||
lines = []
|
||||
for i, box in enumerate(boxes):
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
conf = box.get("confidence", 1.0)
|
||||
source = box.get("source", "manual")
|
||||
lines.append(f"{i}: [{x1:.0f}, {y1:.0f}, {x2:.0f}, {y2:.0f}] conf={conf:.2f} ({source})")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def load_image(self, direction: str = "current") -> tuple[Image.Image, str, str]:
|
||||
"""Load image (current/next/prev)."""
|
||||
if direction == "next":
|
||||
self.current_idx = min(self.current_idx + 1, len(self.image_paths) - 1)
|
||||
elif direction == "prev":
|
||||
self.current_idx = max(self.current_idx - 1, 0)
|
||||
|
||||
img, filename = self.get_current_image()
|
||||
if not img:
|
||||
return None, "", "No images"
|
||||
|
||||
# Load existing annotations
|
||||
boxes = self.annotations.get(filename, [])
|
||||
img_with_boxes = self.draw_boxes_on_image(img, boxes) if boxes else img
|
||||
boxes_text = self._format_boxes_text(boxes)
|
||||
info = f"Image {self.current_idx + 1}/{len(self.image_paths)}: {filename}"
|
||||
|
||||
return img_with_boxes, boxes_text, info
|
||||
|
||||
def add_box_manual(self, x1: int, y1: int, x2: int, y2: int) -> tuple[Image.Image, str, str]:
|
||||
"""Manually add a bounding box."""
|
||||
img, filename = self.get_current_image()
|
||||
if not img:
|
||||
return None, "", "No images"
|
||||
|
||||
# Add box
|
||||
box = {
|
||||
"bbox": [float(x1), float(y1), float(x2), float(y2)],
|
||||
"label": "knot",
|
||||
"confidence": 1.0,
|
||||
"source": "manual"
|
||||
}
|
||||
|
||||
if filename not in self.annotations:
|
||||
self.annotations[filename] = []
|
||||
self.annotations[filename].append(box)
|
||||
self._save_annotations()
|
||||
|
||||
# Redraw
|
||||
boxes = self.annotations[filename]
|
||||
img_with_boxes = self.draw_boxes_on_image(img, boxes)
|
||||
boxes_text = self._format_boxes_text(boxes)
|
||||
info = f"✓ Added box: {len(boxes)} total"
|
||||
|
||||
return img_with_boxes, boxes_text, info
|
||||
|
||||
def delete_last_box(self) -> tuple[Image.Image, str, str]:
|
||||
"""Delete the last box from current image."""
|
||||
img, filename = self.get_current_image()
|
||||
if not img:
|
||||
return None, "", "No images"
|
||||
|
||||
if filename in self.annotations and self.annotations[filename]:
|
||||
self.annotations[filename].pop()
|
||||
self._save_annotations()
|
||||
|
||||
# Redraw
|
||||
boxes = self.annotations.get(filename, [])
|
||||
img_with_boxes = self.draw_boxes_on_image(img, boxes) if boxes else img
|
||||
boxes_text = self._format_boxes_text(boxes)
|
||||
info = f"✓ Deleted last box: {len(boxes)} remaining"
|
||||
|
||||
return img_with_boxes, boxes_text, info
|
||||
|
||||
def clear_boxes(self) -> tuple[Image.Image, str, str]:
|
||||
"""Clear all boxes from current image."""
|
||||
img, filename = self.get_current_image()
|
||||
if not img:
|
||||
return None, "", "No images"
|
||||
|
||||
self.annotations[filename] = []
|
||||
self._save_annotations()
|
||||
|
||||
boxes_text = "No annotations"
|
||||
info = "✓ Cleared all boxes"
|
||||
|
||||
return img, boxes_text, info
|
||||
|
||||
def _save_annotations(self):
|
||||
"""Save annotations to JSON file."""
|
||||
with self.ann_file.open("w") as f:
|
||||
json.dump(self.annotations, f, indent=2)
|
||||
|
||||
def export_to_coco(self, output_path: Path):
|
||||
"""Export annotations to COCO format."""
|
||||
coco_data = {
|
||||
"images": [],
|
||||
"annotations": [],
|
||||
"categories": [{"id": 0, "name": "knot", "supercategory": "defect"}]
|
||||
}
|
||||
|
||||
ann_id = 0
|
||||
for img_id, img_path in enumerate(self.image_paths):
|
||||
filename = img_path.name
|
||||
img = Image.open(img_path)
|
||||
width, height = img.size
|
||||
|
||||
coco_data["images"].append({
|
||||
"id": img_id,
|
||||
"file_name": filename,
|
||||
"width": width,
|
||||
"height": height
|
||||
})
|
||||
|
||||
# Add annotations
|
||||
boxes = self.annotations.get(filename, [])
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
|
||||
coco_data["annotations"].append({
|
||||
"id": ann_id,
|
||||
"image_id": img_id,
|
||||
"category_id": 0,
|
||||
"bbox": [x1, y1, w, h],
|
||||
"area": w * h,
|
||||
"iscrowd": 0,
|
||||
"score": box.get("confidence", 1.0)
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
with output_path.open("w") as f:
|
||||
|
||||
def prepare_training_dataset(self, output_dir: Path, train_split: float = 0.8, valid_split: float = 0.1):
|
||||
"""Prepare dataset in RF-DETR format (train/valid/test splits)."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create splits
|
||||
import random
|
||||
annotated_images = [img for img in self.image_paths if img.name in self.annotations and self.annotations[img.name]]
|
||||
|
||||
if len(annotated_images) < 10:
|
||||
return f"⚠️ Need at least 10 annotated images, have {len(annotated_images)}"
|
||||
|
||||
random.shuffle(annotated_images)
|
||||
n = len(annotated_images)
|
||||
train_n = int(n * train_split)
|
||||
valid_n = int(n * valid_split)
|
||||
|
||||
splits = {
|
||||
"train": annotated_images[:train_n],
|
||||
"valid": annotated_images[train_n:train_n + valid_n],
|
||||
"test": annotated_images[train_n + valid_n:]
|
||||
}
|
||||
|
||||
# Create directories and copy images
|
||||
import shutil
|
||||
for split_name, split_images in splits.items():
|
||||
split_dir = output_dir / split_name
|
||||
split_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Prepare COCO JSON for this split
|
||||
coco_data = {
|
||||
"images": [],
|
||||
"annotations": [],
|
||||
"categories": [{"id": 0, "name": "knot", "supercategory": "defect"}]
|
||||
}
|
||||
|
||||
ann_id = 0
|
||||
for img_id, img_path in enumerate(split_images):
|
||||
# Copy image
|
||||
dest = split_dir / img_path.name
|
||||
shutil.copy2(img_path, dest)
|
||||
|
||||
# Add to COCO
|
||||
img = Image.open(img_path)
|
||||
width, height = img.size
|
||||
|
||||
coco_data["images"].append({
|
||||
"id": img_id,
|
||||
"file_name": img_path.name,
|
||||
"width": width,
|
||||
"height": height
|
||||
})
|
||||
|
||||
# Add annotations
|
||||
boxes = self.annotations.get(img_path.name, [])
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
|
||||
coco_data["annotations"].append({
|
||||
"id": ann_id,
|
||||
"image_id": img_id,
|
||||
"category_id": 0,
|
||||
"bbox": [x1, y1, w, h],
|
||||
"area": w * h,
|
||||
"iscrowd": 0
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
# Save COCO JSON
|
||||
with (split_dir / "_annotations.coco.json").open("w") as f:
|
||||
json.dump(coco_data, f, indent=2)
|
||||
|
||||
return f"✓ Dataset prepared: {len(splits['train'])} train, {len(splits['valid'])} valid, {len(splits['test'])} test"
|
||||
|
||||
with gr.Row():
|
||||
start_train_btn = gr.Button("🚀 Start Training", variant="primary")
|
||||
stop_train_btn = gr.Button("⏹️ Stop Training", variant="stop")
|
||||
refresh_status_btn = gr.Button("🔄 Refresh Status")
|
||||
|
||||
training_status = gr.Textbox(
|
||||
label="Training Status",
|
||||
value="Not training",
|
||||
lines=3
|
||||
)
|
||||
|
||||
gr.Markdown("""
|
||||
**Note**: Training runs in the background. You can continue annotating while training.
|
||||
Check the training log file for detailed progress.
|
||||
""
|
||||
boxes = self.annotations.get(img_path.name, [])
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
|
||||
coco_data["annotations"].append({
|
||||
"id": ann_id,
|
||||
"image_id": img_id,
|
||||
"category_id": 0,
|
||||
"bbox": [x1, y1, w, h],
|
||||
"area": w * h,
|
||||
"iscrowd": 0
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
# Save COCO JSON
|
||||
with (split_dir / "_annotations.coco.json").open("w") as f:
|
||||
json.dump(coco_data, f, indent=2)
|
||||
|
||||
return f"✓ Dataset prepared: {len(splits['train'])} train, {len(splits['valid'])} valid, {len(splits['test'])} test"
|
||||
|
||||
def start_training(self, dataset_dir: str, output_dir: str, model_size: str,
|
||||
epochs: int, batch_size: int, lr: float, progress=gr.Progress()):
|
||||
"""Start training in background."""
|
||||
dataset_path = Path(dataset_dir)
|
||||
output_path = Path(output_dir)
|
||||
|
||||
if not dataset_path.exists():
|
||||
return "❌ Dataset directory not found"
|
||||
|
||||
if self.training_process and self.training_process.poll() is None:
|
||||
return "⚠️ Training already in progress"
|
||||
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build training command
|
||||
venv_python = Path(__file__).parent / ".venv/bin/python"
|
||||
train_script = Path(__file__).parent / "train_rfdetr.py"
|
||||
|
||||
cmd = [
|
||||
str(venv_python),
|
||||
str(train_script),
|
||||
"--dataset-dir", str(dataset_path),
|
||||
"--output-dir", str(output_path),
|
||||
"--model", model_size,
|
||||
"--epochs", str(epochs),
|
||||
"--batch-size", str(batch_size),
|
||||
"--grad-accum-steps", str(max(1, 16 // batch_size)),
|
||||
"--lr", str(lr)
|
||||
]
|
||||
|
||||
# Start training process
|
||||
log_file = output_path / "training.log"
|
||||
self.training_status = f"🚀 Starting training..."
|
||||
|
||||
def run_training():
|
||||
try:
|
||||
with log_file.open("w") as f:
|
||||
self.training_process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=f,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True
|
||||
)
|
||||
self.training_status = f"⏳ Training in progress (PID: {self.training_process.pid})"
|
||||
self.training_process.wait()
|
||||
|
||||
if self.training_process.returncode == 0:
|
||||
self.training_status = "✅ Training completed successfully!"
|
||||
# Reload model with new weights
|
||||
best_weights = output_path / "checkpoint_best_total.pth"
|
||||
if best_weights.exists():
|
||||
self._load_model(best_weights)
|
||||
else:
|
||||
self.training_status = f"❌ Training failed (exit code {self.training_process.returncode})"
|
||||
except Exception as e:
|
||||
self.training_status = f"❌ Error: {e}"
|
||||
|
||||
self.training_thread = threading.Thread(target=run_training, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
return f"✓ Training started! Check {log_file} for progress"
|
||||
|
||||
def get_training_status(self):
|
||||
"""Get current training status."""
|
||||
return self.training_status
|
||||
|
||||
def stop_training(self):
|
||||
"""Stop the training process."""
|
||||
if self.training_process and self.training_process.poll() is None:
|
||||
self.training_process.terminate()
|
||||
self.training_status = "⏹️ Training stopped by user"
|
||||
return "✓ Training process terminated"
|
||||
return "⚠️ No training in progress"
|
||||
json.dump(coco_data, f, indent=2)
|
||||
|
||||
return f"✓ Exported {len(coco_data['annotations'])} annotations to {output_path}"
|
||||
|
||||
|
||||
def create_ui(app: AnnotationApp) -> gr.Blocks:
|
||||
"""Create Gradio UI."""
|
||||
|
||||
with gr.Blocks(title="Knot Annotation Tool") as demo:
|
||||
gr.Markdown("""
|
||||
# 🪵 Wood Knot Annotation Tool
|
||||
**Label -> Train -> Auto-Label -> Repeat**
|
||||
|
||||
- Manually annotate images or use **Auto-Label** with your trained model
|
||||
- Export and prepare dataset for training
|
||||
- Train RF-DETR directly from this GUI
|
||||
- Use trained model to auto-label more images
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
image_display = gr.Image(label="Current Image", type="pil")
|
||||
|
||||
with gr.Row():
|
||||
prev_btn = gr.Button("⬅️ Previous")
|
||||
next_btn = gr.Button("Next ➡️")
|
||||
auto_label_btn = gr.Button("🤖 Auto-Label", variant="primary")
|
||||
|
||||
with gr.Row():
|
||||
threshold_slider = gr.Slider(0.1, 0.9, DEFAULT_DETECTION_THRESHOLD, label="Detection Threshold")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
info_text = gr.Textbox(label="Status", lines=2)
|
||||
boxes_text = gr.Textbox(label="Annotations", lines=10)
|
||||
|
||||
gr.Markdown("### Manual Annotation")
|
||||
with gr.Row():
|
||||
x1_input = gr.Number(label="x1", value=100)
|
||||
y1_input = gr.Number(label="y1", value=100)
|
||||
with gr.Row():
|
||||
Training handlers
|
||||
prep_btn.click(
|
||||
lambda out, train, valid: app.prepare_training_dataset(Path(out), train, valid),
|
||||
inputs=[dataset_prep_dir, train_split, valid_split],
|
||||
outputs=[prep_result]
|
||||
)
|
||||
|
||||
start_train_btn.click(
|
||||
app.start_training,
|
||||
inputs=[train_dataset_dir, train_output_dir, model_size, epochs, batch_size, learning_rate],
|
||||
outputs=[training_status]
|
||||
)
|
||||
|
||||
stop_train_btn.click(
|
||||
app.stop_training,
|
||||
outputs=[training_status]
|
||||
)
|
||||
|
||||
refresh_status_btn.click(
|
||||
app.get_training_status,
|
||||
outputs=[training_status]
|
||||
)
|
||||
|
||||
# x2_input = gr.Number(label="x2", value=200)
|
||||
y2_input = gr.Number(label="y2", value=200)
|
||||
|
||||
add_box_btn = gr.Button("➕ Add Box")
|
||||
delete_btn = gr.Button("🗑️ Delete Last")
|
||||
clear_btn = gr.Button("❌ Clear All")
|
||||
|
||||
gr.Markdown("### Export")
|
||||
export_path = gr.Textbox(
|
||||
label="Output Path",
|
||||
value="annotations_coco.json"
|
||||
)
|
||||
export_btn = gr.Button("💾 Export COCO")
|
||||
export_result = gr.Textbox(label="Export Result")
|
||||
|
||||
# Event handlers
|
||||
def on_load():
|
||||
return app.load_image("current")
|
||||
|
||||
# Settings handlers
|
||||
load_images_btn.click(
|
||||
app.load_new_images_dir,
|
||||
inputs=[images_dir_input],
|
||||
outputs=[image_display, boxes_text, info_text]
|
||||
).then(
|
||||
lambda: (app.get_current_dir_info(), app.get_current_model_info()),
|
||||
outputs=[dir_info, model_info]
|
||||
)
|
||||
|
||||
load_model_btn.click(
|
||||
app.load_new_model,
|
||||
inputs=[model_weights_input],
|
||||
outputs=[model_info]
|
||||
)
|
||||
|
||||
prev_btn.click(
|
||||
lambda: app.load_image("prev"),
|
||||
outputs=[image_display, boxes_text, info_text]
|
||||
)
|
||||
|
||||
next_btn.click(
|
||||
lambda: app.load_image("next"),
|
||||
outputs=[image_display, boxes_text, info_text]
|
||||
)
|
||||
|
||||
auto_label_btn.click(
|
||||
lambda t: app.auto_label_current(t),
|
||||
inputs=[threshold_slider],
|
||||
outputs=[image_display, boxes_text, info_text]
|
||||
)
|
||||
|
||||
add_box_btn.click(
|
||||
app.add_box_manual,
|
||||
inputs=[x1_input, y1_input, x2_input, y2_input],
|
||||
outputs=[image_display, boxes_text, info_text]
|
||||
)
|
||||
|
||||
delete_btn.click(
|
||||
app.delete_last_box,
|
||||
outputs=[image_display, boxes_text, info_text]
|
||||
)
|
||||
|
||||
clear_btn.click(
|
||||
app.clear_boxes,
|
||||
outputs=[image_display, boxes_text, info_text]
|
||||
)help="Default directory with images (can be changed in GUI)")
|
||||
parser.add_argument("--model-weights", type=Path, help="Default trained model for auto-labeling (can be changed in GUI)")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port for web interface")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate paths if provided
|
||||
if args.images_dir and not args.images_dir.exists():
|
||||
print(f"⚠️ Warning: Images directory not found: {args.images_dir}")
|
||||
print("You can load a different directory from the GUI Settings")
|
||||
args.images_dir = None
|
||||
|
||||
if args.model_weights and not args.model_weights.exists():
|
||||
print(f"⚠️ Warning: Model weights not found: {args.model_weights}")
|
||||
print("You can load different weights from the GUI Settings")
|
||||
args.model_weights = None
|
||||
|
||||
# Load first image on start
|
||||
demo.load(on_load, outputs=[image_display, boxes_text, info_text])
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Simple annotation GUI with auto-labeling")
|
||||
parser.add_argument("--images-dir", type=Path, required=True, help="Directory with images")
|
||||
parser.add_argument("--model-weights", type=Path, help="Optional: trained model for auto-labeling")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port for web interface")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.images_dir.exists():
|
||||
raise SystemExit(f"Images directory not found: {args.images_dir}")
|
||||
|
||||
if args.images_dir:
|
||||
print(f"📁 Default images: {args.images_dir} ({len(app.image_paths)} images)")
|
||||
else:
|
||||
print(f"📁 No default images - load directory from Settings")
|
||||
if app.model:
|
||||
print(f"🤖 Model: Loaded from {args.model_weights}")
|
||||
else:
|
||||
print(f"⚠️ No model loaded - load from Settings or train one")
|
||||
print(f"💡 You can change images directory and model weights from the Settings panel
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"🚀 Starting annotation tool...")
|
||||
print(f"📁 Images: {args.images_dir} ({len(app.image_paths)} images)")
|
||||
if app.model:
|
||||
print(f"🤖 Model: Loaded from {args.model_weights}")
|
||||
else:
|
||||
print(f"⚠️ No model loaded (manual annotation only)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=args.port,
|
||||
share=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
247
convert_for_deployment.py
Executable file
247
convert_for_deployment.py
Executable file
@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert trained models for OAK-D deployment.
|
||||
|
||||
Supports conversion to ONNX and OpenVINO formats for edge deployment.
|
||||
|
||||
Usage:
|
||||
python convert_for_deployment.py --model runs/training/weights/best.pt --output oak_d_deployment
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def convert_rfdetr_for_oak(model_path: Path, output_dir: Path, img_size: int = 640):
|
||||
"""Convert RF-DETR model for OAK-D deployment."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from export_rtdetr_oak import export_rfdetr_onnx
|
||||
|
||||
# Export to ONNX
|
||||
onnx_path = output_dir / "model.onnx"
|
||||
export_rfdetr_onnx(str(model_path), str(onnx_path), img_size)
|
||||
|
||||
# Convert ONNX to OpenVINO
|
||||
import subprocess
|
||||
xml_path = output_dir / "model.xml"
|
||||
bin_path = output_dir / "model.bin"
|
||||
|
||||
cmd = [
|
||||
"mo", "--input_model", str(onnx_path),
|
||||
"--output_dir", str(output_dir),
|
||||
"--model_name", "model",
|
||||
"--input_shape", f"[1,3,{img_size},{img_size}]",
|
||||
"--data_type", "FP16"
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
return f"❌ OpenVINO conversion failed: {result.stderr}"
|
||||
|
||||
return f"✓ RF-DETR exported for OAK-D!\n📁 Output: {output_dir}\n🔗 Next: Convert ONNX → RVC using HubAI (online) or ModelConverter (offline)."
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ RF-DETR conversion failed: {e}"
|
||||
|
||||
|
||||
def convert_rtdetr_for_oak(model_path: Path, output_dir: Path, img_size: int = 640):
|
||||
"""Convert RT-DETR model for OAK-D deployment."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from export_rtdetr_oak import export_rtdetr_openvino
|
||||
|
||||
# Export to OpenVINO directly
|
||||
xml_path, bin_path = export_rtdetr_openvino(str(model_path), str(output_dir), img_size)
|
||||
|
||||
return f"✓ RT-DETR exported for OAK-D!\n📁 Output: {output_dir}\n🔗 Next: Convert .xml/.bin → blob using blobconverter.luxonis.com"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ RT-DETR conversion failed: {e}"
|
||||
|
||||
|
||||
def convert_yolov6_for_oak(model_path: Path, output_dir: Path, img_size: int = 640):
|
||||
"""Convert YOLOv6 model for OAK-D deployment."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from export_onnx import export_yolov6_onnx
|
||||
|
||||
# Export to ONNX
|
||||
onnx_path = export_yolov6_onnx(str(model_path), str(output_dir), img_size)
|
||||
|
||||
# Convert ONNX to OpenVINO
|
||||
import subprocess
|
||||
xml_path = output_dir / "model.xml"
|
||||
bin_path = output_dir / "model.bin"
|
||||
|
||||
cmd = [
|
||||
"mo", "--input_model", str(onnx_path),
|
||||
"--output_dir", str(output_dir),
|
||||
"--model_name", "model",
|
||||
"--input_shape", f"[1,3,{img_size},{img_size}]",
|
||||
"--data_type", "FP16",
|
||||
"--reverse_input_channels"
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
return f"❌ OpenVINO conversion failed: {result.stderr}"
|
||||
|
||||
return f"✓ YOLOv6 exported for OAK-D!\n📁 Output: {output_dir}\n🔗 Next: Convert .xml/.bin → blob using blobconverter.luxonis.com"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ YOLOv6 conversion failed: {e}"
|
||||
|
||||
|
||||
def convert_yolox_for_oak(model_path: Path, output_dir: Path, img_size: int = 640):
|
||||
"""Convert YOLOX model for OAK-D deployment."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from export_onnx import export_yolox_onnx
|
||||
|
||||
# Export to ONNX
|
||||
onnx_path = export_yolox_onnx(str(model_path), str(output_dir), img_size)
|
||||
|
||||
# Convert ONNX to OpenVINO
|
||||
import subprocess
|
||||
xml_path = output_dir / "model.xml"
|
||||
bin_path = output_dir / "model.bin"
|
||||
|
||||
cmd = [
|
||||
"mo", "--input_model", str(onnx_path),
|
||||
"--output_dir", str(output_dir),
|
||||
"--model_name", "model",
|
||||
"--input_shape", f"[1,3,{img_size},{img_size}]",
|
||||
"--data_type", "FP16",
|
||||
"--reverse_input_channels"
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
return f"❌ OpenVINO conversion failed: {result.stderr}"
|
||||
|
||||
return f"✓ YOLOX exported for OAK-D!\n📁 Output: {output_dir}\n🔗 Next: Convert .xml/.bin → blob using blobconverter.luxonis.com"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ YOLOX conversion failed: {e}"
|
||||
|
||||
|
||||
def detect_model_type(model_path: Path) -> str:
|
||||
"""Detect model type from file path or contents."""
|
||||
path_str = str(model_path).lower()
|
||||
|
||||
# Check path patterns
|
||||
if 'rf-detr' in path_str or 'rfdetr' in path_str:
|
||||
return 'rf-detr'
|
||||
elif 'rt-detr' in path_str or 'rtdetr' in path_str:
|
||||
return 'rt-detr'
|
||||
elif 'yolov6' in path_str:
|
||||
return 'yolov6'
|
||||
elif 'yolox' in path_str:
|
||||
return 'yolox'
|
||||
|
||||
# Try to detect from file contents
|
||||
try:
|
||||
import torch
|
||||
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
|
||||
# Check for RT-DETR patterns
|
||||
if any('rtdetr' in key.lower() for key in state_dict.keys()):
|
||||
return 'rt-detr'
|
||||
|
||||
# Check for RF-DETR patterns
|
||||
if any('rf_detr' in key.lower() for key in state_dict.keys()):
|
||||
return 'rf-detr'
|
||||
|
||||
# Check for YOLO patterns
|
||||
if any('yolo' in key.lower() for key in state_dict.keys()):
|
||||
if any('v6' in key.lower() for key in state_dict.keys()):
|
||||
return 'yolov6'
|
||||
else:
|
||||
return 'yolox'
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return 'yolox'
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert trained models for OAK-D deployment")
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Path to trained model weights (.pt file)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
type=Path,
|
||||
default=Path('oak_d_deployment'),
|
||||
help='Output directory for converted models'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--img-size',
|
||||
type=int,
|
||||
default=640,
|
||||
choices=[320, 416, 512, 640, 800, 1024],
|
||||
help='Input image size for the model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--framework',
|
||||
choices=['auto', 'rf-detr', 'rt-detr', 'yolov6', 'yolox'],
|
||||
default='auto',
|
||||
help='Model framework (auto-detect if not specified)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Auto-detect framework if not specified
|
||||
if args.framework == 'auto':
|
||||
framework = detect_model_type(args.model)
|
||||
print(f"Auto-detected framework: {framework}")
|
||||
else:
|
||||
framework = args.framework
|
||||
|
||||
print(f"Converting {framework.upper()} model for OAK-D deployment...")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Output: {args.output}")
|
||||
print(f"Image size: {args.img_size}")
|
||||
|
||||
# Convert based on framework
|
||||
if framework == 'rf-detr':
|
||||
result = convert_rfdetr_for_oak(args.model, args.output, args.img_size)
|
||||
elif framework == 'rt-detr':
|
||||
result = convert_rtdetr_for_oak(args.model, args.output, args.img_size)
|
||||
elif framework == 'yolov6':
|
||||
result = convert_yolov6_for_oak(args.model, args.output, args.img_size)
|
||||
elif framework == 'yolox':
|
||||
result = convert_yolox_for_oak(args.model, args.output, args.img_size)
|
||||
else:
|
||||
result = f"❌ Unsupported framework: {framework}"
|
||||
|
||||
print(result)
|
||||
|
||||
if "✓" in result:
|
||||
print("\n📋 Next steps:")
|
||||
print("1. Test OpenVINO model (optional):")
|
||||
print(" python -c \"from openvino.runtime import Core; core = Core(); model = core.read_model('model.xml'); print('✓ Model loaded')\"")
|
||||
print("2. Convert to RVC compiled format:")
|
||||
print(" - Online: HubAI conversion (fastest setup)")
|
||||
print(" - Offline: ModelConverter (requires Docker)")
|
||||
print(" - Docs: https://docs.luxonis.com/software-v3/ai-inference/conversion/")
|
||||
print("3. Deploy to OAK-D using DepthAI Python API")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()</content>
|
||||
<parameter name="filePath">/home/dillon/_code/saw_mill_knot_detection/convert_for_deployment.py
|
||||
@ -5,5 +5,4 @@ onnx_graphsurgeon>=0.5.0
|
||||
# optional but useful for visualization / quick sanity checks
|
||||
supervision>=0.27.0
|
||||
Pillow>=10.0.0
|
||||
# For the custom annotation GUI
|
||||
gradio>=4.0.0
|
||||
# GUI is Tkinter-based (standard library)
|
||||
|
||||
96
run_gui.sh
96
run_gui.sh
@ -1,96 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to run the annotation GUI with automatic virtual environment detection
|
||||
|
||||
# Get the directory of this script
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# Change to the script directory
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
# Function to check if a command exists
|
||||
command_exists() {
|
||||
command -v "$1" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
# Check for Python
|
||||
if ! command_exists python; then
|
||||
echo "Error: Python is not installed or not in PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check for virtual environment
|
||||
VENV_DIR=""
|
||||
if [ -d "venv" ]; then
|
||||
VENV_DIR="venv"
|
||||
elif [ -d ".venv" ]; then
|
||||
VENV_DIR=".venv"
|
||||
elif [ -d "env" ]; then
|
||||
VENV_DIR="env"
|
||||
fi
|
||||
|
||||
# Check for conda environment
|
||||
CONDA_ENV=""
|
||||
if command_exists conda; then
|
||||
# Check if we're already in a conda environment
|
||||
if [ -n "$CONDA_DEFAULT_ENV" ]; then
|
||||
CONDA_ENV="$CONDA_DEFAULT_ENV"
|
||||
else
|
||||
# Try to find a conda environment with the project name
|
||||
PROJECT_NAME=$(basename "$SCRIPT_DIR")
|
||||
if conda env list | grep -q "^$PROJECT_NAME "; then
|
||||
CONDA_ENV="$PROJECT_NAME"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
if [ -n "$VENV_DIR" ]; then
|
||||
echo "Activating virtual environment: $VENV_DIR"
|
||||
source "$VENV_DIR/bin/activate"
|
||||
elif [ -n "$CONDA_ENV" ]; then
|
||||
echo "Activating conda environment: $CONDA_ENV"
|
||||
conda activate "$CONDA_ENV"
|
||||
else
|
||||
echo "Warning: No virtual environment found. Using system Python."
|
||||
echo "Consider creating a virtual environment with:"
|
||||
echo " python -m venv venv"
|
||||
echo " source venv/bin/activate"
|
||||
echo " pip install -r requirements.txt"
|
||||
fi
|
||||
|
||||
# Check if requirements are installed
|
||||
echo "Checking if dependencies are installed..."
|
||||
python -c "
|
||||
import sys
|
||||
try:
|
||||
import gradio
|
||||
import torch
|
||||
import PIL
|
||||
print('✓ Core dependencies are installed')
|
||||
except ImportError as e:
|
||||
print(f'✗ Missing dependency: {e}')
|
||||
print('Installing requirements...')
|
||||
import subprocess
|
||||
result = subprocess.run([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'],
|
||||
capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print('Failed to install requirements:')
|
||||
print(result.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print('✓ Requirements installed successfully')
|
||||
"
|
||||
|
||||
# Run the GUI
|
||||
echo "Starting annotation GUI..."
|
||||
python annotation_gui.py "$@"
|
||||
|
||||
# Deactivate virtual environment if activated
|
||||
if [ -n "$VENV_DIR" ] || [ -n "$CONDA_ENV" ]; then
|
||||
if [ -n "$VENV_DIR" ]; then
|
||||
deactivate
|
||||
elif [ -n "$CONDA_ENV" ]; then
|
||||
conda deactivate
|
||||
fi
|
||||
fi
|
||||
11
run_tk_gui.sh
Executable file
11
run_tk_gui.sh
Executable file
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
if [ -f .venv/bin/activate ]; then
|
||||
# shellcheck disable=SC1091
|
||||
source .venv/bin/activate
|
||||
fi
|
||||
|
||||
python tk_annotation_gui.py "$@"
|
||||
645
tk_annotation_gui.py
Normal file
645
tk_annotation_gui.py
Normal file
@ -0,0 +1,645 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tkinter-based annotation GUI.
|
||||
|
||||
This is a standalone GUI for manual bounding-box annotation that writes
|
||||
`annotations.json` in the same format used by the project:
|
||||
|
||||
{
|
||||
"image.jpg": [
|
||||
{"bbox": [x1, y1, x2, y2], "label": "knot", "confidence": 1.0, "source": "manual"},
|
||||
...
|
||||
],
|
||||
...
|
||||
}
|
||||
|
||||
This project uses the Tkinter GUI as the annotation interface.
|
||||
|
||||
Run:
|
||||
python tk_annotation_gui.py
|
||||
|
||||
Optional:
|
||||
python tk_annotation_gui.py --images-dir IMAGE/
|
||||
|
||||
Controls:
|
||||
- Click-drag on the image to create a box
|
||||
- Double-click a box entry to delete it
|
||||
- Prev/Next to navigate
|
||||
|
||||
Notes:
|
||||
- Boxes are stored in ORIGINAL image pixel coordinates.
|
||||
- The displayed image is scaled to fit the canvas; coordinates are converted.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tkinter as tk
|
||||
from tkinter import ttk
|
||||
|
||||
from PIL import Image, ImageTk
|
||||
|
||||
# Defaults
|
||||
DEFAULT_IMAGES_DIR = "IMAGE/"
|
||||
DEFAULT_MODEL_WEIGHTS = ""
|
||||
ANNOTATION_CATEGORIES = ["knot"]
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.5
|
||||
|
||||
try:
|
||||
import config as _cfg
|
||||
|
||||
DEFAULT_IMAGES_DIR = getattr(_cfg, "DEFAULT_IMAGES_DIR", DEFAULT_IMAGES_DIR)
|
||||
DEFAULT_MODEL_WEIGHTS = getattr(_cfg, "DEFAULT_MODEL_WEIGHTS", DEFAULT_MODEL_WEIGHTS)
|
||||
ANNOTATION_CATEGORIES = getattr(_cfg, "ANNOTATION_CATEGORIES", ANNOTATION_CATEGORIES)
|
||||
DEFAULT_DETECTION_THRESHOLD = float(getattr(_cfg, "DEFAULT_DETECTION_THRESHOLD", DEFAULT_DETECTION_THRESHOLD))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DisplayTransform:
|
||||
scale: float
|
||||
offset_x: float
|
||||
offset_y: float
|
||||
|
||||
|
||||
class TkAnnotationApp:
|
||||
def __init__(self, root: tk.Tk, images_dir: Path):
|
||||
self.root = root
|
||||
self.root.title("Wood Knot Annotation Tool (Tkinter)")
|
||||
|
||||
self.images_dir = images_dir
|
||||
self.image_paths: list[Path] = []
|
||||
self.current_idx: int = 0
|
||||
|
||||
self.ann_file: Path = self.images_dir / "annotations.json"
|
||||
self.annotations: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
self.current_image: Image.Image | None = None
|
||||
self.current_image_path: Path | None = None
|
||||
self.current_photo: ImageTk.PhotoImage | None = None
|
||||
|
||||
self.transform: DisplayTransform | None = None
|
||||
|
||||
self._draw_start: tuple[float, float] | None = None
|
||||
self._preview_rect_id: int | None = None
|
||||
|
||||
self.model: Any | None = None
|
||||
self.model_type: str | None = None # rf-detr | rt-detr | yolov6 | yolox
|
||||
self.model_path: Path | None = None
|
||||
|
||||
self.model_path_var = tk.StringVar(value=str(DEFAULT_MODEL_WEIGHTS) if DEFAULT_MODEL_WEIGHTS else "")
|
||||
self.model_type_var = tk.StringVar(value="auto")
|
||||
self.threshold_var = tk.DoubleVar(value=float(DEFAULT_DETECTION_THRESHOLD))
|
||||
|
||||
self.label_var = tk.StringVar(value=(ANNOTATION_CATEGORIES[0] if ANNOTATION_CATEGORIES else "knot"))
|
||||
|
||||
self._build_ui()
|
||||
self._load_images_dir(self.images_dir)
|
||||
|
||||
# ------------------------- UI -------------------------
|
||||
|
||||
def _build_ui(self) -> None:
|
||||
container = ttk.Frame(self.root, padding=8)
|
||||
container.grid(row=0, column=0, sticky="nsew")
|
||||
|
||||
self.root.rowconfigure(0, weight=1)
|
||||
self.root.columnconfigure(0, weight=1)
|
||||
|
||||
# Top controls
|
||||
top = ttk.Frame(container)
|
||||
top.grid(row=0, column=0, columnspan=2, sticky="ew", pady=(0, 8))
|
||||
top.columnconfigure(1, weight=1)
|
||||
|
||||
ttk.Label(top, text="Images dir:").grid(row=0, column=0, sticky="w")
|
||||
self.images_dir_var = tk.StringVar(value=str(self.images_dir))
|
||||
self.images_dir_entry = ttk.Entry(top, textvariable=self.images_dir_var)
|
||||
self.images_dir_entry.grid(row=0, column=1, sticky="ew", padx=6)
|
||||
ttk.Button(top, text="Load", command=self._on_load_dir).grid(row=0, column=2, sticky="ew")
|
||||
|
||||
self.index_label = ttk.Label(top, text="Image: -/-")
|
||||
self.index_label.grid(row=0, column=3, sticky="e", padx=(10, 0))
|
||||
|
||||
ttk.Separator(container, orient="horizontal").grid(row=1, column=0, columnspan=2, sticky="ew", pady=(0, 8))
|
||||
|
||||
# Left: Canvas
|
||||
left = ttk.Frame(container)
|
||||
left.grid(row=2, column=0, sticky="nsew", padx=(0, 8))
|
||||
container.rowconfigure(2, weight=1)
|
||||
container.columnconfigure(0, weight=3)
|
||||
|
||||
nav = ttk.Frame(left)
|
||||
nav.grid(row=0, column=0, sticky="ew", pady=(0, 6))
|
||||
nav.columnconfigure(2, weight=1)
|
||||
|
||||
ttk.Button(nav, text="Prev", command=self.prev_image).grid(row=0, column=0, sticky="w")
|
||||
ttk.Button(nav, text="Next", command=self.next_image).grid(row=0, column=1, sticky="w", padx=(6, 0))
|
||||
self.status_label = ttk.Label(nav, text="", foreground="#444")
|
||||
self.status_label.grid(row=0, column=2, sticky="w", padx=(10, 0))
|
||||
|
||||
self.canvas = tk.Canvas(left, width=1200, height=800, bg="#111", highlightthickness=0)
|
||||
self.canvas.grid(row=1, column=0, sticky="nsew")
|
||||
left.rowconfigure(1, weight=1)
|
||||
left.columnconfigure(0, weight=1)
|
||||
|
||||
self.canvas.bind("<ButtonPress-1>", self._on_mouse_down)
|
||||
self.canvas.bind("<B1-Motion>", self._on_mouse_move)
|
||||
self.canvas.bind("<ButtonRelease-1>", self._on_mouse_up)
|
||||
|
||||
# Right: boxes list + controls
|
||||
right = ttk.Frame(container)
|
||||
right.grid(row=2, column=1, sticky="nsew")
|
||||
container.columnconfigure(1, weight=1)
|
||||
|
||||
# Model controls
|
||||
model_frame = ttk.LabelFrame(right, text="Auto-Label", padding=8)
|
||||
model_frame.grid(row=0, column=0, columnspan=2, sticky="ew")
|
||||
model_frame.columnconfigure(1, weight=1)
|
||||
|
||||
ttk.Label(model_frame, text="Weights:").grid(row=0, column=0, sticky="w")
|
||||
self.model_entry = ttk.Entry(model_frame, textvariable=self.model_path_var)
|
||||
self.model_entry.grid(row=0, column=1, sticky="ew", padx=(6, 0))
|
||||
|
||||
ttk.Label(model_frame, text="Type:").grid(row=1, column=0, sticky="w", pady=(6, 0))
|
||||
self.model_type_menu = ttk.OptionMenu(
|
||||
model_frame,
|
||||
self.model_type_var,
|
||||
self.model_type_var.get(),
|
||||
"auto",
|
||||
"rf-detr",
|
||||
"rt-detr",
|
||||
"yolov6",
|
||||
"yolox",
|
||||
)
|
||||
self.model_type_menu.grid(row=1, column=1, sticky="ew", padx=(6, 0), pady=(6, 0))
|
||||
|
||||
ttk.Label(model_frame, text="Threshold:").grid(row=2, column=0, sticky="w", pady=(6, 0))
|
||||
self.threshold_scale = ttk.Scale(model_frame, from_=0.05, to=0.95, variable=self.threshold_var)
|
||||
self.threshold_scale.grid(row=2, column=1, sticky="ew", padx=(6, 0), pady=(6, 0))
|
||||
|
||||
model_buttons = ttk.Frame(model_frame)
|
||||
model_buttons.grid(row=3, column=0, columnspan=2, sticky="ew", pady=(8, 0))
|
||||
model_buttons.columnconfigure(0, weight=1)
|
||||
model_buttons.columnconfigure(1, weight=1)
|
||||
ttk.Button(model_buttons, text="Load Model", command=self.load_model).grid(row=0, column=0, sticky="ew")
|
||||
ttk.Button(model_buttons, text="Auto-Label Current", command=self.auto_label_current).grid(row=0, column=1, sticky="ew", padx=(6, 0))
|
||||
|
||||
self.model_status = ttk.Label(model_frame, text="No model loaded")
|
||||
self.model_status.grid(row=4, column=0, columnspan=2, sticky="w", pady=(6, 0))
|
||||
|
||||
ttk.Label(right, text="Label:").grid(row=1, column=0, sticky="w", pady=(10, 0))
|
||||
self.label_menu = ttk.OptionMenu(right, self.label_var, self.label_var.get(), *ANNOTATION_CATEGORIES)
|
||||
self.label_menu.grid(row=1, column=1, sticky="ew", padx=(6, 0), pady=(10, 0))
|
||||
right.columnconfigure(1, weight=1)
|
||||
|
||||
ttk.Label(right, text="Annotations:").grid(row=2, column=0, columnspan=2, sticky="w", pady=(10, 4))
|
||||
|
||||
self.box_list = tk.Listbox(right, height=18)
|
||||
self.box_list.grid(row=3, column=0, columnspan=2, sticky="nsew")
|
||||
right.rowconfigure(3, weight=1)
|
||||
|
||||
self.box_list.bind("<Double-Button-1>", self._on_box_double_click)
|
||||
|
||||
buttons = ttk.Frame(right)
|
||||
buttons.grid(row=4, column=0, columnspan=2, sticky="ew", pady=(6, 0))
|
||||
ttk.Button(buttons, text="Delete Selected", command=self.delete_selected_box).grid(row=0, column=0, sticky="ew")
|
||||
ttk.Button(buttons, text="Clear All", command=self.clear_all_boxes).grid(row=0, column=1, sticky="ew", padx=(6, 0))
|
||||
|
||||
# Make buttons frame expand reasonably
|
||||
buttons.columnconfigure(0, weight=1)
|
||||
buttons.columnconfigure(1, weight=1)
|
||||
|
||||
# ------------------------- Model loading / auto-label -------------------------
|
||||
|
||||
def _guess_model_type_from_path(self, path: Path) -> str:
|
||||
s = str(path).lower()
|
||||
if "rf" in s or "checkpoint" in s or s.endswith(".pth"):
|
||||
return "rf-detr"
|
||||
if "rtdetr" in s or "rt-detr" in s:
|
||||
return "rt-detr"
|
||||
if "yolov6" in s:
|
||||
return "yolov6"
|
||||
if "yolox" in s:
|
||||
return "yolox"
|
||||
# Default to ultralytics RT-DETR if ambiguous
|
||||
return "rt-detr"
|
||||
|
||||
def load_model(self) -> None:
|
||||
raw = self.model_path_var.get().strip()
|
||||
if not raw:
|
||||
self._set_model_status("No weights path provided")
|
||||
return
|
||||
|
||||
weights_path = Path(raw).expanduser()
|
||||
if not weights_path.exists():
|
||||
self._set_model_status(f"File not found: {weights_path}")
|
||||
return
|
||||
|
||||
selected = (self.model_type_var.get() or "auto").strip().lower()
|
||||
model_type = self._guess_model_type_from_path(weights_path) if selected == "auto" else selected
|
||||
|
||||
try:
|
||||
# RF-DETR optional
|
||||
if model_type == "rf-detr":
|
||||
from rfdetr import RFDETRNano
|
||||
|
||||
self.model = RFDETRNano(pretrain_weights=str(weights_path))
|
||||
else:
|
||||
if model_type == "rt-detr":
|
||||
from ultralytics import RTDETR
|
||||
|
||||
self.model = RTDETR(str(weights_path))
|
||||
else:
|
||||
from ultralytics import YOLO
|
||||
|
||||
self.model = YOLO(str(weights_path))
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_path = weights_path
|
||||
self._set_model_status(f"Loaded: {weights_path.name} ({model_type})")
|
||||
except Exception as e:
|
||||
self.model = None
|
||||
self.model_type = None
|
||||
self.model_path = None
|
||||
self._set_model_status(f"Load failed: {e}")
|
||||
|
||||
def auto_label_current(self) -> None:
|
||||
if self.current_image_path is None:
|
||||
return
|
||||
if self.model is None or self.model_type is None:
|
||||
self._set_model_status("No model loaded")
|
||||
return
|
||||
|
||||
threshold = float(self.threshold_var.get())
|
||||
img_path = self.current_image_path
|
||||
|
||||
try:
|
||||
new_boxes: list[dict[str, Any]] = []
|
||||
|
||||
if self.model_type == "rf-detr":
|
||||
# RF-DETR model expects PIL image
|
||||
if self.current_image is None:
|
||||
return
|
||||
detections = self.model.predict(self.current_image, threshold=threshold)
|
||||
for i in range(len(detections)):
|
||||
xyxy = detections.xyxy[i]
|
||||
conf = float(detections.confidence[i]) if detections.confidence is not None else 1.0
|
||||
x1, y1, x2, y2 = xyxy
|
||||
new_boxes.append(
|
||||
{
|
||||
"bbox": [float(x1), float(y1), float(x2), float(y2)],
|
||||
"label": "knot",
|
||||
"confidence": conf,
|
||||
"source": "auto",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Ultralytics models
|
||||
results = self.model.predict(source=str(img_path), conf=threshold, save=False, verbose=False)
|
||||
for result in results:
|
||||
for box in result.boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
conf = float(box.conf[0])
|
||||
label = "knot"
|
||||
try:
|
||||
cls = int(box.cls[0])
|
||||
if hasattr(self.model, "names") and cls in self.model.names:
|
||||
label = str(self.model.names[cls])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
new_boxes.append(
|
||||
{
|
||||
"bbox": [float(x1), float(y1), float(x2), float(y2)],
|
||||
"label": label,
|
||||
"confidence": conf,
|
||||
"source": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
# Match legacy behavior: append auto boxes to existing
|
||||
key = img_path.name
|
||||
self.annotations.setdefault(key, [])
|
||||
self.annotations[key].extend(new_boxes)
|
||||
self._save_annotations()
|
||||
|
||||
self._refresh_box_list()
|
||||
self._redraw_boxes()
|
||||
self._set_model_status(f"Auto-labeled: {len(new_boxes)}")
|
||||
except Exception as e:
|
||||
self._set_model_status(f"Auto-label failed: {e}")
|
||||
|
||||
def _set_model_status(self, msg: str) -> None:
|
||||
self.model_status.config(text=msg)
|
||||
|
||||
# ------------------------- Data load/save -------------------------
|
||||
|
||||
def _load_images_dir(self, images_dir: Path) -> None:
|
||||
images_dir = images_dir.expanduser().resolve()
|
||||
if not images_dir.exists() or not images_dir.is_dir():
|
||||
self._set_status(f"Invalid images dir: {images_dir}")
|
||||
return
|
||||
|
||||
self.images_dir = images_dir
|
||||
self.ann_file = self.images_dir / "annotations.json"
|
||||
|
||||
self.image_paths = sorted(list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png")) + list(images_dir.glob("*.jpeg")))
|
||||
self.current_idx = 0
|
||||
|
||||
# Load annotations (if present)
|
||||
self.annotations = {}
|
||||
if self.ann_file.exists():
|
||||
try:
|
||||
with self.ann_file.open("r") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
self.annotations = data
|
||||
except Exception as e:
|
||||
self._set_status(f"Failed to load annotations.json: {e}")
|
||||
|
||||
if not self.image_paths:
|
||||
self._set_status("No images found")
|
||||
self._clear_canvas()
|
||||
self._update_index_label()
|
||||
self._refresh_box_list()
|
||||
return
|
||||
|
||||
self._set_status("")
|
||||
self.load_current_image()
|
||||
|
||||
def _save_annotations(self) -> None:
|
||||
# Ensure we always have an entry for current image
|
||||
if self.current_image_path is not None:
|
||||
key = self.current_image_path.name
|
||||
self.annotations.setdefault(key, [])
|
||||
|
||||
try:
|
||||
with self.ann_file.open("w") as f:
|
||||
json.dump(self.annotations, f, indent=2)
|
||||
except Exception as e:
|
||||
self._set_status(f"Failed to save annotations: {e}")
|
||||
|
||||
# ------------------------- Navigation -------------------------
|
||||
|
||||
def prev_image(self) -> None:
|
||||
if not self.image_paths:
|
||||
return
|
||||
self.current_idx = max(0, self.current_idx - 1)
|
||||
self.load_current_image()
|
||||
|
||||
def next_image(self) -> None:
|
||||
if not self.image_paths:
|
||||
return
|
||||
self.current_idx = min(len(self.image_paths) - 1, self.current_idx + 1)
|
||||
self.load_current_image()
|
||||
|
||||
def load_current_image(self) -> None:
|
||||
if not self.image_paths:
|
||||
return
|
||||
|
||||
self.current_image_path = self.image_paths[self.current_idx]
|
||||
|
||||
try:
|
||||
img = Image.open(self.current_image_path).convert("RGB")
|
||||
except Exception as e:
|
||||
self._set_status(f"Failed to open image: {e}")
|
||||
return
|
||||
|
||||
self.current_image = img
|
||||
self._update_index_label()
|
||||
|
||||
# Ensure annotation list exists
|
||||
self.annotations.setdefault(self.current_image_path.name, [])
|
||||
|
||||
self._render_image_and_boxes()
|
||||
self._refresh_box_list()
|
||||
|
||||
def _update_index_label(self) -> None:
|
||||
total = len(self.image_paths)
|
||||
if total == 0:
|
||||
self.index_label.config(text="Image: -/-")
|
||||
return
|
||||
filename = self.image_paths[self.current_idx].name
|
||||
self.index_label.config(text=f"Image {self.current_idx + 1}/{total}: {filename}")
|
||||
|
||||
# ------------------------- Canvas rendering -------------------------
|
||||
|
||||
def _clear_canvas(self) -> None:
|
||||
self.canvas.delete("all")
|
||||
self.current_photo = None
|
||||
self.transform = None
|
||||
|
||||
def _render_image_and_boxes(self) -> None:
|
||||
self._clear_canvas()
|
||||
if self.current_image is None:
|
||||
return
|
||||
|
||||
canvas_w = int(self.canvas.winfo_width())
|
||||
canvas_h = int(self.canvas.winfo_height())
|
||||
# If not yet realized, fall back to configured size
|
||||
if canvas_w <= 2:
|
||||
canvas_w = int(self.canvas["width"])
|
||||
if canvas_h <= 2:
|
||||
canvas_h = int(self.canvas["height"])
|
||||
|
||||
orig_w, orig_h = self.current_image.size
|
||||
scale = min(canvas_w / orig_w, canvas_h / orig_h)
|
||||
scale = max(scale, 1e-6)
|
||||
|
||||
disp_w = int(orig_w * scale)
|
||||
disp_h = int(orig_h * scale)
|
||||
|
||||
offset_x = (canvas_w - disp_w) / 2
|
||||
offset_y = (canvas_h - disp_h) / 2
|
||||
|
||||
disp_img = self.current_image.resize((disp_w, disp_h), Image.Resampling.BILINEAR)
|
||||
self.current_photo = ImageTk.PhotoImage(disp_img)
|
||||
|
||||
# Draw image
|
||||
self.canvas.create_image(offset_x, offset_y, anchor="nw", image=self.current_photo)
|
||||
self.transform = DisplayTransform(scale=scale, offset_x=offset_x, offset_y=offset_y)
|
||||
|
||||
# Draw boxes
|
||||
self._redraw_boxes()
|
||||
|
||||
def _redraw_boxes(self) -> None:
|
||||
self.canvas.delete("box")
|
||||
if self.current_image_path is None or self.transform is None:
|
||||
return
|
||||
|
||||
boxes = self.annotations.get(self.current_image_path.name, []) or []
|
||||
for i, box in enumerate(boxes):
|
||||
bbox = box.get("bbox") if isinstance(box, dict) else None
|
||||
if not bbox or len(bbox) != 4:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
dx1, dy1 = self._img_to_disp(x1, y1)
|
||||
dx2, dy2 = self._img_to_disp(x2, y2)
|
||||
self.canvas.create_rectangle(dx1, dy1, dx2, dy2, outline="#00FF66", width=2, tags=("box", f"box_{i}"))
|
||||
|
||||
def _img_to_disp(self, x: float, y: float) -> tuple[float, float]:
|
||||
assert self.transform is not None
|
||||
return (x * self.transform.scale + self.transform.offset_x, y * self.transform.scale + self.transform.offset_y)
|
||||
|
||||
def _disp_to_img(self, x: float, y: float) -> tuple[float, float]:
|
||||
assert self.transform is not None
|
||||
ix = (x - self.transform.offset_x) / self.transform.scale
|
||||
iy = (y - self.transform.offset_y) / self.transform.scale
|
||||
if self.current_image is None:
|
||||
return ix, iy
|
||||
w, h = self.current_image.size
|
||||
ix = min(max(ix, 0.0), float(w))
|
||||
iy = min(max(iy, 0.0), float(h))
|
||||
return ix, iy
|
||||
|
||||
# ------------------------- Mouse interactions -------------------------
|
||||
|
||||
def _on_mouse_down(self, event: tk.Event) -> None:
|
||||
if self.current_image is None or self.current_image_path is None or self.transform is None:
|
||||
return
|
||||
self._draw_start = (event.x, event.y)
|
||||
if self._preview_rect_id is not None:
|
||||
self.canvas.delete(self._preview_rect_id)
|
||||
self._preview_rect_id = None
|
||||
|
||||
def _on_mouse_move(self, event: tk.Event) -> None:
|
||||
if self._draw_start is None or self.current_image is None or self.transform is None:
|
||||
return
|
||||
|
||||
x0, y0 = self._draw_start
|
||||
x1, y1 = event.x, event.y
|
||||
|
||||
if self._preview_rect_id is not None:
|
||||
self.canvas.delete(self._preview_rect_id)
|
||||
|
||||
self._preview_rect_id = self.canvas.create_rectangle(
|
||||
x0, y0, x1, y1, outline="#FFCC00", width=2, dash=(4, 2)
|
||||
)
|
||||
|
||||
def _on_mouse_up(self, event: tk.Event) -> None:
|
||||
if self._draw_start is None or self.current_image is None or self.current_image_path is None or self.transform is None:
|
||||
self._draw_start = None
|
||||
return
|
||||
|
||||
x0, y0 = self._draw_start
|
||||
x1, y1 = event.x, event.y
|
||||
self._draw_start = None
|
||||
|
||||
if self._preview_rect_id is not None:
|
||||
self.canvas.delete(self._preview_rect_id)
|
||||
self._preview_rect_id = None
|
||||
|
||||
# Convert to image coords
|
||||
ix0, iy0 = self._disp_to_img(x0, y0)
|
||||
ix1, iy1 = self._disp_to_img(x1, y1)
|
||||
|
||||
x_min, x_max = sorted([ix0, ix1])
|
||||
y_min, y_max = sorted([iy0, iy1])
|
||||
|
||||
# Ignore tiny drags
|
||||
if (x_max - x_min) < 2 or (y_max - y_min) < 2:
|
||||
return
|
||||
|
||||
new_box = {
|
||||
"bbox": [float(x_min), float(y_min), float(x_max), float(y_max)],
|
||||
"label": self.label_var.get() or "knot",
|
||||
"confidence": 1.0,
|
||||
"source": "manual",
|
||||
}
|
||||
|
||||
boxes = self.annotations.setdefault(self.current_image_path.name, [])
|
||||
boxes.append(new_box)
|
||||
self._save_annotations()
|
||||
|
||||
self._refresh_box_list()
|
||||
self._redraw_boxes()
|
||||
|
||||
# ------------------------- Box list actions -------------------------
|
||||
|
||||
def _refresh_box_list(self) -> None:
|
||||
self.box_list.delete(0, tk.END)
|
||||
if self.current_image_path is None:
|
||||
return
|
||||
boxes = self.annotations.get(self.current_image_path.name, []) or []
|
||||
for idx, box in enumerate(boxes):
|
||||
if not isinstance(box, dict) or "bbox" not in box:
|
||||
continue
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
label = str(box.get("label", "knot"))
|
||||
src = str(box.get("source", "manual"))
|
||||
conf = box.get("confidence", 1.0)
|
||||
self.box_list.insert(
|
||||
tk.END,
|
||||
f"[x] {idx}: {label} ({src}, {conf:.3f}) ({x1:.1f},{y1:.1f})-({x2:.1f},{y2:.1f})",
|
||||
)
|
||||
|
||||
def _selected_box_index(self) -> int | None:
|
||||
sel = self.box_list.curselection()
|
||||
if not sel:
|
||||
return None
|
||||
# Listbox index corresponds to displayed entries, which correspond to boxes in order
|
||||
return int(sel[0])
|
||||
|
||||
def delete_selected_box(self) -> None:
|
||||
if self.current_image_path is None:
|
||||
return
|
||||
idx = self._selected_box_index()
|
||||
if idx is None:
|
||||
return
|
||||
|
||||
boxes = self.annotations.get(self.current_image_path.name, []) or []
|
||||
if 0 <= idx < len(boxes):
|
||||
del boxes[idx]
|
||||
self._save_annotations()
|
||||
self._refresh_box_list()
|
||||
self._redraw_boxes()
|
||||
|
||||
def _on_box_double_click(self, _event: tk.Event) -> None:
|
||||
self.delete_selected_box()
|
||||
|
||||
def clear_all_boxes(self) -> None:
|
||||
if self.current_image_path is None:
|
||||
return
|
||||
self.annotations[self.current_image_path.name] = []
|
||||
self._save_annotations()
|
||||
self._refresh_box_list()
|
||||
self._redraw_boxes()
|
||||
|
||||
# ------------------------- Misc -------------------------
|
||||
|
||||
def _set_status(self, msg: str) -> None:
|
||||
self.status_label.config(text=msg)
|
||||
|
||||
def _on_load_dir(self) -> None:
|
||||
path = Path(self.images_dir_var.get().strip())
|
||||
self._load_images_dir(path)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Tkinter annotation GUI")
|
||||
parser.add_argument(
|
||||
"--images-dir",
|
||||
type=Path,
|
||||
default=Path(DEFAULT_IMAGES_DIR) if DEFAULT_IMAGES_DIR else Path("IMAGE/"),
|
||||
help="Directory containing images and annotations.json",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
root = tk.Tk()
|
||||
app = TkAnnotationApp(root, args.images_dir)
|
||||
|
||||
# Re-render on first layout so scaling is correct
|
||||
def after_layout() -> None:
|
||||
if app.current_image is not None:
|
||||
app._render_image_and_boxes()
|
||||
|
||||
root.after(50, after_layout)
|
||||
root.mainloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
317
train_model.py
Executable file
317
train_model.py
Executable file
@ -0,0 +1,317 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training script for wood knot detection models.
|
||||
|
||||
Supports: RF-DETR, RT-DETR, YOLOv6, YOLOX
|
||||
All models are MIT/Apache 2.0 licensed - free for commercial use.
|
||||
|
||||
Usage:
|
||||
python train_model.py --framework rtdetr --dataset dataset_prepared --output runs/training
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def prepare_training_dataset(
|
||||
images_dir: Path,
|
||||
annotations_file: Path,
|
||||
output_dir: Path,
|
||||
train_split: float = 0.8,
|
||||
valid_split: float = 0.1
|
||||
) -> str:
|
||||
"""Prepare dataset in RF-DETR format (train/valid/test splits)."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load annotations
|
||||
with annotations_file.open() as f:
|
||||
annotations = json.load(f)
|
||||
|
||||
# Get all image files
|
||||
image_files = []
|
||||
for img_name in annotations.keys():
|
||||
img_path = images_dir / img_name
|
||||
if img_path.exists():
|
||||
image_files.append(img_name)
|
||||
|
||||
if not image_files:
|
||||
return "❌ No annotated images found"
|
||||
|
||||
# Shuffle and split
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(image_files)
|
||||
|
||||
n_total = len(image_files)
|
||||
n_train = int(n_total * train_split)
|
||||
n_valid = int(n_total * valid_split)
|
||||
n_test = n_total - n_train - n_valid
|
||||
|
||||
splits = {
|
||||
'train': image_files[:n_train],
|
||||
'valid': image_files[n_train:n_train + n_valid],
|
||||
'test': image_files[n_train + n_valid:]
|
||||
}
|
||||
|
||||
# Create directories
|
||||
for split in ['train', 'valid', 'test']:
|
||||
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy images and create labels
|
||||
for split, img_names in splits.items():
|
||||
for img_name in img_names:
|
||||
# Copy image
|
||||
src_img = images_dir / img_name
|
||||
dst_img = output_dir / split / 'images' / img_name
|
||||
shutil.copy2(src_img, dst_img)
|
||||
|
||||
# Convert annotations to YOLO format
|
||||
boxes = annotations.get(img_name, [])
|
||||
if boxes:
|
||||
img_width, img_height = get_image_size(src_img)
|
||||
yolo_lines = []
|
||||
for box in boxes:
|
||||
if isinstance(box, dict) and "bbox" in box:
|
||||
x1, y1, x2, y2 = box["bbox"]
|
||||
else:
|
||||
x1, y1, x2, y2 = box
|
||||
# Convert to YOLO format (normalized center x, center y, width, height)
|
||||
x_center = (x1 + x2) / 2 / img_width
|
||||
y_center = (y1 + y2) / 2 / img_height
|
||||
width = (x2 - x1) / img_width
|
||||
height = (y2 - y1) / img_height
|
||||
yolo_lines.append(f"0 {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
|
||||
# Save label file
|
||||
label_file = output_dir / split / 'labels' / img_name.replace('.jpg', '.txt').replace('.png', '.txt')
|
||||
with label_file.open('w') as f:
|
||||
f.write('\n'.join(yolo_lines))
|
||||
|
||||
# Create data.yaml for YOLO models
|
||||
data_yaml = f"""train: {output_dir}/train/images
|
||||
val: {output_dir}/valid/images
|
||||
test: {output_dir}/test/images
|
||||
|
||||
nc: 1
|
||||
names: ['knot']
|
||||
"""
|
||||
|
||||
with (output_dir / 'data.yaml').open('w') as f:
|
||||
f.write(data_yaml)
|
||||
|
||||
return f"✓ Dataset prepared: {n_train} train, {n_valid} valid, {n_test} test images"
|
||||
|
||||
|
||||
def get_image_size(img_path: Path) -> tuple[int, int]:
|
||||
"""Get image dimensions."""
|
||||
from PIL import Image
|
||||
with Image.open(img_path) as img:
|
||||
return img.size
|
||||
|
||||
|
||||
def train_rfdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train RF-DETR model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Import and run training
|
||||
try:
|
||||
from train_rfdetr import main as train_main
|
||||
import sys
|
||||
|
||||
# Set up arguments
|
||||
sys.argv = [
|
||||
'train_rfdetr.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ RF-DETR training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ RF-DETR training failed: {e}"
|
||||
|
||||
|
||||
def train_rtdetr(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train RT-DETR model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from train_rtdetr import main as train_main
|
||||
import sys
|
||||
|
||||
sys.argv = [
|
||||
'train_rtdetr.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ RT-DETR training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ RT-DETR training failed: {e}"
|
||||
|
||||
|
||||
def train_yolov6(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train YOLOv6 model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from train_yolov6 import main as train_main
|
||||
import sys
|
||||
|
||||
sys.argv = [
|
||||
'train_yolov6.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ YOLOv6 training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ YOLOv6 training failed: {e}"
|
||||
|
||||
|
||||
def train_yolox(dataset_dir: Path, output_dir: Path, model_size: str, epochs: int, batch_size: int, lr: float):
|
||||
"""Train YOLOX model."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
from train_yolox import main as train_main
|
||||
import sys
|
||||
|
||||
sys.argv = [
|
||||
'train_yolox.py',
|
||||
'--data', str(dataset_dir / 'data.yaml'),
|
||||
'--output', str(output_dir),
|
||||
'--model', model_size,
|
||||
'--epochs', str(epochs),
|
||||
'--batch', str(batch_size),
|
||||
'--lr', str(lr)
|
||||
]
|
||||
|
||||
train_main()
|
||||
return "✓ YOLOX training completed"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ YOLOX training failed: {e}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train object detection models for wood knot detection")
|
||||
parser.add_argument(
|
||||
'--framework',
|
||||
choices=['rf-detr', 'rt-detr', 'yolov6', 'yolox'],
|
||||
default='rt-detr',
|
||||
help='Model framework to train'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Path to prepared dataset directory'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
type=Path,
|
||||
default=Path('runs/training'),
|
||||
help='Output directory for trained model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model-size',
|
||||
choices=['nano', 'small', 'medium', 'base'],
|
||||
default='small',
|
||||
help='Model size/variant'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=20,
|
||||
help='Number of training epochs'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Batch size for training'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help='Learning rate'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--prepare-dataset',
|
||||
action='store_true',
|
||||
help='Prepare dataset from annotations first'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--images-dir',
|
||||
type=Path,
|
||||
default=Path('IMAGE'),
|
||||
help='Images directory (for --prepare-dataset)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--annotations',
|
||||
type=Path,
|
||||
default=Path('annotations.json'),
|
||||
help='Annotations file (for --prepare-dataset)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.prepare_dataset:
|
||||
print("Preparing dataset...")
|
||||
result = prepare_training_dataset(
|
||||
args.images_dir,
|
||||
args.annotations,
|
||||
args.dataset
|
||||
)
|
||||
print(result)
|
||||
if "❌" in result:
|
||||
return
|
||||
|
||||
print(f"Training {args.framework.upper()} model...")
|
||||
print(f"Dataset: {args.dataset}")
|
||||
print(f"Output: {args.output}")
|
||||
print(f"Model size: {args.model_size}")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Batch size: {args.batch_size}")
|
||||
print(f"Learning rate: {args.lr}")
|
||||
|
||||
# Train based on framework
|
||||
if args.framework == 'rf-detr':
|
||||
result = train_rfdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
elif args.framework == 'rt-detr':
|
||||
result = train_rtdetr(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
elif args.framework == 'yolov6':
|
||||
result = train_yolov6(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
elif args.framework == 'yolox':
|
||||
result = train_yolox(args.dataset, args.output, args.model_size, args.epochs, args.batch_size, args.lr)
|
||||
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()</content>
|
||||
<parameter name="filePath">/home/dillon/_code/saw_mill_knot_detection/train_model.py
|
||||
Reference in New Issue
Block a user