Neural Network Showdown: TensorFlow vs PyTorch

Neural Network Showdown: TensorFlow Vs PyTorch

Why discuss the topic TensorFlow vs PyTorch? Python has been the language of choice for most AI and ML engineers. TensorFlow and PyTorch are the two Python libraries that have really accelerated the use of neural networks. This post compares each of them, and lets you make up your own mind as to which might be more appropriate for use in your next ML/data science project.

The past few years have seen a huge surge of interest in Artificial Intelligence (AI) and Machine Learning (ML) technologies, which are being used to build more human-like applications. With the advent of artificial neural networks (eg., deep learning), higher complexity data and more sophisticated models have become accessible to the everyday programmer. 

With a traditional ML system, a model needed to be trained on both data and feature sets (ie., labels), which meant that the data had to be structured and each value manually represented by a tag. This posed a problem for visual tasks such as image recognition. 

Now, with artificial neural networks, one only needs to train with data and the model is responsible for finding patterns and links within the data. This works great with image, audio, video and even text data! Artificial neural networks seek to mimic the function of biological neural networks:

Real vs artificial neurons

TensorFlow vs PyTorch: History

Both libraries are open source and contain licensing appropriate for commercial projects.

TensorFlow was first developed by the Google Brain team in 2015, and is currently used by Google for both research and production purposes. 

PyTorch, on the other hand, was primarily developed by Facebook based on the popular Torch framework, and initially acted as an advanced replacement for NumPy. However, in early 2018, Caffe2 (Convolutional Architecture for Fast Feature Embedding) was merged into PyTorch, effectively dividing PyTorch’s focus between data analytics and deep learning.

TensorFlow vs PyTorch: Prevalence

Until recently, no other deep learning library could compete in the same class as TensorFlow. It’s heavily used, has great community/forum support, and constantly receives press from Google themselves. The TensorFlow team has also released TensorFlow Lite, which can run on mobile devices. To speed up processing with TensorFlow, there are hardware devices like the Tensor Processing Unit (TPU) that can be accessed on the Google Cloud Platform, and Edge TPU, which is an ASIC chip found on most Android devices for running TensorFlow Lite.

Recently, Google launched the Machine Learning Crash Course (MLCC), which aims to equip developers with practical artificial intelligence and machine learning fundamentals – for free!

In the challenger’s corner is PyTorch, which will feel familiar to most Python developers. It can act as a powerful replacement for NumPy, which is the industry-standard, general-purpose array-processing package. In fact, NumPy is most likely the first library anyone interested in data science or machine learning comes across. Since PyTorch has a very similar interface to NumPy, Python developers can easily migrate to it.

TensorFlow, on the other hand, at first appears to be designed with some peculiar logic featuring concepts like placeholders and sessions. As a result, the learning curve for TensorFlow can be quite steep. This is one of the major reasons PyTorch is gaining momentum.

TensorFlow vs PyTorch: Technical Differences

Dynamic Computational Graphs

Where PyTorch really shines is its use of dynamic rather than static (which TensorFlow uses) computational graphs. Deep learning frameworks use computational graphs to define the order of computations that need to be performed in any artificial neural network. 

Static graphs need to be compiled before you can test with the model. This is incredibly tedious and does not lend itself to quick prototyping. For example, with TensorFlow the entire computation graph has to be defined before you can run the model. But with PyTorch, you can define and manipulate your graph on the fly. This greatly increases developer productivity, and is helpful while using variable length inputs in Recurrent Neural Networks (RNNs).

Fortunately, TensorFlow has added Dynamic Computation Graph support with the release of its TensorFlow Fold library in 2018.

Saving and Loading Models

Both libraries save and load models quite well. PyTorch has a simple API that can save all the weights of a model for easier reproduction. You can also pickle the entire model.

TensorFlow also handles save/load extremely well. The entire model can be saved as a protocol buffer, including parameters and operations. This feature also supports saving your model in one language, and then loading it in another language like C++ or Java, which can be critical for deployment stacks where Python is not an option.


The traditional interface for an AI/ML model is a REST API. For most Python models, a simple Flask server is created to provide easy access. Both libraries can easily be wrapped with a Flask server.

For mobile and embedded deployments, TensorFlow is by far the best way to go. With the help of tools like TensorFlow Lite, it’s very easy to integrate into Android and even iOS frameworks.

Another great feature is TensorFlow Serving. Models become outdated over time and need to be retrained with new data. TensorFlow Serving allows old models to be swapped out with new ones without bringing the entire service down.

Practical Differences

To see how TensorFlow and PyTorch differ in practice, let’s look at a simple example of a neural network that mimics the function of an OR gate.

An OR gate simply returns “true” (Z value) if at least one input signal (X or Y) is true. Here’s the truth table:

0 0 0
0 1 1
1 0 1
1 1 1

We’ll feed our model the x_data and y_data from the first three rows of the OR gate truth table. The objective for the neural network will be to predict the output for (1,1).

TensorFlow vs PyTorch: Model Creation

First, we’ll look at how to model the OR gate with TensorFlow.

TensorFlow Model

The above code will create a sigmoid neural network with one input, one hidden, and one output layer. And here’s where the TensorFlow quirkiness kicks in, with terms like placeholders, variables, constants, and more. Let’s look at what each of them means:

  • Placeholder: A placeholder is simply a variable that we will assign data to at a later date.
  • Variables: The best way to represent a shared, persistent state manipulated by your program.
  • Constants: A variable whose value cannot be changed.
sess = tf.Session()
init = tf.initialize_all_variables()

This is where the actual training happens. First, the static computational graph gets created and then executed with the session. Once training is over, the session provides you access to the model. The model can be used to predict or determine the accuracy.

Now, let’s look at how to model our OR gate with PyTorch.

PyTorch Model

If you’re familiar with Python, you can instantly see the difference in the PyTorch code, which has a very familiar interface and feels very “Python” like.

TensorFlow vs PyTorch: Conclusion

For Python developers just getting started with deep learning, PyTorch may offer less of a ramp up time. In fact, ease of use is one of the key reasons that a recent study found PyTorch is gaining more acceptance in academia than TensorFlow. But if ease of use is an issue, I’d recommend having a look at Keras, which provides a high level wrapper around TensorFlow. Keras dramatically simplifies TensorFlow implementation for a wide range of use cases.

Another recent survey of over 1,600 developers working with  ML and data science found that PyTorch may be the preferred framework for data analysis and ad hoc modeling, whereas TensorFlow is counted on to produce production models.

My personal preference is TensorFlow. It has much better community support and stable deployment strategies. And with the recent release of version 2.0, its feature set is catching up with the innovations offered by PyTorch. Though the learning curve is steep with TensorFlow, the payoff is worth it. 

Which one do you use, and why? Comment down below.

top ML packages runtime

For Windows | For Linux


Python for machine learning - download

Related Blogs:

Top 10 Python Packages for Machine Learning

How to Clean Machine Learning Datasets Using Pandas

Recent Posts

Scroll to Top