Using LangChain and OpenAI to Query SQL Databases: A Practical Example
In this article, we will explore how to use LangChain and OpenAI to interact with an SQL database. We’ll walk through a Python script that leverages these technologies to convert natural language queries into SQL commands, execute those queries, and then return the results in a human-readable format. Additionally, I’ll recommend a sample CSV file to populate your database, and we’ll discuss the expected outputs for each query.
Setting Up the Environment
Before diving into the code, ensure you have all necessary libraries installed:
pip install langchain openai pymysql python-dotenv
Also, make sure you have a .env
file containing your database credentials:
MYSQL_HOST=localhost
MYSQL_PORT=3306
MYSQL_USER=root
MYSQL_PASSWORD=xxxxx
MYSQL_DB=university
The Code
Here’s a complete Python script that connects to a university database and processes various queries:
from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
load_dotenv()
llm = ChatOpenAI(temperature=0)
host = 'localhost'
port = '3306'
username = 'root'
password = 'xxxxx'
database_schema = 'university'
mysql_uri = f"mysql+pymysql://{username}:{password}@{host}:{port}/{database_schema}"
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=2)
chain = create_sql_query_chain(llm, db)
# Test the setup
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT COUNT(*) AS total_students FROM students;")
# Queries
response = chain.invoke({"question": "How many students are enrolled in the university?"})
print(response)
print(db.run(response))
response = chain.invoke({"question": "How many students are enrolled in the 'Computer Science' course?"})
print(response)
print(db.run(response))
response = chain.invoke({"question": "Which course has the highest number of enrollments?"})
print(response)
print(db.run(response))
response = chain.invoke({"question": "Give me the top 5 students with the highest grades in 'Mathematics'."})
print(response)
print(db.run(response))
Sample CSV File
To try this out, you’ll need a sample CSV file to populate your database. Here’s a recommendation:
students.csv
student_id,first_name,last_name,course_name,grade
1001,John,Doe,Mathematics,95
1002,Jane,Smith,Computer Science,88
1003,Emily,Jones,Mathematics,92
1004,Michael,Brown,Physics,85
1005,Jessica,Davis,Mathematics,89
1006,David,Wilson,Computer Science,91
1007,Alice,Moore,Mathematics,87
1008,Robert,Taylor,Computer Science,85
1009,Mary,Anderson,Mathematics,90
1010,James,Thomas,Physics,82
This CSV should have enough rows to demonstrate different courses and grades.
Expected Outputs
Here are the expected results for each query, assuming your database is populated with the sample data:
Print Usable Table Names:
print(db.dialect)
print(db.get_usable_table_names())
Expected output :
mysql
['students', 'courses', 'enrollments', 'instructors']
[('total_students', 1000)]
Print Usable Table Names:
Total Students Enrolled:
response = chain.invoke({"question": "How many students are enrolled in the university?"})
print(response)
print(db.run(response))
Expected output :
SELECT COUNT(*) AS total_students FROM students;
[('total_students', 10)]
Students Enrolled in ‘Computer Science’:
response = chain.invoke({"question": "How many students are enrolled in the 'Computer Science' course?"})
print(response)
print(db.run(response))
Expected output :
SELECT COUNT(*) AS num_students FROM students WHERE course_name = 'Computer Science';dedede
[('num_students', 3)]
Course with the Highest Number of Enrollments:
response = chain.invoke({"question": "Which course has the highest number of enrollments?"})
print(response)
print(db.run(response))
Expected output :
SELECT course_name, COUNT(*) AS total_enrollments FROM students GROUP BY course_name ORDER BY total_enrollments DESC LIMIT 1;
[('Mathematics', 5)]
Top 5 Students with Highest Grades in ‘Mathematics’:
response = chain.invoke({"question": "Give me the top 5 students with the highest grades in 'Mathematics'."})
print(response)
print(db.run(response))
SELECT student_id, first_name, last_name, grade FROM students WHERE course_name = 'Mathematics' ORDER BY grade DESC LIMIT 5;
[(1001, 'John', 'Doe', 95), (1003, 'Emily', 'Jones', 92), (1009, 'Mary', 'Anderson', 90), (1005, 'Jessica', 'Davis', 89),(1007, 'Alice', 'Moore', 87)]
Using LangChain and OpenAI in conjunction with an SQL database can simplify the process of querying and analyzing data. This setup allows you to interact with complex databases using natural language, making data analysis more accessible to everyone, regardless of their SQL expertise.