How to Calculate NDVI with Python: A Practical Guide for Agricultural Scenario

NDVI with Python

Introduction: Why Every Agricultural Scientist Should Know NDVI

If you work in agriculture, chances are you’ve heard the term NDVI. Normalized Difference Vegetation Index — a mouthful, but one of the most powerful and widely-used tools in modern agricultural science.

Think of NDVI as a crop health report card, derived entirely from satellite images. In one number — ranging from -1 to +1 — it tells you whether your field is thriving, stressed, or bare.

But here’s the gap I see all the time among agricultural researchers in India: they understand NDVI conceptually, but they depend entirely on paid software like ERDAS Imagine or commercial platforms to compute it. That’s expensive, slow, and inflexible.

With Python, you can calculate NDVI for any field, anywhere in India, for free — in under 20 lines of code.

In this post, I’ll show you exactly how, using real Sentinel-2 satellite data. By the end, you’ll be able to:

  • Understand what NDVI measures and why it matters for agricultural research
  • Download and read multispectral satellite imagery in Python
  • Calculate NDVI using the rasterio and numpy libraries
  • Visualise NDVI maps with matplotlib
  • Apply NDVI analysis to real agricultural problems (crop monitoring, drought assessment, yield estimation)

I use this workflow regularly in my own research. Let’s get into it.


What is NDVI and Why Does It Matter?

NDVI stands for Normalized Difference Vegetation Index. It was developed by NASA scientists in the 1970s and remains the most widely used vegetation index in the world today.

The formula is simple:

NDVI = (NIR - Red) / (NIR + Red)

Where:

  • NIR = Near-Infrared band reflectance
  • Red = Red band reflectance

Healthy green vegetation absorbs red light for photosynthesis and strongly reflects near-infrared light. Stressed or sparse vegetation does the opposite. This contrast is what NDVI captures.

NDVI with Python

NDVI value interpretation:

NDVI RangeWhat It MeansAgricultural Implication
0.8 to 1.0Dense, healthy vegetationPeak crop canopy, excellent health
0.6 to 0.8Moderate-high vegetationGood crop growth
0.4 to 0.6Moderate vegetationAverage crop health, possible stress
0.2 to 0.4Sparse vegetationStress, early season, or thin canopy
0.0 to 0.2Very sparse/bare soilBare ground, very early stage
Below 0.0Non-vegetationWater bodies, built-up areas, clouds

In Indian agricultural contexts, NDVI is used for:

  1. Crop condition monitoring — Is the paddy crop healthy in July? Are wheat fields uniform in November?
  2. Drought and stress detection — Which districts show early signs of moisture stress?
  3. Yield forecasting — NDVI at critical crop growth stages correlates strongly with final yield
  4. Kharif/Rabi crop mapping — Identifying crop types across a region based on NDVI time series
  5. Assessment of PMFBY claims — Insurance companies and state governments use NDVI to validate crop loss claims

We use NDVI-based analyses to support everything from crop condition reports to state-level advisories. Once you know how to compute it in Python, you unlock all of this yourself.


What Satellite Data Will We Use?

We’ll use Sentinel-2 imagery from the European Space Agency (ESA). It’s free, it covers all of India, and it has a 10-metre spatial resolution for the bands we need — fine enough to see individual farm fields.

Key Sentinel-2 bands for NDVI:

BandNameWavelengthResolutionRole in NDVI
Band 4Red665 nm10 mDenominator + difference
Band 8NIR842 nm10 mNumerator + difference

You can download Sentinel-2 data freely from:

  • Copernicus Open Access Hub: https://scihub.copernicus.eu/
  • Google Earth Engine (for larger areas or time series)
  • AWS Open Data Registry: free access, no sign-up needed for many datasets
  • USGS EarthExplorer: also carries some ESA products

For this tutorial, I’ll walk you through reading a pre-downloaded GeoTIFF. If you’d like a separate post on how to download Sentinel-2 data from Copernicus Hub using Python, let me know in the comments.


Setting Up Your Python Environment

Before writing any code, let’s make sure the right libraries are installed.

pip install rasterio numpy matplotlib geopandas

If you’re working on a government server or a restricted environment, you can also run this in Google Colab for free — no installation needed.

Libraries we’ll use:

  • rasterio — read and write geospatial raster files (GeoTIFFs)
  • numpy — numerical operations, including the NDVI formula
  • matplotlib — visualisation and NDVI map plotting
  • geopandas — optional, for masking to a specific study area

Step 1: Import Libraries and Load the Satellite Bands

import rasterio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import warnings
warnings.filterwarnings('ignore')

# Define paths to your Sentinel-2 band files
# Replace these with your actual file paths
red_band_path = "T44QKF_20231015_B04_10m.tif"   # Band 4 (Red)
nir_band_path  = "T44QKF_20231015_B08_10m.tif"   # Band 8 (NIR)

# Open the Red band
with rasterio.open(red_band_path) as red_src:
    red_band = red_src.read(1).astype(float)    # Read as float for division
    profile = red_src.profile                    # Save metadata for output file
    print(f"Red band shape: {red_band.shape}")
    print(f"CRS: {red_src.crs}")
    print(f"Transform: {red_src.transform}")

# Open the NIR band
with rasterio.open(nir_band_path) as nir_src:
    nir_band = nir_src.read(1).astype(float)
    print(f"NIR band shape: {nir_band.shape}")

What this does: We read both bands as 2D numpy arrays. The .astype(float) is important — integer arrays would give wrong results during division.

Note: Sentinel-2 filenames follow the format T[tile]_[date]_[band]_[resolution].tif. The tile code for Assam and northeast India is typically T46RBN, T46RBM, or T46QBL depending on your study area.


Step 2: Calculate NDVI

This is the core calculation. Two lines of Python.

# Avoid division by zero — set to NaN where both bands are zero
denominator = nir_band + red_band
denominator[denominator == 0] = np.nan

# Calculate NDVI
ndvi = (nir_band - red_band) / denominator

# Quick summary statistics
print(f"NDVI Statistics:")
print(f"  Min:  {np.nanmin(ndvi):.4f}")
print(f"  Max:  {np.nanmax(ndvi):.4f}")
print(f"  Mean: {np.nanmean(ndvi):.4f}")
print(f"  Std:  {np.nanstd(ndvi):.4f}")

Sample output for an agricultural area in Assam (October, post-Kharif):

NDVI Statistics:
  Min:  -0.2341
  Max:   0.8923
  Mean:  0.4512
  Std:   0.1876

A mean NDVI of 0.45 in October for this area is consistent with late Kharif paddy — the crop is still standing but past its peak greenness.


Step 3: Visualise the NDVI Map

A raw array of numbers isn’t useful for agricultural interpretation. We need a colour-coded map.

# Set up a professional NDVI colour map
# RdYlGn: Red (low/stressed) → Yellow (moderate) → Green (healthy)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle("NDVI Analysis — Agricultural Area, Assam (October 2023)",
             fontsize=14, fontweight='bold')

# Panel 1: Red Band (visible light)
ax1 = axes[0]
im1 = ax1.imshow(red_band, cmap='Reds_r', vmin=0, vmax=3000)
ax1.set_title("Red Band (Band 4)", fontsize=11)
ax1.axis('off')
plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04, label="Reflectance")

# Panel 2: NIR Band
ax2 = axes[1]
im2 = ax2.imshow(nir_band, cmap='YlOrRd', vmin=0, vmax=5000)
ax2.set_title("NIR Band (Band 8)", fontsize=11)
ax2.axis('off')
plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04, label="Reflectance")

# Panel 3: NDVI Map
ax3 = axes[2]
im3 = ax3.imshow(ndvi, cmap='RdYlGn', vmin=-0.3, vmax=0.9)
ax3.set_title("NDVI — Vegetation Health Index", fontsize=11)
ax3.axis('off')
cbar = plt.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)
cbar.set_label("NDVI Value", rotation=270, labelpad=15)
cbar.set_ticks([-0.3, 0.0, 0.2, 0.4, 0.6, 0.8])
cbar.set_ticklabels(['Water/Cloud', 'Bare Soil', 'Sparse', 'Moderate', 'Good', 'Dense'])

plt.tight_layout()
plt.savefig("ndvi_analysis_assam.png", dpi=200, bbox_inches='tight')
plt.show()
print("Map saved as ndvi_analysis_assam.png")

Step 4: Classify NDVI into Agricultural Categories

Raw NDVI values are useful, but for field reports and policy briefs, classified maps are more interpretable. Let’s create a crop health classification.

