Giter VIP home page Giter VIP logo

concept-saliency-maps's People

Contributors

lenbrocki avatar ncchung avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

ncchung kayoyin

concept-saliency-maps's Issues

concept_vectors.npy

Hi,

How can I create the concept_vectors.npy file which is being used in the Dot products and histograms section?

Fetch argument None has invalid type <class 'NoneType'>

Hi,

When running the DeepExplain as shown below, I run into the following error. Any suggestions? Please let me know if you need any further info. Thanks!

import keras
sess = K.get_session()
print('sess: ',sess)
from ConceptSaliencyMaps.deepexplain.tensorflow import DeepExplain
from ConceptSaliencyMaps.deepexplain.utils import preprocess

list_files = []
all_files = train_files + test_files
for file_name in files_max:
    for file_name2 in all_files:
        if file_name in file_name2:
            list_files.append(file_name2)
            
test_set2 = zfish_age(list_files, path_to_save = path_to_augmented, test=True, transform = True, new_channel=new_channel, new_size_frame=size_frame, 
                     verbose=False)
test_generator2 = data.DataLoader(test_set2,batch_size=1,
                                       shuffle=False,
                                       num_workers=20)

input_img = keras.Input(shape=(50, 128, 128)) 

with DeepExplain(session=sess, graph=sess.graph) as de:
    with torch.no_grad():
        for i, d in enumerate(test_generator2): 
            xis, _, _, labels_name = d
            print('labels_name: {}'.format(labels_name))
                
            input_tensor = input_img
            img_array = xis.reshape([1,50,128,128])
            ris, zis = model(xis.to(device))
            print('zis.shape: ',zis.shape) # torch.Size([1, 256])
            latents = reducer.transform(zis.cpu().detach())
            print('latents.shape: ',latents.shape) # (1, 2)
            method = 'guidedbp'

            concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
            attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]```


Error: 
TypeError                                 Traceback (most recent call last)
<ipython-input-169-177871cfe4fc> in <module>
     73 
     74             concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
---> 75             attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]

<ipython-input-169-177871cfe4fc> in <listcomp>(.0)
     73 
     74             concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
---> 75             attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in explain(self, method, T, X, xs, **kwargs)
    733         _ENABLED_METHOD_CLASS = method_class
    734         method = _ENABLED_METHOD_CLASS(T, X, xs, self.session, self.keras_phase_placeholder, **kwargs)
--> 735         result = method.run()
    736         if issubclass(_ENABLED_METHOD_CLASS, GradientBasedMethod) and _GRAD_OVERRIDE_CHECKFLAG == 0:
    737             warnings.warn('DeepExplain detected you are trying to use an attribution method that requires '

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in run(self)
    463         for alpha in list(np.linspace(1. / self.steps, 1.0, self.steps)):
    464             xs_mod = [xs * alpha for xs in self.xs] if self.has_multiple_inputs else self.xs * alpha
--> 465             _attr = self.session_run(attributions, xs_mod)
    466             if gradient is None: gradient = _attr
    467             else: gradient = [g + a for g, a in zip(gradient, _attr)]

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in session_run(self, T, xs)
     94         if self.keras_learning_phase is not None:
     95             feed_dict[self.keras_learning_phase] = 0
---> 96         return self.session.run(T, feed_dict)
     97 
     98     def _set_check_baseline(self):

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    954     try:
    955       result = self._run(None, fetches, feed_dict, options_ptr,
--> 956                          run_metadata_ptr)
    957       if run_metadata:
    958         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1163     # Create a fetch handler to take care of the structure of fetches.
   1164     fetch_handler = _FetchHandler(
-> 1165         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1166 
   1167     # Run request and get response.

..lib/python3.7/site-packages/tensorflow_core/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    472     """
    473     with graph.as_default():
--> 474       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    475     self._fetches = []
    476     self._targets = []

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in for_fetch(fetch)
    264     elif isinstance(fetch, (list, tuple)):
    265       # NOTE(touts): This is also the code path for namedtuples.
--> 266       return _ListFetchMapper(fetch)
    267     elif isinstance(fetch, collections_abc.Mapping):
    268       return _DictFetchMapper(fetch)

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in __init__(self, fetches)
    373     """
    374     self._fetch_type = type(fetches)
--> 375     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    376     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    377 

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in <listcomp>(.0)
    373     """
    374     self._fetch_type = type(fetches)
--> 375     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    376     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    377 

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in for_fetch(fetch)
    261     if fetch is None:
    262       raise TypeError('Fetch argument %r has invalid type %r' %
--> 263                       (fetch, type(fetch)))
    264     elif isinstance(fetch, (list, tuple)):
    265       # NOTE(touts): This is also the code path for namedtuples.

TypeError: Fetch argument None has invalid type <class 'NoneType'>

TypeError: unhashable type: 'numpy.ndarray'

Hello,

Thanks for your work, I am trying to use it on the VAE of latplan , I tried this:

attributions = de.explain("guidedbp", latent, latent, train_subdata_batch_cache)

the 2nd argument should be a concept_score but I don't understand what you mean by that exactly (Latplan's encoder, encodes an image into a latent vector... that's it ), could you explain it to me please ?

With the code above I get the following error:
TypeError: unhashable type: 'numpy.ndarray'

latent and train_subdata_batch_cache are both numpy arrays, both with the following shape (400, 2, 28, 28, 3) where 400 is the batch size, 2 is because there are two image per example (it's a "sequence" of two images) and 28x28x3 is the image dimension.

Thanks a lot!

Aymeric

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.