Giter VIP home page Giter VIP logo

Comments (4)

rhaps0dy avatar rhaps0dy commented on September 1, 2024

Hi, I've made this python3 compatible for my own use. It didn't take very long.

Here is a diff. It's not fully tested yet.

diff --git a/compare_gan/src/eval_gan_lib.py b/compare_gan/src/eval_gan_lib.py
index becba4d..8d2594b 100644
--- a/compare_gan/src/eval_gan_lib.py
+++ b/compare_gan/src/eval_gan_lib.py
@@ -74,7 +74,7 @@ def GetAllTrainingParams():
   for gan_type in SUPPORTED_GANS:
     for dataset in ["mnist", "fashion-mnist", "cifar10", "celeba"]:
       p = params.GetParameters(gan_type, "wide")
-      all_params.update(list(p.keys()))
+      all_params.update(p.keys())
   logging.info("All training parameter exported: %s", sorted(all_params))
   return sorted(all_params)
 
diff --git a/compare_gan/src/gan_lib.py b/compare_gan/src/gan_lib.py
index 316cc52..ee41e7c 100644
--- a/compare_gan/src/gan_lib.py
+++ b/compare_gan/src/gan_lib.py
@@ -142,7 +142,7 @@ def create_gan(gan_type, dataset, dataset_content, options,
 def profile_context(tfprofile_dir):
   if "enable_tf_profile" in FLAGS and FLAGS.enable_tf_profile:
     with tf.contrib.tfprof.ProfileContext(
-        tfprofile_dir, trace_steps=list(range(100, 200, 1)), dump_steps=[200]):
+        tfprofile_dir, trace_steps=range(100, 200, 1), dump_steps=[200]):
       yield
   else:
     yield
diff --git a/compare_gan/src/gans/resnet_architecture_test.py b/compare_gan/src/gans/resnet_architecture_test.py
index c08e7b7..dc62261 100644
--- a/compare_gan/src/gans/resnet_architecture_test.py
+++ b/compare_gan/src/gans/resnet_architecture_test.py
@@ -14,11 +14,11 @@
 # limitations under the License.
 
 """Tests for Resnet architectures."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 from compare_gan.src.gans import resnet_architecture as resnet_arch
 import tensorflow as tf
 
@@ -43,8 +43,8 @@ class ResnetArchitectureTest(tf.test.TestCase):
   def testResnet5GeneratorRuns(self):
     generator_128 = TestResnet5GeneratorShape(128)
     generator_64 = TestResnet5GeneratorShape(64)
-    self.assertEqual(generator_128[0], generator_128[1])
-    self.assertEqual(generator_64[0], generator_64[1])
+    self.assertEquals(generator_128[0], generator_128[1])
+    self.assertEquals(generator_64[0], generator_64[1])
 
   def testResnet5DiscriminatorRuns(self):
     config = tf.ConfigProto(allow_soft_placement=True)
@@ -57,7 +57,7 @@ class ResnetArchitectureTest(tf.test.TestCase):
           reuse=False)
       tf.global_variables_initializer().run()
       output = sess.run([out])
-      self.assertEqual(output[0].shape, (batch_size, 1))
+      self.assertEquals(output[0].shape, (batch_size, 1))
 
   def testResnet107GeneratorRuns(self):
     config = tf.ConfigProto(allow_soft_placement=True)
@@ -70,7 +70,7 @@ class ResnetArchitectureTest(tf.test.TestCase):
           noise=z, is_training=True, reuse=False, colors=3)
       tf.global_variables_initializer().run()
       output = sess.run([g])
-      self.assertEqual(output[0].shape, (batch_size, 128, 128, 3))
+      self.assertEquals(output[0].shape, (batch_size, 128, 128, 3))
 
   def testResnet107DiscriminatorRuns(self):
     config = tf.ConfigProto(allow_soft_placement=True)
@@ -83,7 +83,7 @@ class ResnetArchitectureTest(tf.test.TestCase):
           discriminator_normalization="spectral_norm", reuse=False)
       tf.global_variables_initializer().run()
       output = sess.run([out])
-      self.assertEqual(output[0].shape, (batch_size, 1))
+      self.assertEquals(output[0].shape, (batch_size, 1))
 
 if __name__ == "__main__":
   tf.test.main()
diff --git a/compare_gan/src/generate_tasks_lib.py b/compare_gan/src/generate_tasks_lib.py
index 39d0850..7c21b9f 100644
--- a/compare_gan/src/generate_tasks_lib.py
+++ b/compare_gan/src/generate_tasks_lib.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 """Generate tasks for comparing GANs."""
+
 from __future__ import absolute_import
 from __future__ import division
 
@@ -143,7 +144,7 @@ def TestGansWithPenaltyNewDatasets(architecture):
 def GetDefaultParams(gan_params):
   """Return the default params for a GAN (=the ones used in the paper)."""
   ret = {}
-  for param_name, param_info in gan_params.items():
+  for param_name, param_info in gan_params.iteritems():
     ret[param_name] = param_info.default
   return ret
 
diff --git a/compare_gan/src/params.py b/compare_gan/src/params.py
index e637226..7ba2ee6 100644
--- a/compare_gan/src/params.py
+++ b/compare_gan/src/params.py
@@ -19,6 +19,7 @@ We define the default GAN parameters with respect to the datasets and the
 training hyperparameters. The hyperparameters used by the respective authors
 are also added to the set.
 """
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
diff --git a/compare_gan/src/params_test.py b/compare_gan/src/params_test.py
index b32a0ee..34c1ce7 100644
--- a/compare_gan/src/params_test.py
+++ b/compare_gan/src/params_test.py
@@ -14,11 +14,11 @@
 # limitations under the License.
 
 """Tests for compare_gan.params."""
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 from compare_gan.src import params
 import tensorflow as tf
 