NDVI with Python, classified image
# Create NDVI classification for agricultural interpretation
def classify_ndvi_agriculture(ndvi_array):
    """
    Classify NDVI into agricultural health categories.
    Suitable for Kharif and Rabi crop monitoring in India.
    """
    classified = np.full(ndvi_array.shape, np.nan)
    
    classified[ndvi_array < 0.0]                       = 0  # Water / Non-veg
    classified[(ndvi_array >= 0.0) & (ndvi_array < 0.2)] = 1  # Bare/Very sparse
    classified[(ndvi_array >= 0.2) & (ndvi_array < 0.4)] = 2  # Sparse / stressed crop
    classified[(ndvi_array >= 0.4) & (ndvi_array < 0.6)] = 3  # Moderate crop health
    classified[(ndvi_array >= 0.6) & (ndvi_array < 0.8)] = 4  # Good crop health
    classified[ndvi_array >= 0.8]                       = 5  # Excellent / dense canopy
    
    return classified

ndvi_classified = classify_ndvi_agriculture(ndvi)

# Create a custom colour map for the classified map
class_colors = ['#2166AC',  # 0 — Blue: Water/Non-veg
                '#D7191C',  # 1 — Red: Bare/Very sparse
                '#FDAE61',  # 2 — Orange: Stressed crop
                '#FEE08B',  # 3 — Yellow: Moderate
                '#A6D96A',  # 4 — Light green: Good
                '#1A9641']  # 5 — Dark green: Excellent

cmap_custom = mcolors.ListedColormap(class_colors)
bounds = [-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5]
norm = mcolors.BoundaryNorm(bounds, cmap_custom.N)

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(ndvi_classified, cmap=cmap_custom, norm=norm)
ax.set_title("NDVI Crop Health Classification\nAgricultural Area, Assam (October 2023)",
             fontsize=13, fontweight='bold')
ax.axis('off')

cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.04)
cbar.set_ticks([0, 1, 2, 3, 4, 5])
cbar.set_ticklabels(['Water / Non-veg',
                      'Bare / Very Sparse',
                      'Stressed Crop',
                      'Moderate Health',
                      'Good Health',
                      'Excellent / Dense'])
cbar.set_label("Crop Health Category", rotation=270, labelpad=20)

plt.tight_layout()
plt.savefig("ndvi_classified_map.png", dpi=200, bbox_inches='tight')
plt.show()

Step 5: Calculate Area Statistics by Category

For an agricultural report, you’ll want to know how many hectares fall in each category.

# Calculate pixel counts and area for each class
# Sentinel-2 at 10m resolution: each pixel = 10m × 10m = 0.01 hectares

pixel_area_ha = 0.01  # hectares per pixel (10m × 10m)

categories = {
    0: "Water / Non-veg",
    1: "Bare / Very Sparse",
    2: "Stressed Crop",
    3: "Moderate Health",
    4: "Good Health",
    5: "Excellent / Dense"
}

print("=" * 55)
print("NDVI AREA STATISTICS — Assam Agricultural Area")
print("=" * 55)
print(f"{'Category':<25} {'Pixels':>10} {'Area (ha)':>12} {'%':>8}")
print("-" * 55)

total_valid_pixels = np.sum(~np.isnan(ndvi_classified))

for class_val, class_name in categories.items():
    pixel_count = np.sum(ndvi_classified == class_val)
    area_ha = pixel_count * pixel_area_ha
    pct = (pixel_count / total_valid_pixels) * 100 if total_valid_pixels > 0 else 0
    print(f"{class_name:<25} {pixel_count:>10,} {area_ha:>12,.1f} {pct:>7.1f}%")

print("-" * 55)
print(f"{'TOTAL':<25} {total_valid_pixels:>10,} "
      f"{total_valid_pixels * pixel_area_ha:>12,.1f} {'100.0':>8}%")
print("=" * 55)

Sample output:

=======================================================
NDVI AREA STATISTICS — Assam Agricultural Area
=======================================================
Category                  Pixels      Area (ha)        %
-------------------------------------------------------
Water / Non-veg            12,450        124.5      5.2%
Bare / Very Sparse          8,230         82.3      3.4%
Stressed Crop              18,750        187.5      7.8%
Moderate Health            64,300        643.0     26.8%
Good Health               98,760        987.6     41.2%
Excellent / Dense          37,320        373.2     15.6%
-------------------------------------------------------
TOTAL                     239,810      2,398.1    100.0%
=======================================================

This kind of output goes directly into a district crop situation report or a field survey planning document.


NDVI with Python, moisture map of the study area

Step 6: Save the NDVI Raster as a GeoTIFF

For use in GIS software (QGIS, ArcGIS) or further analysis, save the NDVI output as a proper georeferenced GeoTIFF.

# Save NDVI as GeoTIFF with geospatial metadata preserved
output_path = "ndvi_output_assam.tif"

# Update profile for float32 single-band output
profile.update({
    'dtype': rasterio.float32,
    'count': 1,
    'nodata': -9999
})

# Replace NaN with the nodata value before writing
ndvi_to_save = ndvi.copy()
ndvi_to_save[np.isnan(ndvi_to_save)] = -9999

with rasterio.open(output_path, 'w', **profile) as dst:
    dst.write(ndvi_to_save.astype(rasterio.float32), 1)

print(f"NDVI GeoTIFF saved successfully: {output_path}")
print(f"File size: {os.path.getsize(output_path) / (1024*1024):.1f} MB")

You can now open ndvi_output_assam.tif directly in QGIS for further GIS analysis, or share it with colleagues who don’t use Python.


Real-World Application: Monitoring Paddy Crop Health in Assam

Let me show you how this workflow plays out in a real agricultural scenario.

Scenario: It’s late October. The Kharif paddy crop is in its grain-filling stage across Kamrup district, Assam. There were reports of drought stress in some blocks. The district agriculture officer wants to know which blocks are showing low NDVI before planning the relief package.

Using the workflow above, here’s what you can do in a single afternoon:

  1. Download Sentinel-2 images for October 2023 for Kamrup district (free from Copernicus Hub)
  2. Calculate NDVI for each 10m pixel across the district
  3. Clip the raster to administrative block boundaries using geopandas
  4. Compute mean NDVI per block
  5. Map the blocks by average NDVI — immediately showing which blocks are stressed
import geopandas as gpd
from rasterio.mask import mask
import json

# Load block-level shapefile for Kamrup district
# (Available from ICRISAT VDSA or state GIS portals)
blocks_gdf = gpd.read_file("kamrup_blocks.shp")

# Ensure same CRS as the raster
blocks_gdf = blocks_gdf.to_crs("EPSG:32646")  # UTM Zone 46N for Assam

# Compute mean NDVI per block
block_ndvi_results = []

with rasterio.open(nir_band_path) as src:
    for idx, row in blocks_gdf.iterrows():
        geom = [json.loads(row.geometry.json())]
        
        try:
            # Open raster and mask to block boundary
            with rasterio.open("ndvi_output_assam.tif") as ndvi_src:
                masked_ndvi, _ = mask(ndvi_src, geom, crop=True)
                valid_vals = masked_ndvi[masked_ndvi != -9999]
                
                if len(valid_vals) > 0:
                    mean_ndvi = float(np.nanmean(valid_vals))
                else:
                    mean_ndvi = np.nan
                    
        except Exception:
            mean_ndvi = np.nan
        
        block_ndvi_results.append({
            'block_name': row['BLOCK_NAME'],
            'mean_ndvi': mean_ndvi
        })

# Add results back to GeoDataFrame
import pandas as pd
ndvi_df = pd.DataFrame(block_ndvi_results)
blocks_gdf = blocks_gdf.merge(ndvi_df, left_on='BLOCK_NAME', right_on='block_name')

# Identify stressed blocks (mean NDVI < 0.4 in October = concern)
stressed_blocks = blocks_gdf[blocks_gdf['mean_ndvi'] < 0.4]
print(f"\nBlocks showing possible crop stress (NDVI < 0.40):")
print(stressed_blocks[['BLOCK_NAME', 'mean_ndvi']].sort_values('mean_ndvi'))

This analysis — which would take days in traditional software — runs in minutes with Python. And it’s fully reproducible, documented, and shareable.


NDVI Time Series: Tracking Crop Growth Through the Season

One of the most powerful applications of NDVI is monitoring how it changes across the crop season. Here’s a simple example using multi-date NDVI values:

import matplotlib.pyplot as plt
import numpy as np

# NDVI values for a paddy field in Assam — Kharif 2023
# (You would compute these from actual multi-date satellite images)
dates = ['June 15', 'July 01', 'July 15', 'Aug 01', 'Aug 15',
         'Sep 01', 'Sep 15', 'Oct 01', 'Oct 15', 'Nov 01']

