Giter VIP home page Giter VIP logo

libtilt's People

Contributors

alisterburt avatar jdickerson95 avatar mchaillet avatar rsanchezgarc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

libtilt's Issues

add 2D/3D rotations

not sure yet of best path for implementation - probably best to start with a real space implementation using torch.grid_sample and accept the interpolation artifacts

alignment API

Do we want a simple, general purpos rigid body image alignment API? Feels like it would be a useful thing to have in the toolbox

cc @McHaillet

FSC implementation is no longer correct

At some point, the FSC implementation here gave identical results to RELION. After a bunch of debugging this evening, this appears to have diverged recently

Defining generic wrapper decorators

I have been doing some experiments using different decorators to speed the code (e.g., functools.lru_cache, torch.compile...), and I think that we should consider a generic mechanism to let the user to optionally activate and deactivate this kind of decorators. Something like this

from functools import lru_cache

import torch

decorators_blacklist = ["torch.compile"] #This should be stored in some global config. Potentially accesible using an ENV_VAR

def decorator_manager(*decorators):
    def wrapper(func):
        for name, decorator in reversed(decorators):
            if name not in decorators_blacklist:
                func = decorator(func)
        return func
    return wrapper


@decorator_manager(("lru_cache", lru_cache(maxsize=32)), ("torch.no_grad", torch.no_grad()), ("torch.compile", torch.compile))
def torch_op(i):
    return torch.rand(i)

torch_op(1)
print(torch_op(2))

By doing in that way, the programmer can choose what are the potentially useful decorators for a given function, but the user can disable those that are not useful for them.

What do you think?

Do you already have some way of defining/modifying config variables?

sample dft 2d

as pointed out by @McHaillet it would be great to have a 2D DFT sampling function like the 3D one to enable fourier space common line extraction ๐Ÿ™‚

Disagrement between relion_project and from libtilt.projection.project_fourier

Hi,

I have compared what you get from libtilt.projection.project_fourier against what you get from relion_project and I am seeing some disagreements that are difficult to spot by the naked eye, but that you can tell easily by computing the difference. For instance,

image

I used this code

import os
import mrcfile
import torch
from scipy.spatial.transform import Rotation as R
from libtilt.projection import project_fourier
from starstack import ParticlesStarSet
from matplotlib import pyplot as plt


dirname = os.path.expanduser("~/cryo/data/preAlignedParticles/EMPIAR-10166/data")
particleIdx = 0

vol = mrcfile.read(os.path.join(dirname,"allparticles_reconstruct.mrc"))
gt_projections = ParticlesStarSet(os.path.join(dirname, "projections/proj_noCTF.star"))

anglesDegs, xyShiftAngs = gt_projections.getPose(particleIdx)
rotation_matrices = R.from_euler("ZYZ", anglesDegs, degrees=True).as_matrix()

vol = torch.FloatTensor(vol)
rotation_matrices = torch.FloatTensor(rotation_matrices).unsqueeze(0)
proj = project_fourier(
    volume = vol,
    rotation_matrices = rotation_matrices,
    rotation_matrix_zyx = False,
    pad = True,
)

img =  gt_projections[particleIdx][0]

proj = (proj-proj.mean())/proj.std()
img = (img-img.mean())/img.std()


diff = torch.abs(proj -img)
print(diff.sum())

f, axes = plt.subplots(1, 3)
axes[0].imshow(proj.cpu().squeeze(0).numpy(), cmap="gray", label="libtilt")
axes[0].set_title("libtilt")
axes[1].imshow(img, cmap="gray", label="relion")
axes[1].set_title("relion")
axes[2].imshow(diff.cpu().squeeze(0).numpy(), cmap="gray", label="abs diff")
axes[2].set_title("abs diff")

plt.show()

and this Relion command

scipion run relion_project --i allparticles_reconstruct.mrc  --ang 1000particles.star --o projections/proj_noCTF 

Do you have any ideas of why the results are not the same? Could it be perhaps the definition of the centre of coordinates?

extract_central_slices_rfft performance

Hi,
I have been using a profiler for the function extract_central_slices_rfft and found that the conjugate_mask is responsible for an important fraction of the whole execution

grid[conjugate_mask] *= -1

projections[conjugate_mask] = torch.conj(projections[conjugate_mask])