@@ -27,10 +27,10 @@ class ParamsTest(tf.test.TestCase):
 
   def testParameterRanges(self):
     training_parameters = params.GetParameters("WGAN", "wide")
-    self.assertEqual(len(list(training_parameters.keys())), 5)
+    self.assertEqual(len(training_parameters.keys()), 5)
 
     training_parameters = params.GetParameters("BEGAN", "wide")
-    self.assertEqual(len(list(training_parameters.keys())), 6)
+    self.assertEqual(len(training_parameters.keys()), 6)
 
 
 if __name__ == "__main__":
diff --git a/compare_gan/src/simple_task_pb2.py b/compare_gan/src/simple_task_pb2.py
index 90a4699..fa58ed5 100644
--- a/compare_gan/src/simple_task_pb2.py
+++ b/compare_gan/src/simple_task_pb2.py
@@ -28,7 +28,7 @@ from google.protobuf import descriptor_pb2
 DESCRIPTOR = _descriptor.FileDescriptor(
   name='src/simple_task.proto',
   package='compare_gan',
-  serialized_pb=b'\n\x15src/simple_task.proto\x12\x0b\x63ompare_gan\"t\n\rTaskDimension\x12\x11\n\tparameter\x18\x01 \x01(\t\x12\x14\n\x0cstring_value\x18\x02 \x01(\t\x12\x11\n\tint_value\x18\x03 \x01(\x05\x12\x13\n\x0b\x66loat_value\x18\x04 \x01(\x02\x12\x12\n\nbool_value\x18\x05 \x01(\x08\"C\n\x04Task\x12\x0b\n\x03num\x18\x01 \x01(\x05\x12.\n\ndimensions\x18\x07 \x03(\x0b\x32\x1a.compare_gan.TaskDimension')
+  serialized_pb='\n\x15src/simple_task.proto\x12\x0b\x63ompare_gan\"t\n\rTaskDimension\x12\x11\n\tparameter\x18\x01 \x01(\t\x12\x14\n\x0cstring_value\x18\x02 \x01(\t\x12\x11\n\tint_value\x18\x03 \x01(\x05\x12\x13\n\x0b\x66loat_value\x18\x04 \x01(\x02\x12\x12\n\nbool_value\x18\x05 \x01(\x08\"C\n\x04Task\x12\x0b\n\x03num\x18\x01 \x01(\x05\x12.\n\ndimensions\x18\x07 \x03(\x0b\x32\x1a.compare_gan.TaskDimension')
 
 
 