# Healthy paddy field (good season)
ndvi_healthy = [0.12, 0.28, 0.45, 0.63, 0.72, 0.78, 0.74, 0.65, 0.52, 0.35]

# Stress-affected paddy field (drought mid-season)
ndvi_stressed = [0.10, 0.22, 0.38, 0.45, 0.42, 0.40, 0.37, 0.30, 0.22, 0.18]

fig, ax = plt.subplots(figsize=(12, 6))

ax.plot(dates, ndvi_healthy, 'g-o', linewidth=2.5, markersize=8,
        label='Healthy Paddy (Kamrup Block A)')
ax.plot(dates, ndvi_stressed, 'r--s', linewidth=2.5, markersize=8,
        label='Stress-Affected Paddy (Kamrup Block B)')

# Add reference lines
ax.axhline(y=0.6, color='gray', linestyle=':', alpha=0.7, label='Moderate-Good threshold (0.6)')
ax.axhline(y=0.4, color='orange', linestyle=':', alpha=0.7, label='Stress warning threshold (0.4)')

# Shade the growing season
ax.axvspan(2, 7, alpha=0.05, color='green', label='Active growing season')

ax.set_xlabel('Date (Kharif 2023)', fontsize=12)
ax.set_ylabel('Mean NDVI', fontsize=12)
ax.set_title('NDVI Time Series — Paddy Crop Monitoring, Kamrup District, Assam\n'
             'Kharif Season 2023', fontsize=13, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.set_ylim(0, 0.95)
ax.grid(True, alpha=0.3)
plt.xticks(rotation=30)

plt.tight_layout()
plt.savefig("ndvi_time_series_paddy.png", dpi=200, bbox_inches='tight')
plt.show()

The divergence between the two curves — starting from August — clearly shows where drought stress began. This kind of analysis supports early warning systems for district agriculture offices.


Complete Code: All Steps in One Script

Here’s the full workflow in a clean, ready-to-run script:

"""
NDVI Calculation with Python — Complete Workflow
Author: Data Science with DEB
Website: https://dibyendudeb.com
Use case: Agricultural crop health monitoring using Sentinel-2 data

Requirements:
    pip install rasterio numpy matplotlib geopandas
"""

import os
import numpy as np
import rasterio
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import warnings
warnings.filterwarnings('ignore')

# ─── CONFIG ────────────────────────────────────────────────────────────────
RED_BAND_PATH = "T44QKF_20231015_B04_10m.tif"   # Sentinel-2 Band 4 (Red)
NIR_BAND_PATH = "T44QKF_20231015_B08_10m.tif"   # Sentinel-2 Band 8 (NIR)
OUTPUT_NDVI   = "ndvi_output.tif"
PIXEL_AREA_HA = 0.01  # 10m x 10m pixel = 0.01 hectares

# ─── LOAD BANDS ────────────────────────────────────────────────────────────
print("Loading satellite bands...")
with rasterio.open(RED_BAND_PATH) as red_src:
    red = red_src.read(1).astype(float)
    profile = red_src.profile

with rasterio.open(NIR_BAND_PATH) as nir_src:
    nir = nir_src.read(1).astype(float)

# ─── CALCULATE NDVI ────────────────────────────────────────────────────────
print("Calculating NDVI...")
denom = nir + red
denom[denom == 0] = np.nan
ndvi = (nir - red) / denom

print(f"NDVI range: {np.nanmin(ndvi):.4f} to {np.nanmax(ndvi):.4f}")
print(f"Mean NDVI:  {np.nanmean(ndvi):.4f}")

# ─── VISUALISE ─────────────────────────────────────────────────────────────
print("Creating NDVI map...")
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(ndvi, cmap='RdYlGn', vmin=-0.3, vmax=0.9)
ax.set_title("NDVI — Vegetation Health Index", fontsize=14, fontweight='bold')
ax.axis('off')
cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.04)
cbar.set_label("NDVI Value", rotation=270, labelpad=20)
plt.tight_layout()
plt.savefig("ndvi_map.png", dpi=200, bbox_inches='tight')
plt.close()

# ─── CLASSIFY ──────────────────────────────────────────────────────────────
print("Classifying NDVI into agricultural categories...")
classified = np.full(ndvi.shape, np.nan)
classified[ndvi < 0.0]                       = 0
classified[(ndvi >= 0.0) & (ndvi < 0.2)]    = 1
classified[(ndvi >= 0.2) & (ndvi < 0.4)]    = 2
classified[(ndvi >= 0.4) & (ndvi < 0.6)]    = 3
classified[(ndvi >= 0.6) & (ndvi < 0.8)]    = 4
classified[ndvi >= 0.8]                      = 5

# ─── AREA STATS ────────────────────────────────────────────────────────────
categories = {0: "Water/Non-veg", 1: "Bare/Sparse",
              2: "Stressed Crop", 3: "Moderate Health",
              4: "Good Health",   5: "Excellent/Dense"}

print("\n" + "=" * 55)
print("NDVI AREA SUMMARY")
print("=" * 55)
total = np.sum(~np.isnan(classified))
for k, v in categories.items():
    n = np.sum(classified == k)
    print(f"{v:<22} {n:>8,} px  {n*PIXEL_AREA_HA:>10,.1f} ha  "
          f"{100*n/total:>6.1f}%")

# ─── SAVE GEOTIFF ──────────────────────────────────────────────────────────
print(f"\nSaving NDVI GeoTIFF to {OUTPUT_NDVI}...")
profile.update({'dtype': rasterio.float32, 'count': 1, 'nodata': -9999})
ndvi_save = ndvi.copy()
ndvi_save[np.isnan(ndvi_save)] = -9999
with rasterio.open(OUTPUT_NDVI, 'w', **profile) as dst:
    dst.write(ndvi_save.astype(rasterio.float32), 1)

print("✓ All outputs saved. NDVI analysis complete.")

Frequently Asked Questions

Q: Can I use this code with Landsat data instead of Sentinel-2?
Yes. For Landsat 8/9, the Red band is Band 4 and NIR is Band 5. The formula and code remain the same — just update the file paths.

Q: What if I don’t have downloaded satellite files? Can I compute NDVI online?
Yes — Google Earth Engine is ideal for large-area or time-series analysis. I’ll cover GEE with Python in a future post.

Q: How do I download Sentinel-2 data for free for my study area in India?
Register at Copernicus Open Access Hub and use the map interface to select your area and date. Downloads are free. I’ll write a dedicated post on this.

Q: What’s the difference between NDVI and other indices like EVI, SAVI, NDWI?
NDVI is the most widely used but has limitations over dense canopies or bare soils. SAVI (Soil Adjusted Vegetation Index) is better for sparse vegetation and is common in semi-arid agricultural research. I’ll cover these in the next post in this series.


What’s Next in the Agricultural Data Science Series?

This post is Part 1 of my Agricultural Data Science with Python series. Here’s what’s coming:

  • Part 2: Downloading Sentinel-2 data from Copernicus Hub using Python
  • Part 3: SAVI, EVI, and NDWI — Which vegetation index should you use?
  • Part 4: Crop classification using Random Forest and Sentinel-2 data
  • Part 5: Time series NDVI analysis for yield forecasting

Conclusion

NDVI is one of the most powerful tools available to modern agricultural scientists — and with Python, it’s accessible, free, and fully customisable.

In this post, you learned how to:

  • Calculate NDVI from Sentinel-2 satellite data using rasterio and numpy
  • Visualise NDVI as a colour-coded map with matplotlib
  • Classify NDVI into agricultural health categories
  • Compute area statistics for field or district-level crop reporting
  • Build an NDVI time series to track crop health through the season
  • Save analysis results as a georeferenced GeoTIFF

The same workflow I’ve shown you is used in real crop monitoring projects. It’s not just an academic exercise — this is how modern scinece works.


Want More Like This?

I write practical tutorials on data science, Python, and remote sensing for agricultural researchers — one post every week.

Next up in the Agricultural Data Science series:

  • How to download free Sentinel-2 data from Copernicus Hub using Python
  • SAVI vs NDVI vs EVI — which vegetation index should you use?
  • Random Forest for crop classification from satellite data

If this post helped you, the best next step is to subscribe by email so you don’t miss the next one in this series.

No spam. Just one practical tutorial per week.

How to do Exploratory Data Analysis (EDA) with python?

Exploratory Data Analysis

This article presents a thorough discussion on how to perform Exploratory Data Analysis (EDA) to extract meaningful insights from a data set. And to do this I am going to use Python programming language and its four very popular libraries for data handling.

