Giter VIP home page Giter VIP logo

Comments (3)

JanuszL avatar JanuszL commented on May 30, 2024

Hi @jpfeil,

Thank you for reaching out.
In your code crops is a list while DALI expects outputs to be plain types. So in your case, you need to either unpack the list:
return *crops, labels
or concatenate all into one tensor using cat or stack operators (as long as the tensors have uniform shapes samplewise).

from dali.

jpfeil avatar jpfeil commented on May 30, 2024

Thank you, @JanuszL!

I've implemented the pipeline with your suggestions, but now I'm having issues with the iterator.

class Solarize:
    def __init__(self, threshold: int = 128) -> None:
        self._threshold = threshold

    def __call__(self, img):
        inverted_img = 255 - img
        mask = img >= self._threshold
        return mask * inverted_img + (True ^ mask) * img

solarize = Solarize()
            
@pipeline_def(seed=seed, batch_size=1, enable_conditionals=True)
def dino_dali_pipeline(image_dir, local_crops_number=8, device="mixed"):
    
    jpegs, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)

    decoded_jpegs = fn.decoders.image(jpegs, device=device)

    cropped_jpegs = fn.crop(decoded_jpegs, crop=(16384, 16384))
    
    #
    # Global Transform 1
    #
    gt1 = fn.random_resized_crop(cropped_jpegs, size=224, random_aspect_ratio=(1.0, 1.0))

    ## Random Horizontal Flip
    coin = fn.random.coin_flip()
    if coin:
        gt1 = fn.flip(gt1, horizontal=1)

    ## Color Jitter
    coin = fn.random.coin_flip(probability=0.8)
    if coin:
        gt1 = fn.color_twist(gt1, brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)

    ## Random Grayscale
    coin = fn.random.coin_flip(probability=0.2)
    if coin:
        gt1 = fn.color_space_conversion(gt1, image_type=types.RGB, output_type=types.GRAY)

    ## Gaussian Blur
    gt1 = fn.gaussian_blur(gt1, sigma=(0.1, 2.0))

    ## Normalize 
    gt1 = fn.normalize(gt1)

    #
    # Global Transform 2
    #
    gt2 = fn.random_resized_crop(cropped_jpegs, size=224, random_aspect_ratio=(1.0, 1.0))

    ## Random Horizontal Flip
    coin = fn.random.coin_flip()
    if coin:
        gt2 = fn.flip(gt2, horizontal=1)

    ## Color Jitter
    coin = fn.random.coin_flip(probability=0.8)
    if coin:
        gt2 = fn.color_twist(gt2, brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)

    ## Random Grayscale
    coin = fn.random.coin_flip(probability=0.2)
    if coin:
        gt2 = fn.color_space_conversion(gt2, image_type=types.RGB, output_type=types.GRAY)

    ## Gaussian Blur
    coin = fn.random.coin_flip(probability=0.1)
    if coin:
        gt2 = fn.gaussian_blur(gt2, sigma=(0.1, 2.0))

    ## Solarize    
    coin = fn.random.coin_flip(probability=0.1)
    if coin:
        gt2 = fn.cast(solarize(gt2), dtype=types.UINT8)

    gt2 = fn.normalize(gt2)


    #
    # Local Transformations
    #
    
    crops = [gt1, gt2]

    for _ in range(local_crops_number):
        lt = fn.random_resized_crop(cropped_jpegs, size=96, random_aspect_ratio=(0.5, 0.5))

        ## Random Horizontal Flip
        coin = fn.random.coin_flip()
        if coin:
            lt = fn.flip(lt, horizontal=1)
    
        ## Color Jitter
        coin = fn.random.coin_flip(probability=0.8)
        if coin:
            lt = fn.color_twist(lt, brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
    
        ## Random Grayscale
        coin = fn.random.coin_flip(probability=0.2)
        if coin:
            gt1 = fn.color_space_conversion(lt, image_type=types.RGB, output_type=types.GRAY)
    
        ## Gaussian Blur
        coin = fn.random.coin_flip(probability=0.5)
        if coin:
            lt = fn.gaussian_blur(lt, sigma=(0.1, 2.0))
    
        ## Normalize 
        lt = fn.normalize(lt)
    
    return *crops, 

This works and I can get augmented images out of it. The only issue is that I'm used to the pytorch representation using floats whereas in DALI it usually represents the image as uint8, but hopefully that doesn't influence training that much.

The problem I have now is passing the pipeline to the iterator:

pipe = dino_dali_pipeline(image_dir, batch_size=4, num_threads=4, device_id=0)
pipe.build()

iterator = DALIGenericIterator(
    pipelines=pipe,
    output_map=["gt1", "gt2", "lt1", "lt2", "lt3", "lt4", "lt5", "lt6", "lt7", "lt8"],
)

for i, (batch,) in enumerate(iterator):
    print(batch)
    break

This will run for a little while and then it throws this error:

RuntimeError: [[/opt/dali/dali/pipeline/data/tensor_list.cc:1012](http://localhost:8888/opt/dali/dali/pipeline/data/tensor_list.cc#line=1011)] Assert on "IsDenseTensor()" failed: The batch must be representable as a tensor - it must have uniform shape and be allocated in contiguous memory.
Stacktrace (88 entries):

This seems to be related to the global and local crops being different sizes. Is there a way to support this kind of data in DALI?

Thanks!

from dali.

JanuszL avatar JanuszL commented on May 30, 2024

Hi @jpfeil,

The only issue is that I'm used to the pytorch representation using floats whereas in DALI it usually represents the image as uint8

You can use the crop_mirror_normalize operator and pass float as the output type, and 255 as the std to scale from uin8 to 0-1 float.

This seems to be related to the global and local crops being different sizes. Is there a way to support this kind of data in DALI?

Yes, the iterator expects that the batch of samples can be represented as the tensor where one of the dimensions is the batch size. In this case, you can either pad samples to have them uniform or try out PyTorch DALIRaggedIterator.

from dali.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.