@@ -43,14 +43,14 @@ _TASKDIMENSION = _descriptor.Descriptor(
     _descriptor.FieldDescriptor(
       name='parameter', full_name='compare_gan.TaskDimension.parameter', index=0,
       number=1, type=9, cpp_type=9, label=1,
-      has_default_value=False, default_value="",
+      has_default_value=False, default_value=u"",
       message_type=None, enum_type=None, containing_type=None,
       is_extension=False, extension_scope=None,
       options=None),
     _descriptor.FieldDescriptor(
       name='string_value', full_name='compare_gan.TaskDimension.string_value', index=1,
       number=2, type=9, cpp_type=9, label=1,
-      has_default_value=False, default_value="",
+      has_default_value=False, default_value=u"",
       message_type=None, enum_type=None, containing_type=None,
       is_extension=False, extension_scope=None,
       options=None),
@@ -127,12 +127,14 @@ _TASK.fields_by_name['dimensions'].message_type = _TASKDIMENSION
 DESCRIPTOR.message_types_by_name['TaskDimension'] = _TASKDIMENSION
 DESCRIPTOR.message_types_by_name['Task'] = _TASK
 
-class TaskDimension(_message.Message, metaclass=_reflection.GeneratedProtocolMessageType):
+class TaskDimension(_message.Message):
+  __metaclass__ = _reflection.GeneratedProtocolMessageType
   DESCRIPTOR = _TASKDIMENSION
 
   # @@protoc_insertion_point(class_scope:compare_gan.TaskDimension)
 
-class Task(_message.Message, metaclass=_reflection.GeneratedProtocolMessageType):
+class Task(_message.Message):
+  __metaclass__ = _reflection.GeneratedProtocolMessageType
   DESCRIPTOR = _TASK
 
   # @@protoc_insertion_point(class_scope:compare_gan.Task)
diff --git a/compare_gan/src/task_utils.py b/compare_gan/src/task_utils.py
index 4c45646..b88c155 100644
--- a/compare_gan/src/task_utils.py
+++ b/compare_gan/src/task_utils.py
@@ -64,7 +64,7 @@ def UnrollCalls(function, kwargs):
     b x c in [2,4] x [5, 6].
   """
   res = []
-  for key, value in sorted(kwargs.items()):
+  for key, value in sorted(kwargs.iteritems()):
     assert not isinstance(key, tuple)
     if isinstance(value, list):
       for v in value:
@@ -104,7 +104,7 @@ def MakeDimensions(dim_dict, extra_dims=None, base_task=None):
     dim_dict = copy.copy(dim_dict)
     dim_dict.update(extra_dims)
   dim_dict = collections.OrderedDict(sorted(dim_dict.items()))
-  for key, value in dim_dict.items():
+  for key, value in dim_dict.iteritems():
     if key in ("_proto", "_prefix"):
       # We skip the special keys.
       continue

from compare_gan.

eyaler avatar eyaler commented on September 1, 2024

also need to change cPickle to pickle in gilbo.py

from compare_gan.

eyaler avatar eyaler commented on September 1, 2024

also in gilbo.py change:

uninitialized = plist(tf.report_uninitialized_variables().eval())

to

uninitialized = plist(tf.report_uninitialized_variables().eval().astype(str))

from compare_gan.

Marvin182 avatar Marvin182 commented on September 1, 2024

Thank you for your feedback. We released an update and all code should now be compatible Python 3 compatible. Please open a new ticket if there are still issues.

from compare_gan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.