EDA is considered a basic and one of the most important steps in data science. It helps us planning advance data analytics by revealing the nature of feature and target variables and their interrelationships.

Every advanced application of data science like machine learning, deep learning or artificial intelligence requires a thorough knowledge of the variables you have in your data. Without a good exploratory data analysis, you can not have that sufficient information about the variables.

In this article, we will discuss four very popular and useful libraries of Python namely Pandas, NumPy, Matplotlib and Seaborn. The first two are to handle arrays and matrices whereas the last two are for creating beautiful plots.

I have created this exploratory data analysis code file in Jupyter notebook with a common data file name and use it anytime a new data set is to be analyzed. The variable names just need to be changed. It saves my considerable time and I am thorough with all the variables with a good enough idea for further data science tasks.

NB: Being a non-native English speaker, I always take extra care to proofread my articles with Grammarly. It is the best grammar and spellchecker available online. Read here my review of using Grammarly for more than two years. 

Data structures for exploratory data analysis

Pandas and NumPy provide us with data structures for data handling. Pandas has two main data structures called series and data frame as data container. Series contains data of mixed type in one-dimensional array whereas data frame is a two-dimensional array having columns with the same kind of data so, it can be considered as a dictionary of series.

Lets first import all the required libraries.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

The data set used here is the very popular Titanic data set from Kaggle (https://www.kaggle.com/c/titanic/data). It contains the details of the passengers travelled in the ship and evidenced the disaster. The data frame contains 12 variables in total which are as below.

df=pd.read_csv("Titanic_data.csv")
df.columns
Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',  'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],       dtype='object')

Feature and target variables

The target variable here is the ‘Survived‘ which contains the information if the passenger survived the disaster or not. It is a binary variable having ‘1’ representing the passenger has survived and ‘0’ indicating the passenger has not survived.

The other variables are all feature variables. Among the feature variables, ‘Pclass‘ contains the class information which has three classes like Upper, Middle and Lower; ‘SibSp‘ contains the number of passengers in a relationship in terms of sibling or spouse, the variable ‘Parch‘ also displays the number of relationships in terms of ‘parent‘ or ‘child‘, the ‘Embarked‘ variable displays the name of the particular port of embarkation, all other variables carry information as the variable names suggest.

Here is the shape of the data set.

df.shape
(891, 12)

It shows that the data set has 891 rows and 12 columns.

Basic information

The info() function displays some more basic information like the variable names, their variable type and if they have null values or not.

df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB

Displaying the data set

The head() function prints few starting rows of the data set for our understanding.

df.head()
Sample of the data set
Sample of the data set

Summary statistics

The describe() function prints some basic statistics of the data in hand.

df.describe()
Summary statistics
Summary statistics

To get some idea about the non-numeric variables in the data set

df.describe(include=[bool,object])
Boolean, object count
Boolean, object count

Inspecting any particular variable more closely

df.Fare.mean()
32.2042079685746

What if we take any categorical variable for inspection? Lets consider the target variable “Survived” here. It is a binary variable as I have mentioned before.

df[df['Survived']==1].mean()
PassengerId    444.368421
Survived         1.000000
Pclass           1.950292
Age             28.343690
SibSp            0.473684
Parch            0.464912
Fare            48.395408
dtype: float64

So, it reveals important information that those who survived the disaster has an average age of 28 and they have spent on an average $48 for the fare.

Let’s find out some more critical information using logical operators. Like if we want to know what is the maximum age of a survivor travelling in Class I.

Use of logical operators

df[(df['Survived'] == 1) & (df['Pclass'] == 1)]['Age'].max()
80.0

So, the maximum age of a passenger travelling in class I is 80.

df[(df['Survived'] == 0) & (df['Pclass'] == 1)]['Age'].max()
71.0

Similarly, the youngest passenger from class I was 71 years old. Such queries can retrive very interesting information.

Suppose we want to inspect the passenger details whose names start with ‘A’. Here I have used the ‘lambda‘ function for the purpose. It makes the task very easy.

df[df['Name'].apply(lambda P_name: P_name[0] == 'A')].head()
Few lines from the data set
Few lines from the data set

Another very useful function is ‘replace()‘ from Pandas. It allows us to replace a particular value of any variable with our desired character. For example, if we want to replace the ‘Pclass‘ variable values 1,2 and 3 with ‘Class I’, ‘Class II’ and ‘Class III’ respectively then we can use the following piece of code.

x = {1 : 'Class I', 2 : 'Class II', 3:'Class III'}
df_new=df.replace({'Pclass': x})
df_new.head()
Use of replace() function
Use of replace() function

Application of ‘groupby’

This is another important function frequently used to get summary statistics. Below is an example of its application to group the variables ‘Fare‘ and ‘Age‘ with respect to the target variable ‘Survived‘.

var_of_interest = ['Fare', 'Age']

df.groupby(['Survived'])[var_of_interest].agg([np.mean, np.std, np.min, np.max])
Application of groupby() function
Application of groupby() function

Contingency table

Contingency table or cross-tabulation is a very popular technique to create a table for multivariate data set in order to display the frequency distribution of variables corresponding to other variables. Here we will use the crosstab() function of Pandas to perform the task

pd.crosstab(df['Survived'], df['Pclass'])
Contingency table
Contingency table

So, you can see how quickly we can get the passenger class wise tally of passenger’s survival and death count through a contingency table.

Pivot table

The ‘Pivot_table()‘ function does the job by providing a summary of some variables corresponding to any particular variable. Below is an example of how we get the mean ‘Fare‘ and ‘Age’ of all passengers either survived or died.

df.pivot_table(['Fare','Age'],['Survived'],aggfunc='mean')
Pivot table
Pivot table

Sorting the data set

We can sort the data set with respect to any of the variables. For example below we have sort the data set with respect to the variable “Fare“. The parameter “ascending=False” specifies that the table will be arranged in descending order with respect to variable ‘Fare‘.

df.sort_values(by=["Fare"], ascending=False)
Sorted with respect to 'Fare'
Sorted with respect to ‘Fare’

Visualization using different plots

Visualization is the most important part in case of performing exploratory data analysis. It reveals many interesting pattern among the variables which otherwise tough to recognise using numerals.

Here we will use two very capable python libraries called matplotlib and seaborn to create different plots and charts.

Check for missing values in the data set

A heat map created using the seaborn library is helpful to find out missing values easily. This is quite useful as if the data frame is a big one and missing values are few, locating them is not always easy. In this case, such a heatmap is quite helpful to find out missing values.

import seaborn as sns
plt.rcParams['figure.dpi'] = 100# the dpi can be set to enhance the resolution of the image
# Congiguring retina format
%config InlineBackend.figure_format = 'retina'
sns.heatmap(df.isnull(), cmap='viridis',yticklabels=False)
Heatmap to locate missing values
Heatmap to locate missing values

So, we can notice from here that out of total 12 variables, the variables “Age” and “Cabin” only have the missing values. We have used the ‘retina’ format of seaborn library to make the plot more sharp and legible.

Also, see the code to create these two plots as subplots and how the figure size has been mentioned. You can create separate plots without specifying all these details and see the effect. These specifications will help you adjust the plots and make them more legible.

Plotting the variable “Survived” to compare the dead and alive passengers of Titanic with a bar chart

sns.countplot(x=df.Survived)
Bar plot for variable 'Survived'
Bar plot for variable ‘Survived’

The above plot displays how many people survived out of all passengers. Again if we want these comparison according to the sex of the passengers, then we should incorporate another variable in the chart.

sns.countplot(df.Survived,hue=df.Sex)
Bar plot showing the survival according to passengers' sex
Bar plot showing the survival according to passengers’ sex


The above plot reveals an important information regarding the survival of the passengers. From the plot we have drawn before it was evident that the death was higher than the number of people survived the disaster.

Now if we group this survival according to the sex, it further reveals that the number of male passengers survived the accident was much more than that of female passengers. Also, the death count for female passengers was also higher than male passengers.

Lets inspect the same information with a contingency table

pd.crosstab(df['Survived'], df['Sex'], margins=True)
Contingency table for count of passengers survived according to their sex
Contingency table for count of passengers survived according to their sex

Again if we want to categorize the plot of survival of the passenger depending on the class of the passengers, then we can have the information about how many passengers of a particular class have survived.

Bar plot with two categorical variables

There were three classes which have been represented as class 1,2 and 3. Let’s prepare a count plot with passenger class as the subcategory in case of survival of the passengers

sns.countplot(df.Survived, hue=df.Pclass)
Count plot for passenger class wise survival
Count plot for passenger class wise survival

