Skip to content
This repository was archived by the owner on Jun 15, 2026. It is now read-only.

Simplify lq.math implementations#677

Open
lgeiger wants to merge 2 commits into
mainfrom
simplify-lq-math
Open

Simplify lq.math implementations#677
lgeiger wants to merge 2 commits into
mainfrom
simplify-lq-math

Conversation

@lgeiger

@lgeiger lgeiger commented Jun 4, 2021

Copy link
Copy Markdown
Member

This PR simplifies the implementations of lq.math to make them more readable. Currently this would break Compute Engine, since we rely on the implementation details there. We could either fix this using @tf.function(experimental_implements=...) which would require us dropping support for older TensorFlow version, or we could update the patterns in Compute Engine to handle this implementation properly which shouldn't be hard.

I briefly evaluated the performance impact of this change on a T4 GPU in this notebook (although only on one input size). This shows performance improvements for some cases and performance regressions for others, although I am not sure if this would be noticeable in a real world model. I am happy to run benchmarks if you think this change might impact full model training performance.

This PR is still WIP for now to explore this further.

@lgeiger lgeiger added the breaking-change Changes that will break user code label Jun 4, 2021
@lgeiger lgeiger requested a review from AdamHillier June 4, 2021 12:51
@Tombana

Tombana commented Jun 7, 2021

Copy link
Copy Markdown
Contributor

We could either fix this using @tf.function(experimental_implements=...) which would require us dropping support for older TensorFlow version, or we could update the patterns in Compute Engine ..

Minor comment: @tf.function(experimental_implements=...) doesn't make it convert, it only adds a tag that the converter can read, you'd still need to add the pattern to compute engine, so you might as well make it a proper pattern.

@lgeiger

lgeiger commented Jun 7, 2021

Copy link
Copy Markdown
Member Author

Minor comment: @tf.function(experimental_implements=...) doesn't make it convert, it only adds a tag that the converter can read, you'd still need to add the pattern to compute engine, so you might as well make it a proper pattern.

Yes, you are right. The idea behind it was that we then would be implementation independent, but I agree a proper pattern is much easier and we can still keep it backwards compatible.

@AdamHillier

Copy link
Copy Markdown
Contributor

I agree that matching directly against the pattern in LCE would be good. If it doesn't cause any issues, it might be nice to still keep the experimental_implements so that if we want to change the converter implementation to use it at some point in the future then that 'tag' will already be there in existing versions of Larq.

How do you think we should manage the scenario where somebody uses a new version of Larq but an old version of the converter? I guess on init we could try and detect if LCE is installed, detect the version, and print some kind of warning message if it's <= 0.6?

@lgeiger

lgeiger commented Jun 17, 2021

Copy link
Copy Markdown
Member Author

How do you think we should manage the scenario where somebody uses a new version of Larq but an old version of the converter? I guess on init we could try and detect if LCE is installed, detect the version, and print some kind of warning message if it's <= 0.6?

Yes, I think that should work well. That's how TensorFlow Addons handles it as well.

@lgeiger lgeiger force-pushed the simplify-lq-math branch from 39d15e6 to e24265d Compare July 13, 2021 09:52
@lgeiger

lgeiger commented Jul 13, 2021

Copy link
Copy Markdown
Member Author

I added a LCE version check to __init__ in 3f16474 now that LCE 0.6.1 has support for this.

I'd be happy to move this check to setup.py though in case warning during import is too verbose.

@lgeiger lgeiger force-pushed the simplify-lq-math branch from e24265d to 3f16474 Compare July 13, 2021 09:54
Comment thread larq/math.py
Comment thread larq/utils.py
@lgeiger lgeiger force-pushed the simplify-lq-math branch from eb4c852 to 810bede Compare July 13, 2021 10:59

@AdamHillier AdamHillier left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome :)

I guess we should run a quick sanity check model conversion with LCE 0.6.1 before merging, but this looks great.

@lgeiger

lgeiger commented Jul 13, 2021

Copy link
Copy Markdown
Member Author

I guess we should run a quick sanity check model conversion with LCE 0.6.1 before merging, but this looks great.

I will run a sanity check later today or tomorrow to be save!

@simonmaurer

simonmaurer commented Jul 28, 2021

Copy link
Copy Markdown

@lgeiger @AdamHillier @Tombana : just as a suggestion since I'm dealing with larq.quantizers.SteSign and larq.quantizers.SteHeaviside model conversions.
for larq.math.sign and larq.math.heaviside to use:

tf.where(x >= 0., 1., -1.)

instead of

tf.where(x >= 0, tf.ones_like(x), -tf.ones_like(x))

This would result in fewer ops.

@lgeiger

lgeiger commented Jul 28, 2021

Copy link
Copy Markdown
Member Author

@simonmaurer Did you do any performance profiling of you proposed solution?

One problem with using tf.where(x >= 0., 1., -1.) is that this might change the datatype (e.g. it will always return float32, even if x is float16 or int8). This means we would need to use something like tf.where(x >= 0, tf.constant(1, dtype=x.dtype), tf.constant(-1, dtype=x.dtype)) or do some casting which wouldn't reduce the number of ops.
I briefly profiled some of the options on a T4 GPU in this notebook, but please let me know if I am missing something here.

@simonmaurer

simonmaurer commented Jul 28, 2021

Copy link
Copy Markdown

@lgeiger ah I see, this is indeed good reasoning. I didn't do any performance tests yet.
I implicitly assumed you are working with tf.float32 tensors since this is the data type that LceDequantize outputs in the end.
but now thinking about it again, the LceDequantize also handles dequantization to tf.int8, otherwise there would be no need for the additional Dequantize op at the model output. so definitely a good point.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

breaking-change Changes that will break user code

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants