Comments (15)
As shown in the line 124-125 of u2net_train.py: inputs = inputs.type(torch.FloatTensor), labels = labels.type(torch.FloatTensor). Inputs, labels and prediction are all FloatTensor. If you have changed the training code, pls make sure you are converted you mask to FloatTensor. To use it for binary segmentation, just replace the path of training data by your own data.
from u-2-net.
If you have changed the training code, pls make sure you are converted you mask to FloatTensor.
@NathanUA I am trying to use it with FastAI. How can i convert the output to obtain 0 if not selected pixel and 1 if selected pixel?
from u-2-net.
Just threshold it. "prediction=prediction>0.5" should work.
from u-2-net.
As shown in the line 124-125 of u2net_train.py: inputs = inputs.type(torch.FloatTensor), labels = labels.type(torch.FloatTensor). Inputs, labels and prediction are all FloatTensor. If you have changed the training code, pls make sure you are converted you mask to FloatTensor.
Thank you, for all your help!!!
I updated my code for passing mask to floattensor. Now, I am getting next error!
547 def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
548 print(labels_v.shape)
--> 549 loss0 = bce_loss(d0, labels_v)
ValueError: Target size (torch.Size([2, 1002, 1002])) must be the same as input size (torch.Size([2, 1, 1002, 1002]))
from u-2-net.
from u-2-net.
The dimension of two tensors d0 and labels_v should be the same. You have to reshape your ground truth by kind of reshaping operation e.g. np.reshape(gt,(2,1,height,width)) or similar operation in petroch.
Now, loss it's working. However, I don't understand what is the output of the model. I tried to use the treshold as you said.
But dice metric is returning 0 even with a small loss. So, I think that I am doing bad the conversión of prediction into mask. Could you share with me a code example?
from u-2-net.
I am using next code @NathanUA :
d0, d1, d2, d3, d4, d5, d6 = self.model(*self.xb)
self.pred = d0.clone()
self.pred = F.sigmoid(self.pred)
self.pred=normPRED(self.pred)
self.pred=self.pred>0.5
self.pred=self.pred.type(torch.uint8)
from u-2-net.
from u-2-net.
I think if you didn’t change the definiton of u2net, then please remove you F.sigmoid() function in your code because there is already sigmoid at the end of our model. It should work then.
…
On May 19, 2020, at 4:08 PM, David Lacalle Castillo @.***> wrote: I am using next code @NathanUA https://github.com/NathanUA : d0, d1, d2, d3, d4, d5, d6 = self.model(*self.xb) self.pred = d0.clone() self.pred = F.sigmoid(self.pred) self.pred=normPRED(self.pred) self.pred=self.pred>0.5 self.pred=self.pred.type(torch.uint8) — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#25 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADSGORMZVG5YPSH44HYRJITRSL7ODANCNFSM4NFJGAOQ.
I removed the F.sigmoid at the forward:
return d0, d1, d2, d3, d4, d5, d6
I also chnaged bceloss to
bce_loss = nn.BCEWithLogitsLoss(size_average=True)
This changes are for allowing the use of torch.cuda.amp.autocast and MixedPrecisionTraining.
from u-2-net.
All my code looks as follows:
class U2NETP(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(U2NETP, self).__init__()
self.stage1 = RSU7(in_ch, 16, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 16, 64)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(64, 16, 64)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(64, 16, 64)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(64, 16, 64)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(64, 16, 64)
# decoder
self.stage5d = RSU4F(128, 16, 64)
self.stage4d = RSU4(128, 16, 64)
self.stage3d = RSU5(128, 16, 64)
self.stage2d = RSU6(128, 16, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# decoder
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
# return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
# Not neccesary for BCEWithLogitsLoss
return d0, d1, d2, d3, d4, d5, d6
bce_loss = nn.BCEWithLogitsLoss(size_average=True)
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
labels_v=labels_v.unsqueeze(1)
loss0 = bce_loss(d0, labels_v)
loss1 = bce_loss(d1, labels_v)
loss2 = bce_loss(d2, labels_v)
loss3 = bce_loss(d3, labels_v)
loss4 = bce_loss(d4, labels_v)
loss5 = bce_loss(d5, labels_v)
loss6 = bce_loss(d6, labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
return loss0, loss
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
One batch
def mixed_precision_one_batch(self, i, b):
from torch.cuda.amp import autocast
self.iter = i
try:
self._split(b)
self('begin_batch')
with autocast():
d0, d1, d2, d3, d4, d5, d6 = self.model(*self.xb)
self.pred = d0.clone()
# self.pred = F.sigmoid(self.pred)
self.pred=normPRED(self.pred)
self.pred=self.pred>0.5
self.pred=self.pred.type(torch.uint8)
aux=[]
for elem in self.yb:
aux.append(elem.type(next(self.model.parameters()).dtype).to(next(self.model.parameters()).device))
self.yb=tuple(aux)
self('after_pred')
if len(self.yb) == 0:
return
_, self.loss = self.loss_func(d0, d1, d2, d3, d4, d5, d6, *self.yb)
self('after_loss')
if not self.training:
return
self.scaler.scale(self.loss).backward()
self('after_backward')
del d0, d1, d2, d3, d4, d5, d6
self.scaler.step(self.opt)
self('after_step')
self.opt.zero_grad()
except CancelBatchException:
self('after_cancel_batch')
finally:
self('after_batch')
from u-2-net.
from u-2-net.
Then you have to debug step by step. For example, output the pre-threhsolded probability maps and to see if they make sense. If yes, then try to debug the normalization and the thresholding function. A good way to debug is to output the intermediate variables and visually check the results. Best of luck.
…
On May 19, 2020, at 4:24 PM, David Lacalle Castillo @.> wrote: I think if you didn’t change the definiton of u2net, then please remove you F.sigmoid() function in your code because there is already sigmoid at the end of our model. It should work then. … x-msg://18/# On May 19, 2020, at 4:08 PM, David Lacalle Castillo @.> wrote: I am using next code @NathanUA https://github.com/NathanUA https://github.com/NathanUA https://github.com/NathanUA : d0, d1, d2, d3, d4, d5, d6 = self.model(*self.xb) self.pred = d0.clone() self.pred = F.sigmoid(self.pred) self.pred=normPRED(self.pred) self.pred=self.pred>0.5 self.pred=self.pred.type(torch.uint8) — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#25 (comment) <#25 (comment)>>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADSGORMZVG5YPSH44HYRJITRSL7ODANCNFSM4NFJGAOQ https://github.com/notifications/unsubscribe-auth/ADSGORMZVG5YPSH44HYRJITRSL7ODANCNFSM4NFJGAOQ. I removed the F.sigmoid at the forward: return d0, d1, d2, d3, d4, d5, d6 I also chnaged bceloss to bce_loss = nn.BCEWithLogitsLoss(size_average=True) This changes are for allowing the use of torch.autocast and MixedPrecisionTraining. — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#25 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADSGORINEFS5UMGHUXFATL3RSMBIVANCNFSM4NFJGAOQ.
Thank you for your help. I will debug it tomorrow.
Loss is working.
I don't know what is happening. Printing torch.max and torch.min of sigmoid and normalization looks okey.
Could you post the code that you used for transforming the prediction in the black white images that are in readme, please?
Thank you very much for all your help!!!
from u-2-net.
from u-2-net.
Sorry, in our readme file, all the maps are probability maps not thresholded ones. I have already sent you the code predicton = prediction>0.5 My suggestion is trying to VISUALLY look at the prediction results. If that make sense, you can check if your data format is correct and dice computation function get the correctly matched inputs with correct format. Best of luck!
Debugging my coude I found that the problem was with dice metric, that metric was making an argmax of prediction. However, the output of this model is just a channel. So, there is no need to apply that.
This model throws this warning:
UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
"See the documentation of nn.Upsample for details.".format(mode))
How can I solve it??
Thank you very much for your help @NathanUA .
from u-2-net.
I guess that is possibily because you are using a different version of pytorch. This is just a warming. Shouldn't be a big issue if it dosn't impact the performance.
from u-2-net.
Related Issues (20)
- issue about portrait function
- Cannot Import U2NET
- onnx result wrong HOT 5
- [Question] Trained model for proprietary use
- Ai
- RuntimeError: unexpected EOF, The file might be corrupted
- 大佬,可以提供一下验证模型相关指标的代码吗? HOT 1
- RescaleT:Why not prioritize maintaining the aspect ratio
- Hi, when I run this code, I get strange errors in other detection tasks.The following is the warning where the error occurs
- Continue training, train my own model with U-2-Net, in between due to some reasons the training was interrupted or I want to strengthen an existing model, what should I do? Can you provide a 'Continue-training.py'? HOT 4
- Inference speed HOT 1
- 模型训练时间过长
- 对u2netp模型进行qat量化
- How can I input video or webcam in the test.py script?
- 能否麻烦将整个项目打个包(包含运行文件和预训练模型),集成一个 .bat 文件,点击运行即可使用? HOT 2
- 请问该模型只能输出二分类结果么?可以输出为多分类么?比如三分类 HOT 1
- Is it support person segmentation now? HOT 2
- 如何添加评价指标
- ImportError: cannot import name 'U2NET' from 'model' HOT 2
- Can' Access Human Segmentation Model Weights HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from u-2-net.