The above plot clearly shows that the death toll was much higher in case of passenger of class 3 and class 1 passengers had the highest survival. Passengers of class 2 have almost equal no. of death and survival rate.

The highest no. of passengers were in class 3 and so the death toll too. We can see the below count plot where it is evident that class 3 has a much higher number of passengers compared to the other classes.

Again we can check the exact figure of passenger survival according to the passenger class with a contingency table too as below.

pd.crosstab(df['Survived'], df['Pclass'], margins=True)
sns.countplot(df.Pclass)

Creating distribution plot

Below a seaborn distribution plot has been created with simple “distplot()” function all other parameters are set to the default one. By default, it calculates the standard normal values to display its distribution pattern.

sns.distplot(df.Age, color='red')
Distribution plot-1
Distribution plot-1

If we want the original ‘Age’values to be displayed, we need to set the ‘kde’ as ‘False’.

Distribution plot-2
Distribution plot-2
sns.distplot(df['Age'].dropna(),color='darkred',bins=40)
Distribution plot-3
Distribution plot-3
sns.distplot(df.Fare, color='green')
Distribution plot-4
Distribution plot-4

Box plot and violin plot

Box plot and violin plots are also very good visualization method to determint the distribution of any variable. See the application of these two plots for the variable ‘Fare‘ below,

plt.subplot(1,2,1)
sns.boxplot(x=df['Fare'])
plt.subplot(1,2,2)
sns.violinplot(x=df['Fare'],color='red')
Box plot and violin plot
Box plot and violin plot

The whiskers in the boxplot above, display the interval of the point scatter which is (Q1−1.5⋅IQR, Q3+1.5⋅IQR) where Q1 is the first quartile, Q3 is the third quartile and IQR is the Inter Quartile range i.e. the difference between 1st and 3rd quartile.

The black dots represent outliers which are beyond the normal scatter marked by the whiskers. On the other hand the violin plot, the kernel density estimate has been displayed on both sides.

Creating a boxplot to inspect the distribution

Below is a boxplot created to see the distribution of different passenger class with respect to the fare and as expected the highest fare class is the first class. Another boxplot has been created with the same ‘Pclass‘ variable against the “Age” variable.

These two boxplots side by side let us understand the relation between passengers’ age group and their choice of classes. We can clearly see that senior passengers are more prone to spend higher and travel in higher classes.

plt.subplot(1,2,1)
sns.boxplot(x=df.Pclass,y=df.Fare)
plt.subplot(1,2,2)
sns.boxplot(x=df.Pclass, y=df.Age)

Correlation plot

Here we will inspect the relationship between the numerical variables using the correlation coefficient. Although the data set is not ideal to do this correlation study as it lacks numerical variables having a meaningful interrelation.

But for the sake of complete EDA steps, we will perform this correlation study with the numerical variables we have in our hand. We will produce a heatmap to display the correlation with different colour shades.

# Considering only numerical variables
scatter_var = list(set(df.columns)-set(['Name', 'Survived', 'Ticket','Cabin','Embarked','Sex','SibSp','Parch']))

# Creating heatmap
corr_matrix = df[scatter_var].corr()
sns.heatmap(corr_matrix,annot=True);
Heat map showing correlation coefficients
Heat map showing correlation coefficients

Scatter plot

Scatter plots are very handy in displaying the relationship between two numeric variables. The scatter() function of matplotlib library does this very quick to give us the first-hand idea about the variables.

Below is a scatterplot created between the ‘Fare‘ and ‘Age‘ variables. Here the two variables are taken as Cartesian coordinates in the 2D space. But even 3D scatterplots are also possible.

plt.scatter(df['Age'], df['Fare'])
plt.title("Age Vs Fare")
plt.xlabel('Age')
plt.ylabel('Fare')
Scatter plot
Scatter plot

Creating a scatterplot matrix

If we want a glimpse of the joint distribution and one to one scatterplots among all combinations of the variables, a scatterplot matrix can be a good solution. The pairplot() function of the seaborn library does the job for us.

Below is an example with the scatter_var variable we created before with all the numerical variables in the data set.

sns.pairplot(df[scatter_var])
Scatter plot matrix
Scatter plot matrix

See the above scatterplot matrix, the diagonal plots are the distribution plot for the corresponding variables while the rest of the scatterplots are for each pair of the variables.

To conclude with I will discuss a very handy and useful function from Pandas. Pandas profiling can create a summary from the data set in a jiffy.

Pandas profiling

First of all you need to install the library using the pip command.

pip install pandas-profiling

It will take some time to install all its module. Once it gets installed then to execute it run the below line of codes. The ProfileReport() function creates the EDA_report and finally an interactive HTML file is created for the user.

from pandas_profiling import ProfileReport
EDA_report = ProfileReport(df)
EDA_report.to_file(output_file='EDA.html')

It is a very helpful process to perform exploratory data analysis specially for those who does not very familiar to coding and statistical analysis and just want some basic idea about his data. The interactive report allows them to dig further to get a particular information.

Disadvantage

The main demerit of pandas profiler is it takes too much time to generate report when the data set is huge one. And many a time the practical real world data set has thousands of records. If you through the entire data set to the profiler you might get fustrated.

In this situation ideally you should use only a part of the data and generate the report. The random sample part from the whole dat set may also help you to have some idea about the variables of interest.

Conclusion

Exploratory data analysis is the key to know your data. Any data science task starts with data exploration. So, you need to be good at exploratory data analysis and it needs a lot of practice.

Although there are a lot of tools which can prepare a summary report from the data at once. Here I have also discussed Pandas profiling function which does all data exploration on your behalf. But my experience is, these are not that effective and may result in some misleading result in case the data is not filtered properly.

If you do the exploration by hand step by step, you may need to devote some more time, but in this way you become more familiar to the data. You get a good grasp about the variables which helps you in advance data science applications.

So, that’s all about exploratory data analysis using four popular python libraries. I have discussed every function with example which are generally required to explore any data set. Please let me know how you find this article and if I have missed anything here. I will certainly improve it according to your suggestions.

Decision tree for classification and regression using Python

Decision tree

Decision tree classification is a popular supervised machine learning algorithm and frequently used to classify categorical data as well as regressing continuous data. In this article, we will learn how can we implement decision tree classification using Scikit-learn package of Python

Decision tree classification helps to take vital decisions in banking and finance sectors like whether a credit/loan should be given to a customer or not depending on his risk bearing credentials; in medical test conditions like if a new medicine should be tried on a patient depending on his/her medical history and many more fields.

The above two cases are where the target variable is a bivariate one i.e. with only two categories of response. There can be cases where the target variable has more than two categories, the decision tree can be applied in such multinomial cases too. The decision tree can also handle both numerical and categorical data. So, no doubt a decision tree gives a lot of liberty to its users.

NB: Being a non-native English speaker, I always take extra care to proofread my articles with Grammarly. It is the best grammar and spellchecker available online. Read here my review of using Grammarly for more than two years. 

Introduction to decision tree

Decision tree problems generally consist of some existing conditions which determine its categorical response. If we arrange the conditions and the decisions depending on those conditions and again one of those decisions resulting in further decisions; the whole structure of decision making resembles a tree structure. Hence the name decision tree.

The first and topmost condition which initiates the decision-making process is called the root condition. The nodes from the root node are called either a leaf node or decision node depending on which one takes part in further decision making. In this way, a recursive process of continues unless and until all the elements are grouped into particular categories and final nodes are all leaf nodes.

An example of decision tree

Here we can take an example of recent COVID-19 epidemic problem related to the testing of positive cases. We all know that the main problem with this disease is that it is very infectious. So, to identify COVID positive patients and isolating them is very essential to stop its further spread. This needs rigorous testing. But COVID testing is a time consuming and resource-intensive process. It becomes more of a challenge in the case of countries like India with a strong 1.3 billion population.

So, if we can categorize which persons actually need testing it can save a lot of time and resources. We can straightway downsize the testing population significantly. So, it is a kind of divide and conquer policy. See the below decision tree for classifying persons who need to be tested.

An example of decision tree
An example of decision tree

The whole classification process is much similar to how a human being judges a situation and makes a decision. That’s why this machine learning technique is simple to understand and easier to implement. Further being a non-parametric approach this algorithm is applicable to any kind of data even when the distribution is not known.

The distinct character of a decision tree which makes it special among all other machine learning algorithms is that unlike them it is a white box technique. That means the logic used in the classification process is visible to us. Due to simple logic, the training time for this algorithm is far less even when the data size is huge with high dimensionality. Moreover, it is the decision tree which makes the foundation of advanced machine learning computing technique like the random forest, bagging, gradient boosting etc.

Advantages of decision tree

  • The decision tree has a great advantage of being capable of handling both numerical and categorical variables. Many other modelling techniques can handle only one kind of variable.
  • No data preprocessing is required. Except for missing values no other data processing steps like data standardization, use of dummy variables for categorical data are required for decision tree which saves a lot of user’s time.
  • The assumptions are not too rigid and model can slightly deviate from them.
  • The decision tree model validation can be done through statistical tests and the reliability can be established easily.
  • As it is a white box model, so the logic behind it is visible to us and we can easily interpret the result unlike the black-box model like an artificial neural network.

Now no technique can be without any flaws, there are always some flipside and decision tree is no exception.

Disadvantages of Decision tree

  • A very serious problem with a decision tree is that it is very much prone to overfitting. That means the prediction given by decision tree is often too accurate for a too specific situation with a too complex model. 
  • The classification by decision tree generally uses an algorithm which tends to find a local optimum result for each node. As this process follows recursively for each node, ultimately the whole process ends up finding a locally optimal instead of a globally optimal decision tree.
  • The result obtained from a decision tree is very unstable. A little variation in the data can lead to a completely different classification/regression result. That’s why the concept of random forest/ensemble technique came, this technique brings together the best result obtained from a number of models instead of relying on a single one.

Classification and Regression Tree (CART)

The decision tree has two main categories classification tree and regression tree. These two terms at a time called as CART. This term was first coined in 1984 by Leo Breiman, Jerome Friedman, Richard Olshen and Charles Stone. 

Classification

When the response is categorical in nature, the decision tree performs classification. Like the examples, I gave before, whether a person is sick or not or a product is pass or fail in a quality test. In all these cases the problem in hand is to include the target variable into a group. 

The target variable can be a binomial that is with only two categories like yes-no, male-female, sick-not sick etc. or the target variable can be multinomial that is with more than two categories. An example of a multinomial variable can be the economic status of people. It can have categories like very rich, rich, middle class, lower-middle class, poor, very poor etc. Now the benefit of the decision tree is a decision tree is capable of handling both binomial and multinomial variables.

Regression

On the other hand, the decision tree has its application in regression problem when the target variable is of continuous nature. For example, predicting the rainfall of a future date depending on other weather parameters. Here the target variable is a continuous one. So, it is a problem of regression. 

Application of Decision tree with Python

Here we will use the sci-kit learn package to implement the decision tree. The package has a function called DecisionTreeClasifier() which is capable of classifying both binomial (target variable with only two classes) and multinomial (target variable having more than two classes) variables.

Performing classification using decision tree

Importing required libraries

The first step to start coding is to import all the libraries we are going to use. The basic libraries for any kind of data science projects are like pandas, numpy, matplotlib etc. The purpose of these libraries has an elaborate discussion in the article simple linear regression with python.

# importing libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

About the data

The example dataset I have used here for demonstration purpose is from kaggle.com. The data collected by “National Institute of Diabetes and Digestive and Kidney Diseases”  contains vital parameters of diabetes patients belong to Pima Indian heritage.

Here is a glimpse of the first ten rows of the data set:

Diabetes data set for logistic regression
Diabetes data set for ANN

The data set has independent variables as several physiological parameters of a diabetes patient. The dependent variable is if the patient is suffering from diabetes or not. Here the dependent column contains binary variable 1 indicating the person is suffering from diabetes and 0 he is not a patient of diabetes.

dataset=pd.read_csv('diabetes.csv')
dataset.head()
# Printing data details
print(dataset.info) # for a quick view of the data
print(dataset.head) # printing first few rows of the data
dataset.tail        # to show last few rows of the data
dataset.sample(10)  # display a sample of 10 rows from the data
dataset.describe    # printing summary statistics of the data
pd.isnull(dataset)  # check for any null values in the data
Checking if the dataset has any null value

Creating variables

As we can see that the data frame contains nine variables in nine columns. The first eight columns contain the independent variables. These are some physiological variables having a correlation with diabetes symptoms. The ninth column shows if the patient is diabetic or not. So, here the x stores the independent variables and y stores the dependent variable diabetes count.

x=dataset.iloc[:,:-1].values
y=dataset.iloc[:,-1].values

Performing the classification

To do the classification we need to import the DecisionTreeClassifier() from sklearn. This special classifier is capable of classifying binary variable i.e. variable with only two classes as well as multiclass variables.

# Use of the classifier
from sklearn import tree
clf = tree.DecisionTreeClassifier()
clf = clf.fit(x, y)

Plotting the tree

Now as the model is ready we can create the tree. The below line will create the tree.

tree.plot_tree()clf

Generally the plot thus created, is of very low resolution and gets distorted while using as image. One solution of this problem is to print it in pdf format, thus the resolution gets maintained.

# The dicision tree creation
tree.plot_tree(clf) 
plt.savefig('DT.pdf')

Another way to print a high resolution and quality image of the tree is to use Graphviz format importing export_graphviz() from tree.

# Creating better graph
import graphviz 
dot_data = tree.export_graphviz(clf, out_file=None) 
graph = graphviz.Source(dot_data) 
graph.render("diabetes") 
Decision tree to classify the data
Decision tree created using Graphviz

The tree represents the logic of classification in a very simple way. We can easily understand how the data has been classified and the steps to achieve that.

Performing regression using decision tree

About the data set

The dataset I have used here for demonstration purpose is from https://www.kaggle.com. The dataset contains the height and weight of persons and a column with their genders. The original dataset has more than thousands of rows, but for this regression purpose, I have used only the first 50 rows containing data on 25 male and 25 females.

Importing libraries

Additional to the basic libraries we imported in a classification problem, here we will need to import the DecisionTreeRegressor() from sklearn.

# Import the necessary modules and libraries
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt

Reading the dataset

I have already mentioned about the dataset used here for demonstration purpose. The below code is to import the data and store in a dataframe called dataset.

dataset=pd.read_csv('weight-height.csv')
print(dataset)

Here is a glimpse of the dataset

Dataset for random forest regression

Creating variables

As we can see that the dataframe contains three variables in three columns. The last two columns are only of our interest. We want to regress the weight of a person using the height of him/her. So, here the independent variable height is x and the dependent variable weight is y.

x=dataset.iloc[:,1:2].values
y=dataset.iloc[:,-1].values

Splitting the dataset

This is a common practice of splitting the whole data set for creating training and testing data set. Here we have set the test_size as 20% that means the training data set will consist 80% of the total data. The test data set works as an independent data set when need to test the classifier after it gets trained with training data.

# Splitting the data for training and testing
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test=train_test_split(x,y, test_size=0.20, random_state=0)

Fitting the decision tree regression

We have here fitted decision tree regression with two different depth values two draw a comparison between them.

# Creating regression models with two different depths
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(x_train, y_train)
regr_2.fit(x_train, y_train)

Prediction

The below line of codes will give predictions from both the regression models with two different depth values using a new independent variable set X_test.

# Making prediction
X_test = np.arange(50,75, 0.5)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)

Visualizing prediction performance

The below line of codes will generate a height vs weight scattered plot alongwith two prediction lines created from two different regression models.

