Comments (2)
Minimal reproducer:
# torch==2.3.1
# ultralytics==8.2.35
import torch
from ultralytics.models.yolo import YOLO
torch.manual_seed(42)
def run_yolo(torch_fx, inputs):
yolo_model = YOLO("yolov8n")
model = yolo_model.model
if torch_fx:
model = torch.compile(model)
return model(inputs)[0]
if __name__ == "__main__":
inputs = torch.rand((1, 3, 640, 640))
print("Run Torch model...")
torch_t = run_yolo(torch_fx=False, inputs=inputs)
print("Run Torch FX model...")
fx_t = run_yolo(torch_fx=True, inputs=inputs)
abs_diff = torch.abs(torch_t - fx_t)
idx = torch.argmax(abs_diff)
print(f"argmax idx: {idx}")
print(f"torch value: {torch_t.view(-1)[idx]}")
print(f"torch FX value: {fx_t.view(-1)[idx]}")
print(f'abs diff: {abs_diff.view(-1)[idx]}')
print(f"torch.quantile(abs_diff, 0.96) {torch.quantile(abs_diff, 0.96)}")
Run Torch model...
Run Torch FX model...
argmax idx: 25132
torch value: 490.80194091796875
torch FX value: 855.9827270507812
abs diff: 365.1807861328125
torch.quantile(abs_diff, 0.96) 2.0144500732421875
from ultralytics.
@daniil-lyakhov hi there,
Thank you for providing the minimal reproducible example and detailed information about the issue you're encountering with the torch.compile
model showing metrics degradation on the COCO128 dataset.
It appears that you've identified a significant difference in the validation metrics between the standard PyTorch model and the Torch FX compiled model. This discrepancy is indeed concerning and warrants further investigation.
Steps to Investigate:
-
Verify Versions:
Ensure you are using the latest versions of bothtorch
andultralytics
. The versions you mentioned (torch==2.3.1
andultralytics==8.2.35
) are quite recent, but it's always good to double-check for any new updates or patches that might address this issue. -
Model Consistency Check:
The minimal example you provided shows a significant difference in the output values between the standard and compiled models. This suggests that the compilation process might be altering the model's behavior. To further diagnose this, you can compare intermediate outputs (e.g., feature maps) at various layers of the model for both the standard and compiled versions. This can help pinpoint where the discrepancy begins. -
Validation Loop:
As you noted, theval
method does not currently use the optimized model inside the validation loop. You can modify the validation loop to use the compiled model directly, ensuring that the same model is being evaluated:def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]: with torch.no_grad(): for batch in data_loader: batch = validator.preprocess(batch) preds = model(batch["img"]) preds = validator.postprocess(preds) validator.update_metrics(preds, batch) stats = validator.get_stats() return stats, validator.seen, validator.nt_per_class.sum()
-
Precision and Stability:
The differences in precision and stability between the standard and compiled models could be due to various factors, including numerical stability issues introduced during the compilation process. You might want to experiment with different compilation settings or flags provided bytorch.compile
to see if they mitigate the issue.
Example Code for Validation with Compiled Model:
Here's an example of how you can modify the validation loop to use the compiled model:
def main(torch_fx):
yolo_model = YOLO("yolov8n")
model_type = "torch"
model = yolo_model.model
if torch_fx:
model = torch.compile(model)
model_type = "FX"
print(f"FP32 {model_type} model validation results:")
validator, data_loader = prepare_validation(yolo_model, "coco128.yaml")
stats, total_images, total_objects = validate(model, tqdm(data_loader), validator)
print_statistics(stats, total_images, total_objects)
Next Steps:
- Run the modified validation loop with the compiled model and compare the results.
- Check for any updates to
torch
andultralytics
that might address this issue. - Experiment with different compilation settings to see if they affect the model's performance and accuracy.
If the issue persists, please let us know, and we can further investigate potential causes and solutions.
Thank you for your patience and for bringing this to our attention. We look forward to resolving this issue with your help.
from ultralytics.
Related Issues (20)
- How can i train better my project ? YOLOV8 HOT 14
- Codebase for running YoloV10 with ONNX HOT 8
- xywh returns wrong result while xyxy returns right result HOT 1
- Support distributed evaluation during training process HOT 1
- Is there an example of yolov8n-segn Android split HOT 2
- @glenn-jocher tracker is not working for custom trained models,
- multi input video to YOLOv8 and using bytetrack.yaml return same ID to different object and keep increasing HOT 2
- The engine model RTX3060 exported by RTX4070 cannot be inferred HOT 3
- YOLO(model_yaml).load(model.pt) not work. HOT 5
- Exporting after training on YoloV10 raise a ValueError with MultiGPU HOT 7
- Yolov8 classifier training: impossible to disable some augmentation options HOT 5
- Decoupled Head in YOLOv8 HOT 5
- How to increase the weight of segmentation loss in a segmentation task? HOT 11
- Why is the performance of detection task better than segmentation task? HOT 8
- Permission Denied Error in the middle/end of training. HOT 5
- Show the true label HOT 1
- The confidence difference of pt and onnx model on yolov9. HOT 3
- About Detection Speed YOLOV8 HOT 5
- why YOLO cannot load my dataset HOT 2
- How to read continuous image frames for training? HOT 12
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ultralytics.