Comments (19)
excuse me?
from hed-unet.
Look, I'm essentially providing support for free here, so kindly accept the fact that I don't work weekends 😄
Though I have tried applying the framework to multi-class data, there is no theoretical reason why it shouldn't work.
from hed-unet.
I would recommend the following changes to the existing code to enable multi-class training:
train.py#L80:
sobel = (-F.max_pool2d(-mask, 3, 1, padding=1) != F.max_pool2d(mask, 3, 1, padding=1)).float()
config.yml:
Add output_channels: <NumberOfClasses+1>
to the model_args
block.
def multiclass_loss(y_hat_log, y):
loss_seg = F.cross_entropy(y_hat_log[:, :-1], y[:, :1])
loss_edge = auto_weight_bce(y_hat_log[:, -1:], y[:, 1:])
return 0.5 * (loss_seg + loss_edge)
And then change loss_functions.py and config.py to use this multiclass_loss
.
from hed-unet.
What color is used for the mask made of new category data and how to read the data? Can you share a code example?
def getitem(self, index):
img_path = self.img_files[index]
mask_path = self.mask_files[index]
data = use opencv or pil read image using img_path
label =use opencv or pil read label using mask_path
return torch.from_numpy(data).float() / 255, torch.from_numpy(label > 0).float()
from hed-unet.
Good point, you will also need to change your data loading. It all depends on how your training masks are stored. Could you share an example?
from hed-unet.
from hed-unet.
good morning,Look at your email. Have you received my sample diagram and instructions?
from hed-unet.
I originally trained the data reading function of a kind of data as follows
def __getitem__(self, index):
img_path = self.img_files[index]
mask_path = self.mask_files[index]
# print('debug1', img_path, mask_path)
data = cv2.imread(img_path)
# label = Image.open(mask_path)
# label = np.array(label)
label = cv2.imread(mask_path, 0)
label = label[:, :, None]
data = np.transpose(data, (2, 0, 1)) # [3,224,224]
label = np.transpose(label, (2, 0, 1)) # [3,224,224]
# return torch.from_numpy(data).float() / 127.) - 1., torch.from_numpy(label > 0).float()
# data1 = torch.from_numpy(data).float() / 127. - 1.
return torch.from_numpy(data).float() / 255., torch.from_numpy(label > 0).float()
from hed-unet.
I'm afraid I have not received any email from you.
from hed-unet.
As shown in the figure below, what should we do about the two types of data, and their mask and data reading parts?
The blue mask in the picture below, you can change it to any color for your training. It's just my demo picture, and it will be changed after my formal training.
I've also attached the original drawings and mask drawings of these two pictures. You can try to modify them in your code, or you can tell me where to modify them. Thank you for your selfless dedication to academics!
from hed-unet.
Alright, as a quick side remark, it would be best practice to save masks in a lossless format like .png
, as JPEG compression artifacts will give you a hard time otherwise.
Regarding converting the RGB colors to an enumerated segmentation mask, I recommend adapting the rgb_to_mask
function from torchgeo
(microsoft/torchgeo/torchgeo/datasets/utils.py#L655-L677).
There, you would pass your loaded data in HWC
order and a list of RGB colors that you want to map to class indices, e.g. passing colors = [(0, 0, 0), (255, 255, 255), (0, 0, 255)]
would map black to 0
, white to 1
and blue to 1
. You will need to adapt this to the exact colors that you are using.
Also, note that cv2.imread()
defaults to reading channels in BGR
order, not RGB
.
Hope this helps :)
from hed-unet.
I wrote the code like this and read two different types of data. Among the values returned in the label, 1.0 is the first type, 0.5 is the second type, and 0 is the black background type. Do you think this is OK? I'm not very familiar with it
def getitem(self, index): # A graph has at most one class
img_path = self.img_files[index]
mask_path = self.mask_files[index]
if 'second' in img_path: # second_class data loader
data = cv2.imread(img_path)
label = cv2.imread(mask_path, 0)
label = label[:, :, None]
cv2.imshow('img', data)
cv2.imshow('label', label)
cv2.waitKey() # debug
data = np.transpose(data, (2, 0, 1)) # [3,224,224]
label = np.transpose(label, (2, 0, 1)) # [3,224,224]
data = torch.from_numpy(data).float() / 255.
label = torch.from_numpy((label == 122)/2.).float() # todo 122 value can be extracted into mask foreground || foreground: 0.5 ;background:
return data, label
else:
# first class: white mask
data = cv2.imread(img_path)
label = cv2.imread(mask_path, 0)
label = label[:, :, None]
# cv2.imshow('img', data)
# cv2.imshow('label', label)
# cv2.waitKey()
data = np.transpose(data, (2, 0, 1)) # [3,224,224]
label = np.transpose(label, (2, 0, 1)) # [3,224,224]
data = torch.from_numpy(data).float() / 255.
label = torch.from_numpy(label > 0).float() #todo foreground: 1 ;background: 0
return data, label
from hed-unet.
Hi, good morning. I'm modifying it to multi class now.
In the data reading part, the label becomes onehot Code:
def mask_to_onehot(mask, palette):
"""
Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
hot encoding vector, C is usually 1 or 3, and K is the number of class.
"""
semantic_map = []
for colour in palette:
equality = np.equal(mask, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
return semantic_map
def __getitem__(self, index): # A graph has at most one class
img_path = self.img_files[index]
mask_path = self.mask_files[index]
palette = [[0], [122], [255]] # one-hot的颜色表
if 'second' in img_path: # second_class data loader
data = cv2.imread(img_path)
label = cv2.imread(mask_path, 0)
label = label[:, :, None]
# cv2.imshow('img', data)
# cv2.imshow('label', label)
# cv2.waitKey() # debug
label = mask_to_onehot(label, palette)
data = np.transpose(data, (2, 0, 1)) # [3,224,224]
label = np.transpose(label, (2, 0, 1)) # [3,224,224]
# print('debug second label mask shape:', label.shape)
data = torch.from_numpy(data).float() / 255.
label = torch.from_numpy((label == 122)/2.).float() # todo 122 value can be extracted into mask foreground || foreground: 0.5 ;background:
return data, label
else:
# first class: white mask
data = cv2.imread(img_path)
label = cv2.imread(mask_path, 0)
label = label[:, :, None]
# cv2.imshow('img', data)
# cv2.imshow('label', label)
# cv2.waitKey()
label = mask_to_onehot(label, palette)
data = np.transpose(data, (2, 0, 1)) # [3,224,224]
label = np.transpose(label, (2, 0, 1)) # [3,224,224]
data = torch.from_numpy(data).float() / 255.
label = torch.from_numpy(label > 0).float() #todo foreground: 1 ;background: 0
# print('debug first label mask shape:', label.shape)
return data, label
Then, use the above loss function to calculate, and the report dimensions are different:
'y_hat_log[:, :-1] shape: [1,2,256,256]', a.shape, ' y[:, :-1] shape: [1,1,256,256]
Then, I modified: loss_seg = F.cross_entropy(y_hat_log[:, :-1], y[:, :2]) # todo
When executing the showexample function: SEG_ pred, edge_ pred,nothing = torch. sigmoid(prediction)
Wrong dimension found, torch sigmoid(prediction). Shape: [3256256], and only two values are used to receive the return value of sigmoid, so I changed it to seg_ pred, edge_ pred,nothing = torch. sigmoid(prediction)
Now it can work normally, but the label display in log / figures / is not normal.
it's running !
After changing to multi class, the reasoning results are displayed in opencv, which also needs to be modified. I'm doing it
from hed-unet.
Hello, I put four pictures in front, which are first_ class. jpg; first_ class_ gt. jpg; second_ class. jpg; second_ class_ gt. jpg
Can you use these four pictures to train a multi classification UNET hed? It can be paid. Thank you for your efforts.
from hed-unet.
Hello, excuse me?
from hed-unet.
I wrote the code like this and read two different types of data. Among the values returned in the label, 1.0 is the first type, 0.5 is the second type, and 0 is the black background type. Do you think this is OK? I'm not very familiar with it
No, you should have integer labels, for example 0=background, 1=class1, 2=class2.
You should no longer be using sigmoid,
but softmax
instead . For visualization purposes, you'll want to extract the predicted class using argmax
.
from hed-unet.
Can you use these four pictures to train a multi classification UNET hed? It can be paid. Thank you for your efforts.
Four pictures are not nearly enough to train a deep learning model, in my experience you would need at least ~1000 samples.
from hed-unet.
from hed-unet.
from hed-unet.
Related Issues (18)
- How could I get the complete images on result HOT 1
- Training steps HOT 4
- Is the image input during the training process the original image and the binary image? HOT 17
- Testing data HOT 7
- Whether this program can handle .png files. If you can’t handle it, can you convert the .tif file to a .png file? HOT 1
- AttributeError: 'Array' object has no attribute 'numpy' HOT 1
- RuntimeError: Given groups=1, weight of size [16, 3, 1, 1], expected input[8, 4, 256, 256] to have 3 channels, but got 4 channels instead HOT 10
- Hello, my training is successful, then I would like to ask how to test my own data? HOT 1
- About training HOT 21
- About Loss HOT 5
- Hed HOT 7
- Weight Map HOT 2
- I have some questions about testing. HOT 6
- About the differences between predictions and queries. HOT 3
- Does the edge detection task have an impact on the semantic segmentation task in this model? HOT 1
- Can you please share the pre-trained weights for interence HOT 1
- Can the author provide code for prediction
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hed-unet.