# Plot the results
plt.figure()
plt.scatter(x, y, s=20, edgecolor="black",
            c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",
         label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("Height")
plt.ylabel("Weight")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

Conclusion

In this post, you have learned about the decision tree and how it can be applied for classification as well as regression problem using scikit-learn of python.

The decision tree is a popular supervised machine learning algorithm and frequently used by data scientists. Its simple logic and easy algorithm are the main reason behind its popularity. Being a white box type algorithm, we can clearly understand how it is doing its work.

The DecisionTreeClassifier() and DecisionTreeRegressor() of scikit-learn are two very useful functions for applying decision tree and I hope you are confident about their use after reading this article.

If you have any question regarding this article or any confusion about its application in python post them in the comment below and I will try my best to answer them.

References

Multiple Linear Regression with Python

Multiple linear regression

Multiple linear regression(MLR) is also a kind of linear regression but unlike simple linear regression here we have more than one independent variables. Multiple linear regression is also known as multivariate regression. As in real-world situation, almost all dependent variables are explained by more than variables, so, MLR is the most prevalent regression method and can be implemented through machine learning.

Mathematical equation for Multiple Linear Regression

An MLR model can be expressed as:

Yn = a0 + a1Xn1 + a2Xn2 + ⋯ + aiXi + ∈n → (Xn1 + ⋯ + Xni ) + ∈n

In the above model, the variable Yn represents response for case n and it has a deterministic part and a stochastic part; a0is the intercept, i is no. of independent variables, ai and Xi are the regression coefficients and values of independent variables, respectively and ivaries from 1 to n

The main purpose of applying this regression technique is to develop a model which can explain the variance in the response as much as possible using the independent variables. The ratio of the explained variance by the model to the total variance of the response is known as the coefficient of determination and denoted by R2. We will discuss this statistic in detail later. 

But it is an important parameter in regression modelling to ascertain how good the model is. The value of R2 varies between 0 to 1. Now three situations regarding the fitting of the model we may face which are underfitted model, good fit and overfitted model.

Underfit model

This situation arises when the value of R is low. Low R2 value indicates that the proposed model is not explaining the variation of the response adequately. So, the model needs improvement.

Good-fit model

Like, in this case, we have a good R2 value. Which suggests a good fit of the model and it can be used for prediction.

Overfit model

Sometimes models become too complex with lots of variables and parameters. Such complex models get trained by the data too well and give a very high R2 value almost close to 1.0. But they can not predict well when tested with a different set of data. This is because the model being too complex becomes too specific to a particular situation. Such models are called overfitted models.

Dataset used

The dataset used here is the same we used in the Simple Linear Regression. But in this case all the explanatory/independent variables were considered for modelling purpose. The database is an imaginary one and based on my experience of modelling tree data. 

The dataset contains data on tree total biomass above the ground and several other tree physical parameters like tree commercial bole height,  diameter, height, first forking height, diameter at breast height, basal area. Tree_biomass is the dependent variable here which depends on all other independent variables.

Here is a glimpse of the database:

If you find any difficulty to understand the variables, just don’t bother about their names. Take them as two categories of variables, one is dependent variable, I have denoted it with y here and others are independent variable1, 2, 3 etc. Important is the relationship between these two categories of variables. Whatever their names maybe, you just have to have some experience in their relations.

Assumptions for multiple linear regression

We conduct the regression process assuming some conditions. Without holding these conditions, it is not possible to proceed with the regression process. These are called regression assumptions and they are as below:

Assumption of linearity:

There must be a linear relationship between the independent variables and the response variable. The variables in this imaginary dataset have a linear relationship between them. You can easily check this property by plotting the response variable against each of the explanatory variables. 

Assumption of Homoscedasticity:

The residuals or errors that is the difference between observed and estimated values must have constant variance.

Assumption of multivariate normality:

The residuals should follow a normal distribution. We can prepare a normal quantile-quantile plot to check this assumption.

Assumption of absence of multicollinearity:

There should be no multicollinearity between the independent variables i.e. the independent variables should not be linearly related to each other.

Application of Multiple Linear Regression using Python

The main purpose of this article is to apply multiple linear regression using Python. This is the most important and also the most interesting part. So let’s jump into writing some python code. Like simple linear regression here also the required libraries have to be called first.

Calling the required libraries

We will be using fore main libraries here. For handling data frame and arrays NumPy and panda, for creating plots matplotlib and for metrics operations sklearn. These are the most important libraries for data science applications. 

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics

Importing the dataset

To import the tree dataset as mentioned earlier we will use the import function of panda library.

***** Importing the dataset ***********
dataset=pd.read_csv('tree.csv')

Defining variables

Now the next important task is to tell Python about the dependent and independent variables of the dataset. As the protocol says we will store the dependent variable in y and the independent variables in x. As I have already explained above the dataset contains one dependent variable and 7 independent variables.

So we will store the variables in two NumPy arrays. As x has to store 7 independent variables, it has to be a 2-dimensional array. Whereas being a variable with only one column, y can do with one dimension. So, the python code for this purpose is as below:

#***** Defining variables *************
x=dataset.iloc[:,: -1].values
y=dataset.iloc[:,-1].values

Here the “:” denotes the rows. As the dataset contains the dependent value i.e. tree_biomass values as the extreme right column so, python indexes it with -1.

Checking the assumption of the linear relationship between variables

For example, here I have plotted the tree_height against the dependent variable tree_biomass. Although it is evident that with the increase of tree height the biomass will certainly increase. Still, a scatterplot is a very handy visualization technique to double-check the property. You can prepare this plot very easily using the below code:

#********* Plotting dependent variable against any independent variable 
plt.scatter(x[:,2],y) # accessing the variable tree_height
plt.title("Checking linearity between dependent and independent variables")
plt.xlabel("Tree height")
plt.ylabel("Tree biomass")

I have stored the variables in numpy array earlier. So, to access them we have to just mention which variable we intend to plot. For plotting we have used the plt function of matplotlib library.

And here is the plot:

The plot suggests almost a linear relationship between the variables.

Splitting the dataset in training and test data

For testing purpose, we need to separate a part of the complete dataset which will not be used for model building. The thumb rule is to use the 80% of data for modelling and keep aside the rest of the data. It will work as an independent dataset once we come up with the model and need to test it.

#****** Dividing the dataset into training and testing dataset
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test= train_test_split(x, y, test_size=0.2, random_state=0)

Here this data splitting task has been performed with the help of model_selection module of sklearn library. This module has an inbuilt function called train_test_split which automatically divides the dataset into two parts. The argument test_size controls the proportion of the test data. Here it has been fixed to 0.2 so the test dataset will contain 20% of the complete data.

Application of multiple linear regression

Here comes the main part of this article that is using the regression to regress the response using the known values of more than one independent variables. As in the above section, we have already created train dataset. The following code will use this train data for model building.

#********* Application of regression
from sklearn.linear_model import LinearRegression
regressor=LinearRegression()
regressor.fit(x_train, y_train)

As it is also a linear regression method, so the linear_model module of sklearn library is the one containing the required function LinearRegression. Regressor is an instance created to apply the LinearRegression function.

Getting the regression coefficients for the regression equation

As the regression is done, we need the regression equation. This equation is actually the relation between the dependent and independent variables defined by some coefficients. Using these coefficients we can determine how a unit change in any of the independent variables is going to affect the dependent variable.

#******** Getting the coefficients stored in a dataframe
#*****************************************************************
# storing the column names of independent variables
pos=[1,2,3,4,5,6,7]          
colnames=dataset.columns[pos]
print(colnames)
# creating a dataframe storing the coefficients along with the independent variable names
regressor.intercept_
coef_df=pd.DataFrame(regressor.coef_,colnames,columns=['Coefficients'])
coef_df

In the above section of code, you can see that first of all the position of the independent variables are stored in a variable. And then the corresponding coefficients are fetched from the instance regressor created from LinrarRegression function of linear_model module of sklearn. The coefficients are from regressor.coef_ and the intercept in regressor.intercept_.

Printing the regression coefficients

The regression equation

With the help of these coefficients now we can develop the multiple linear regression.

The multiple linear regression equation

So, this is the final equation for the multiple linear regression model.

Using the model to predict using the test dataset

Now we have the model in our hand. But how can we test its efficiency? If the model is a good one then it should have the capability to predict with precision. And to test that we will need independent data which was not involved during model building.

Here comes the role of test dataset that we kept aside at the very beginning. We will predict the response using the test dataset and compare the prediction with the observations we already have in our hand. The following code will do the trick for us.

And here is the comparison. I have created a dataframe with the observed and predicted values side by side for the ease of comparison.

Comparing the original and predicted values

In the above figure, I have shown only the first 15 values of the dataframe. But it is enough to show that the prediction is satisfactory. 

Goodness of fit of the model 

We have tested the data and got a good prediction using the model. However, we have not quantified yet. We do not have any number to ascertain how good is the model. Here I will discuss such fit statistics that are very useful in this respect. If we have to compare multiple models then these numbers play a crucial role to find the best out of them.

The following code will deliver fit statistics popularly used to judge the goodness of any statistical model. These are coefficient of determination denoted as R2 is the proportion of variance exists in the response variable explained by the proposed model. So the higher its value better is the model. 

Coefficient of determination (R2)

Suppose our test dataset has n set of independent and dependent variables i.e. (x1,x2,…,xn), (y1,y2,…,yn)respectively. Now using our developed model the prediction we achieved has the predicted values (v1,v2,…,vn). So, the total sum of square will be:

This is the total existing variation in the response variable.

Now the variation explained by the model we developed is the regression sum of square and can be calculated as

So as the definition of the coefficient of determination goes, it can be calculated as:

Again it can be farther simplified by breaking down the regression sum of square as the variance explained subtracting the unexplained variance from the total variance. The unexplained variance is actually the variance the model is not able to explain. It is also known as error or residual sum of square and calculated as:

So, now we can rewrite the equation of R2 as

#****** Calculating fit statistics
r_square=regressor.score(x_train, y_train)
print('Coefficient of determination(R square):',r_square)
from sklearn import metrics
print('Mean Absolute Error:', metrics.mean_absolute_error(y_test, y_predict))
print('Mean Squared Error:', metrics.mean_squared_error(y_test,y_predict))
print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test, y_predict)))