These are the results of my profiler (setting CUDA_LAUNCH_BLOCKING=1 to avoid asynchronous run).

Function: extract_central_slices_rfft at line 58

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    58                                           @profile
    59                                           def extract_central_slices_rfft(
    60                                               dft: torch.Tensor,
    61                                               image_shape: tuple[int, int, int],
    62                                               rotation_matrices: torch.Tensor,
    63                                               rotation_matrix_zyx: bool,
    64                                           ):
    65                                               """Extract central slice from an fftshifted rfft."""
    66                                               # generate grid of DFT sample frequencies for a central slice in the xy-plane
    67                                               # these are a coordinate grid for the DFT
    68      1050    7847185.4   7473.5     42.5      grid = rotated_central_slice_grid(
    69       525        391.5      0.7      0.0          image_shape=image_shape,
    70       525        246.1      0.5      0.0          rotation_matrices=rotation_matrices,
    71       525        151.7      0.3      0.0          rotation_matrix_zyx=rotation_matrix_zyx,
    72       525        378.2      0.7      0.0          rfft=True,
    73       525        114.8      0.2      0.0          fftshift=True,
    74       525        821.4      1.6      0.0          device=dft.device,
    75                                               )  # (..., h, w, 3)
    76                                           
    77                                               # flip coordinates in redundant half transform
    78       525     213000.9    405.7      1.2      conjugate_mask = grid[..., 2] < 0
    79                                               # conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') #This operation does not compile
    80       525      29983.9     57.1      0.2      conjugate_mask = conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3) #This does compile
    81       525    3829907.7   7295.1     20.7      grid[conjugate_mask] *= -1 #This is super slower. masked_scatter_ seems 15% faster, still slow #TODO: This is the cornercase
    82                                               # grid.masked_scatter_(conjugate_mask, -1 * grid.masked_select(conjugate_mask))
    83                                           
    84       525      20220.4     38.5      0.1      conjugate_mask = conjugate_mask[..., 0]  # un-repeat
    85                                           
    86                                               # convert frequencies to array coordinates and sample from DFT
    87      1050    2115877.3   2015.1     11.5      grid = fftfreq_to_dft_coordinates(
    88       525        566.8      1.1      0.0          frequencies=grid,
    89       525        631.4      1.2      0.0          image_shape=image_shape,
    90       525        417.4      0.8      0.0          rfft=True
    91                                               )
    92       525    2975877.7   5668.3     16.1      projections = sample_dft_3d(dft=dft, coordinates=grid)  # (..., h, w) rfft
    93                                           
    94                                               # take complex conjugate of values from redundant half transform
    95       525    1429438.3   2722.7      7.7      projections[conjugate_mask] = torch.conj(projections[conjugate_mask]) #This is slower
    96                                               # projections.masked_scatter_(conjugate_mask, torch.conj(projections.masked_select(conjugate_mask)))
    97       525       1142.1      2.2      0.0      return projections



The following changes can speed the code a bit.

grid.masked_scatter_(conjugate_mask, -1 * grid.masked_select(conjugate_mask))
projections.masked_scatter_(conjugate_mask, torch.conj(projections.masked_select(conjugate_mask)))

unittesting device agnostic code

Some of the code is giving issues when running on cuda devices and I was thinking whether we could have device testing integrated in unittests.

To highligth some things, backproject_fourier is not running with cuda tensors and I noticed @rsanchezgarc also pointed out some issues in PR #54.

Running all the tests with both tensors on cpu and a cuda device would of course solve it. However, if there would be automated testing for PR's in the future, the default GitHub instances do not support this. Secondly, nice about pytorch is that it also allows development on a CPU only system and still have portability for GPUs.

I was looking online, cause I imagined more people need this, and found that torch supports a 'meta' device: https://pytorch.org/torchdistx/latest/fake_tensor.html . Also see the discussion here (pytorch/pytorch#61654). I don't know whether its fully supported throughout pytorch, but I could play around with it and see how applicable it is.

For example though, the following produces appropriate errors:

>>> import torch
>>> a = torch.zeros((10,10))  # default will initialize on 'cpu'
>>> b = torch.zeros((10,10), device='meta')
>>> c = a * b
RuntimeError: Tensor on device meta is not on the expected device cpu!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.