Mean Absolute Error(MAE)

This is another popular measure for model fit. As the name suggests, it is the simple difference between observed and predicted values. As we are only interested in the deviations, so we will take here the absolute value of the differences. So the expression will be:

As it measures the error of the estimated values so a lowe MAE suggests better model.

Mean Squared Error (MSE)

This is also a measure of the deviation of the model estimation from that of the original values. But instead of the absolute values, we will take the squared values of the deviations. So many a time it is also called Mean Squared Deviation (MSD) and calculated as:

Root Mean Squared Error (RMSE)

As the name suggests, this measure of fit first calculates the difference between the observed and model-predicted values, takes the square of each error then calculates the mean and ultimately calculates the square root to get the RMSE. So its equation is:

Fit statistics

How can the fitting further be improved?

There is always scope for improving the model so that it can give more precise prediction. As we already know that the main purpose of Multiple Linear Regression is to ascribe the variance of response variable as much as possible amongst the independent variables.

Now here lies the trick of improving the prediction of multiple linear regression model. The response variable you are dealing here with gets affected by a number of explanatory variables. Some of them are straight way visible to us and we can say with confindence that they are main contributor towards the response. And all together they can give you a good explanation too.

But with a good knowledge of the domain one can identify many other variables that are not directly recognizable as causal effects. For an example if we take the example of any agriculture experiment, crop yield is determined by so many direct, indirect, physiological, chemical, weather variable, soil condition etc.

So, the skill and domain knowledge of the researcher play a viral role to choose variable wisely in order to improve the model’s fit. Using too less variable will result in a poor R2 whereas using too many variables may produce a very complex model with a very high R2. In both of these scenario model’s performance will not be up to the mark.

References:

  • https://www.wikipedia.org/
  • https://www.statisticshowto.com/
  • https://towardsdatascience.com/

Getting started with Python for Machine Learning: beginners guide

Getting started with Python for Machine Learning

If you are reading this article, then you are a Machine Learning enthusiast without any doubt. You must have already gone through the theoretical basics of it and getting impatient to try hand in your first Machine Learning application. Python is the most popular programming language for machine learning. I would suggest that if you want a carrier in data science it is Python which you should bet on for.

Learn about two main types of Machine Learning 
>Supervised machine learning
>Unsupervised machine learning

So, this article is for you. Here I will demonstrate how to complete with the setup of the Python and to start with your first simple programming.

But first of all the question is….

Why Python for machine learning?

Why I have chosen Python to carry on Machine Learning? There are lots of tools available and some of them are very popular too. For example, R is a very reputed language and also present there for a long time. 

Especially people with traditional statistical or mathematical background have a strong inclination towards R too. One of the reasons behind this popularity is R actually came into existence replacing S which was a pure statistical programming language developed on C platform and hence was hugely popular amongst statisticians.

Python Vs R

R was developed in 1992 and has a specific edge for data analysis tasks. And being a procedural language it breaks down the total tasks into a series of steps and procedures. Both of R and Python being open source are freely available to use and online resources are huge.

R is mainly helpful for core statistical and data analytics purpose. The language was developed by statisticians keeping the need for statisticians in mind mainly. It has very powerful graphical functions like ggplot, ggvis, shiny etc. If you want to create eye-catching plots from your data, R should be your best friend.

On the other hand, Python came a little early in 1989 developed by Guido Van Rossum, a Dutch scientist. It has a slow steady growth till 2010 but after that with the start of data explosion era, its popularity also shoots quickly. 

The main reason behind is so quick popularity is its simplicity and versatility. Machine Learning and Artificial Intelligence have many complex algorithms to perform several complex tasks. But the beauty of Python is that it makes tasks easy for both machine learning and AI with its vast collection of simple to use functions.

Use of Python in data science is just one of its capability. Being a general-purpose language, Python can be used for developing web applications, software, mobile applications development and even read-modifying files connecting to the database. This versatility of this language has won the heart of millions of people irrespective of whether they are data scientists or computer science enthusiasts.

If you are a beginner in data science you can jump-start the learning and application of Python even with little or no background in programming languages. It is also a far better performer compare to R when it comes to analyzing large size database.

The following chart from Economist.com will help you to realize how popular Python has become recently surpassing all other big names like Java, R, C++ etc.

Source: steelkiwi.com, economist.com

In the data science world these two programming languages are close competitors. Both of them are very popular and have their own plus and minuses. And ultimately which platform you should use is purely your choice. 

Having said this, I think the popularity and simplicity of Python in its application in machine learning will keep it slightly ahead of R. And if you are looking ahead to build a career as a data scientist, in my opinion, the future is brighter with Python skill.

Setting up Python in your computer

To start with python application, the first step is to install Python in your computer. If your desktop/laptop is a new one, then there is a chance that it might have Python preinstalled in it. You can check your start menu for it. If you get it there then skip this step.

Download Python

If it is not installed already then you have to download and install it from Python.org. 

Python for machine learning

So, from here you download the specific Python version that suits your computer and download it. As of today Python, 3.8.2 is the latest version so you can download it. And if you have an old system and run Windows XP then you have to download an old compatible version preferably lower than Python 3.5.

After you downloaded the file, click it to start the installation. Just go with the recommended installation process. It is a quick process and within minutes python is installed in your system.

Python for machine learning

The following window will appear as the Python installation is finished.

Python for machine learning

Now you can check your computer start menu and the python folder with associated applications will be there.

Python for machine learning

As I have installed it just now so it is having a “New” tag with all its application. Now as the python is installed you can directly launch its application and start your code. 

Python for machine learning

Here in the above screenshot, you can see that the console is showing all the details of the Python version installed. I have also done some basic command like print and simple calculation.

But to start with your Python coding we will need a good IDE which will help us with Python syntax writing in an intuitive way.

Selecting a Python IDE

Although while installing Python a simple IDE called IDLE gets installed automatically. We prefer to use a more popular and advanced IDE called PyCharm. The reason is to get familiar with one IDE of any programming language takes significant time. So, we should choose a good IDE to start with so that we can continue our task in it.

PyCharm is currently the most popular IDE for python. See the following table which compares some popular Python IDEs. PyCharm also comes with a paid version. But you will get full-featured integrated environment in both of them.

Python for machine learning

Source: www.softwaretestinghelp.com

Except for thesse IDEs some simple text editor like Notepad++ is also very popular amongst data scientists. The only issue with text editors is that you have to use some additional plugins to compile the code written in them. In this context IDEs come handy and you can do the complete task starting from writing code to its compilation in there itself.

Having said these, the final selection is completely on your choice. Python being the most popular programming language, users have the luxury of choosing an IDE from a vast collection of it. And honestly speaking, it is difficult to judge any single IDE to be the best one. Every one of them has its own strengths and weakness.  

So though I have selected PyCharm here, you can select any other too. It will not take you much time to switch between IDEs once you make your basics strong.

So, let’s start installing PyCharm

PyCharm is a product of Jetbrains. Open the concern page following this link

https://www.jetbrains.com/pycharm/download/#section=windows

Step 1: 

Click the download button under Community, it is the free open source version of PyCharm

Python for machine learning

Step: 2

Start the downloaded application by clicking the next button.

Python for machine learning

Step: 3

In the next step, if you want to change the program location then provide the path or you can go with the default path assigned. I am here going with the default folder. Then click next.

Python for machine learning

Step: 4

Next window will allow you to create a desktop icon of PyCharm and you can also update the path variable. To proceed click next.

Python for machine learning

Step: 5

Here you can change the start menu folder, then click “Install”.

Python for machine learning

Step: 6

The installation will start. It takes a few minutes. As the installation is done click next.

Step: 7

In the next window click “Finish” to complete the installation process.

Starting PyCharm for the first time

Now PyCharm is installed on your computer. Go to your computer start menu and launch the programme. The first window appears is of the privacy policy. Click on to agree with the terms & conditions and click continue.

Next is the data sharing window. It’s completely your choice. Choose any of the options and proceed.

In the next window, you will get to choose the appearance of your IDE. Choose any of them you feel comfortable with and click skip remaining and set defaults. You can change all these options anytime you want later.

The next window is important which allows you to choose the location where you want to create your Python project. For me, I like to save all my important files at the cloud, so I have provided that particular path there. You can change it here or later.

So now you are all set to start your journey with Python programming with PyCharm IDE. 

References:

  • https://www.python.org
  • https://towardsdatascience.com
  • https://www.geeksforgeeks.org
  • https://steelkiwi.com/blog