SQL
这个示例演示了如何使用 SQLDatabaseChain
来对 SQL 数据库进行问题回答。
在底层,LangChain 使用 SQLAlchemy 连接 SQL 数据库。因此,SQLDatabaseChain
可与 SQLAlchemy 支持的任何 SQL 方言一起使用,例如 MS SQL、MySQL、MariaDB、PostgreSQL、Oracle SQL、Databricks 和 SQLite。有关连接到数据库的要求的详细信息,请参阅 SQLAlchemy 文档。例如,连接到 MySQL 需要适当的连接器,如 PyMySQL。MySQL 连接的 URI 可能如下所示:mysql+pymysql://user:pass@some_mysql_db_address/db_name
。
此演示使用 SQLite 和示例 Chinook 数据库。
要设置它,请按照 https://database.guide/2-sample-databases-sqlite/ 上的说明,在此存储库根目录的 notebooks 文件夹中放置 .db
文件。
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db")
llm = OpenAI(temperature=0, verbose=True)
NOTE: For data-sensitive projects, you can specify return_direct=True
in the SQLDatabaseChain
initialization to directly return the output of the SQL query without any additional formatting. This prevents the LLM from seeing any contents within the database. Note, however, the LLM still has access to the database scheme (i.e. dialect, table and key names) by default.
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("How many employees are there?")
> Entering new SQLDatabaseChain chain...
How many employees are there?
SQLQuery:
/workspace/langchain/langchain/sql_database.py:191: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.
sample_rows = connection.execute(command)
SELECT COUNT(*) FROM "Employee";
SQLResult: [(8,)]
Answer:There are 8 employees.
> Finished chain.
'There are 8 employees.'
使用查询检查器 Query Checker
有时,语言模型会生成具有小错误的无效 SQL,可以使用 SQL Database Agent 使用的相同技术来自行纠正 SQL。您只需在创建链时指定此选项即可:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True)
db_chain.run("How many albums by Aerosmith?")
> Entering new SQLDatabaseChain chain...
How many albums by Aerosmith?
SQLQuery:SELECT COUNT(*) FROM Album WHERE ArtistId = 3;
SQLResult: [(1,)]
Answer:There is 1 album by Aerosmith.
> Finished chain.
'There is 1 album by Aerosmith.'
自定义提示
您还可以自定义使用的提示。这是一个将其提示为理解 foobar 与 Employee 表相同的示例:
from langchain.prompts.prompt import PromptTemplate
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
If someone asks for the table foobar, they really mean the employee table.
Question: {input}"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)
db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)
db_chain.run("How many employees are there in the foobar table?")
> Entering new SQLDatabaseChain chain...
How many employees are there in the foobar table?
SQLQuery:SELECT COUNT(*) FROM Employee;
SQLResult: [(8,)]
Answer:There are 8 employees in the foobar table.
> Finished chain.
'There are 8 employees in the foobar table.'
返回中间步骤 (Return Intermediate Steps)
您还可以返回 SQLDatabaseChain 的中间步骤。这允许您访问生成的 SQL 语句以及针对 SQL 数据库运行的结果。
db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, use_query_checker=True, return_intermediate_steps=True)
result = db_chain("How many employees are there in the foobar table?")
result["intermediate_steps"]
> Entering new SQLDatabaseChain chain...
How many employees are there in the foobar table?
SQLQuery:SELECT COUNT(*) FROM Employee;
SQLResult: [(8,)]
Answer:There are 8 employees in the foobar table.
> Finished chain.
[{'input': 'How many employees are there in the foobar table?\nSQLQuery:SELECT COUNT(*) FROM Employee;\nSQLResult: [(8,)]\nAnswer:',
'top_k': '5',
'dialect': 'sqlite',
'table_info': '\nCREATE TABLE "Artist" (\n\t"ArtistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("ArtistId")\n)\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/\n\n\nCREATE TABLE "Employee" (\n\t"EmployeeId" INTEGER NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"FirstName" NVARCHAR(20) NOT NULL, \n\t"Title" NVARCHAR(30), \n\t"ReportsTo" INTEGER, \n\t"BirthDate" DATETIME, \n\t"HireDate" DATETIME, \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60), \n\tPRIMARY KEY ("EmployeeId"), \n\tFOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")\n)\n\n/*\n3 rows from Employee table:\nEmployeeId\tLastName\tFirstName\tTitle\tReportsTo\tBirthDate\tHireDate\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\n1\tAdams\tAndrew\tGeneral Manager\tNone\t1962-02-18 00:00:00\t2002-08-14 00:00:00\t11120 Jasper Ave NW\tEdmonton\tAB\tCanada\tT5K 2N1\t+1 (780) 428-9482\t+1 (780) 428-3457\[email protected]\n2\tEdwards\tNancy\tSales Manager\t1\t1958-12-08 00:00:00\t2002-05-01 00:00:00\t825 8 Ave SW\tCalgary\tAB\tCanada\tT2P 2T3\t+1 (403) 262-3443\t+1 (403) 262-3322\[email protected]\n3\tPeacock\tJane\tSales Support Agent\t2\t1973-08-29 00:00:00\t2002-04-01 00:00:00\t1111 6 Ave SW\tCalgary\tAB\tCanada\tT2P 5M5\t+1 (403) 262-3443\t+1 (403) 262-6712\[email protected]\n*/\n\n\nCREATE TABLE "Genre" (\n\t"GenreId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("GenreId")\n)\n\n/*\n3 rows from Genre table:\nGenreId\tName\n1\tRock\n2\tJazz\n3\tMetal\n*/\n\n\nCREATE TABLE "MediaType" (\n\t"MediaTypeId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("MediaTypeId")\n)\n\n/*\n3 rows from MediaType table:\nMediaTypeId\tName\n1\tMPEG audio file\n2\tProtected AAC audio file\n3\tProtected MPEG-4 video file\n*/\n\n\nCREATE TABLE "Playlist" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("PlaylistId")\n)\n\n/*\n3 rows from Playlist table:\nPlaylistId\tName\n1\tMusic\n2\tMovies\n3\tTV Shows\n*/\n\n\nCREATE TABLE "Album" (\n\t"AlbumId" INTEGER NOT NULL, \n\t"Title" NVARCHAR(160) NOT NULL, \n\t"ArtistId" INTEGER NOT NULL, \n\tPRIMARY KEY ("AlbumId"), \n\tFOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")\n)\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"SupportRepId" INTEGER, \n\tPRIMARY KEY ("CustomerId"), \n\tFOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")\n)\n\n/*\n3 rows from Customer table:\nCustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\[email protected]\t3\n2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\[email protected]\t5\n3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\[email protected]\t3\n*/\n\n\nCREATE TABLE "Invoice" (\n\t"InvoiceId" INTEGER NOT NULL, \n\t"CustomerId" INTEGER NOT NULL, \n\t"InvoiceDate" DATETIME NOT NULL, \n\t"BillingAddress" NVARCHAR(70), \n\t"BillingCity" NVARCHAR(40), \n\t"BillingState" NVARCHAR(40), \n\t"BillingCountry" NVARCHAR(40), \n\t"BillingPostalCode" NVARCHAR(10), \n\t"Total" NUMERIC(10, 2) NOT NULL, \n\tPRIMARY KEY ("InvoiceId"), \n\tFOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")\n)\n\n/*\n3 rows from Invoice table:\nInvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n*/\n\n\nCREATE TABLE "Track" (\n\t"TrackId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(200) NOT NULL, \n\t"AlbumId" INTEGER, \n\t"MediaTypeId" INTEGER NOT NULL, \n\t"GenreId" INTEGER, \n\t"Composer" NVARCHAR(220), \n\t"Milliseconds" INTEGER NOT NULL, \n\t"Bytes" INTEGER, \n\t"UnitPrice" NUMERIC(10, 2) NOT NULL, \n\tPRIMARY KEY ("TrackId"), \n\tFOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), \n\tFOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), \n\tFOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")\n)\n\n/*\n3 rows from Track table:\nTrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99\n*/\n\n\nCREATE TABLE "InvoiceLine" (\n\t"InvoiceLineId" INTEGER NOT NULL, \n\t"InvoiceId" INTEGER NOT NULL, \n\t"TrackId" INTEGER NOT NULL, \n\t"UnitPrice" NUMERIC(10, 2) NOT NULL, \n\t"Quantity" INTEGER NOT NULL, \n\tPRIMARY KEY ("InvoiceLineId"), \n\tFOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), \n\tFOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")\n)\n\n/*\n3 rows from InvoiceLine table:\nInvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n1\t1\t2\t0.99\t1\n2\t1\t4\t0.99\t1\n3\t2\t6\t0.99\t1\n*/\n\n\nCREATE TABLE "PlaylistTrack" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"TrackId" INTEGER NOT NULL, \n\tPRIMARY KEY ("PlaylistId", "TrackId"), \n\tFOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), \n\tFOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")\n)\n\n/*\n3 rows from PlaylistTrack table:\nPlaylistId\tTrackId\n1\t3402\n1\t3389\n1\t3390\n*/',
'stop': ['\nSQLResult:']},
'SELECT COUNT(*) FROM Employee;',
{'query': 'SELECT COUNT(*) FROM Employee;', 'dialect': 'sqlite'},
'SELECT COUNT(*) FROM Employee;',
'[(8,)]']
选择如何限制返回的行数
如果您查询表的多行,可以使用 'top_k' 参数(默认为 10)选择要获取的最大结果数。这对于避免查询结果超过提示的最大长度或不必要地消耗令牌很有用。
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True, top_k=3)
db_chain.run("What are some example tracks by composer Johann Sebastian Bach?")
> Entering new SQLDatabaseChain chain...
What are some example tracks by composer Johann Sebastian Bach?
SQLQuery:SELECT Name FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3
SQLResult: [('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',)]
Answer:Examples of tracks by Johann Sebastian Bach are Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace, Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria, and Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude.
> Finished chain.
'Examples of tracks by Johann Sebastian Bach are Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace, Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria, and Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude.'
添加每个表的示例行
有时,数据的格式并不明显,最佳选择是在提示中包含来自表中的行的样本,以便 LLM 在提供最终查询之前了解数据。在这里,我们将使用此功能,通过从 Track
表中提供两行来让 LLM 知道艺术家是以他们的全名保存的。
db = SQLDatabase.from_uri(
"sqlite:///../../../../notebooks/Chinook.db",
include_tables=['Track'], # we include only one table to save tokens in the prompt :)
sample_rows_in_table_info=2)
样本行将在每个相应表的列信息之后添加到提示中:
print(db.table_info)
CREATE TABLE "Track" (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"AlbumId" INTEGER,
"MediaTypeId" INTEGER NOT NULL,
"GenreId" INTEGER,
"Composer" NVARCHAR(220),
"Milliseconds" INTEGER NOT NULL,
"Bytes" INTEGER,
"UnitPrice" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("TrackId"),
FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"),
FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"),
FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
/*
2 rows from Track table:
TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice
1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99
2 Balls to the Wall 2 2 1 None 342562 5510424 0.99
*/
db_chain = SQLDatabaseChain.from_llm(llm, db, use_query_checker=True, verbose=True)
db_chain.run("What are some example tracks by Bach?")
> Entering new SQLDatabaseChain chain...
What are some example tracks by Bach?
SQLQuery:SELECT "Name", "Composer" FROM "Track" WHERE "Composer" LIKE '%Bach%' LIMIT 5
SQLResult: [('American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), ('Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria', 'Johann Sebastian Bach'), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', 'Johann Sebastian Bach'), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata', 'Johann Sebastian Bach')]
Answer:Tracks by Bach include 'American Woman', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria', 'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude', and 'Toccata and Fugue in D Minor, BWV 565: I. Toccata'.
> Finished chain.
'Tracks by Bach include \'American Woman\', \'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\', \'Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria\', \'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\', and \'Toccata and Fugue in D Minor, BWV 565: I. Toccata\'.'
自定义表信息 (Custom Table Info)
在某些情况下,提供自定义表信息而不是使用自动生成的表定义和第一个 sample_rows_in_table_info
示例行可能很有用。例如,如果您知道表的前几行无关紧要,手动提供更多样化的示例行或为模型提供更多信息可能会有所帮助。如果存在不必要的列,还可以限制模型可见的列。
此信息可以作为字典提供,其中表名称为键,表信息为值。例如,让我们为仅有几列的 Track 表提供自定义定义和示例行:
custom_table_info = {
"Track": """CREATE TABLE Track (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"Composer" NVARCHAR(220),
PRIMARY KEY ("TrackId")
)
/*
3 rows from Track table:
TrackId Name Composer
1 For Those About To Rock (We Salute You) Angus Young, Malcolm Young, Brian Johnson
2 Balls to the Wall None
3 My favorite song ever The coolest composer of all time
*/"""
}
db = SQLDatabase.from_uri(
"sqlite:///../../../../notebooks/Chinook.db",
include_tables=['Track', 'Playlist'],
sample_rows_in_table_info=2,
custom_table_info=custom_table_info)
print(db.table_info)
CREATE TABLE "Playlist" (
"PlaylistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("PlaylistId")
)
/*
2 rows from Playlist table:
PlaylistId Name
1 Music
2 Movies
*/
CREATE TABLE Track (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"Composer" NVARCHAR(220),
PRIMARY KEY ("TrackId")
)
/*
3 rows from Track table:
TrackId Name Composer
1 For Those About To Rock (We Salute You) Angus Young, Malcolm Young, Brian Johnson
2 Balls to the Wall None
3 My favorite song ever The coolest composer of all time
*/
请注意,我们为 Track
的自定义表定义和示例行覆盖了 sample_rows_in_table_info
参数。未被 custom_table_info
覆盖的表(在此示例中为 Playlist
)将像往常一样自动收集其表信息。
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("What are some example tracks by Bach?")
> Entering new SQLDatabaseChain chain...
What are some example tracks by Bach?
SQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE '%Bach%' LIMIT 5;
SQLResult: [('American Woman',), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata',)]
Answer:text='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\nUse the following format:\n\nQuestion: "Question here"\nSQLQuery: "SQL Query to run"\nSQLResult: "Result of the SQLQuery"\nAnswer: "Final answer here"\n\nOnly use the following tables:\n\nCREATE TABLE "Playlist" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("PlaylistId")\n)\n\n/*\n2 rows from Playlist table:\nPlaylistId\tName\n1\tMusic\n2\tMovies\n*/\n\nCREATE TABLE Track (\n\t"TrackId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(200) NOT NULL,\n\t"Composer" NVARCHAR(220),\n\tPRIMARY KEY ("TrackId")\n)\n/*\n3 rows from Track table:\nTrackId\tName\tComposer\n1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n2\tBalls to the Wall\tNone\n3\tMy favorite song ever\tThe coolest composer of all time\n*/\n\nQuestion: What are some example tracks by Bach?\nSQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE \'%Bach%\' LIMIT 5;\nSQLResult: [(\'American Woman\',), (\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\',), (\'Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria\',), (\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\',), (\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\',)]\nAnswer:'
You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
CREATE TABLE "Playlist" (
"PlaylistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("PlaylistId")
)
/*
2 rows from Playlist table:
PlaylistId Name
1 Music
2 Movies
*/
CREATE TABLE Track (
"TrackId" INTEGER NOT NULL,
"Name" NVARCHAR(200) NOT NULL,
"Composer" NVARCHAR(220),
PRIMARY KEY ("TrackId")
)
/*
3 rows from Track table:
TrackId Name Composer
1 For Those About To Rock (We Salute You) Angus Young, Malcolm Young, Brian Johnson
2 Balls to the Wall None
3 My favorite song ever The coolest composer of all time
*/
Question: What are some example tracks by Bach?
SQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE '%Bach%' LIMIT 5;
SQLResult: [('American Woman',), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata',)]
Answer:
{'input': 'What are some example tracks by Bach?\nSQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE \'%Bach%\' LIMIT 5;\nSQLResult: [(\'American Woman\',), (\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\',), (\'Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria\',), (\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\',), (\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\',)]\nAnswer:', 'top_k': '5', 'dialect': 'sqlite', 'table_info': '\nCREATE TABLE "Playlist" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("PlaylistId")\n)\n\n/*\n2 rows from Playlist table:\nPlaylistId\tName\n1\tMusic\n2\tMovies\n*/\n\nCREATE TABLE Track (\n\t"TrackId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(200) NOT NULL,\n\t"Composer" NVARCHAR(220),\n\tPRIMARY KEY ("TrackId")\n)\n/*\n3 rows from Track table:\nTrackId\tName\tComposer\n1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n2\tBalls to the Wall\tNone\n3\tMy favorite song ever\tThe coolest composer of all time\n*/', 'stop': ['\nSQLResult:']}
Examples of tracks by Bach include "American Woman", "Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace", "Aria Mit 30 Veränderungen, BWV 988 'Goldberg Variations': Aria", "Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude", and "Toccata and Fugue in D Minor, BWV 565: I. Toccata".
> Finished chain.
'Examples of tracks by Bach include "American Woman", "Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace", "Aria Mit 30 Veränderungen, BWV 988 \'Goldberg Variations\': Aria", "Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude", and "Toccata and Fugue in D Minor, BWV 565: I. Toccata".'
SQLDatabaseSequentialChain
用于查询 SQL 数据库的顺序链。
链的顺序如下:
1. 基于查询确定要使用的表。
2. 基于这些表,调用正常的 SQL 数据库链。
这在数据库中的表数量很大的情况下非常有用。
from langchain.chains import SQLDatabaseSequentialChain
db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db")
chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)
chain.run("How many employees are also customers?")
> Entering new SQLDatabaseSequentialChain chain...
Table names to use:
['Employee', 'Customer']
> Entering new SQLDatabaseChain chain...
How many employees are also customers?
SQLQuery:SELECT COUNT(*) FROM Employee e INNER JOIN Customer c ON e.EmployeeId = c.SupportRepId;
SQLResult: [(59,)]
Answer:59 employees are also customers.
> Finished chain.
> Finished chain.
'59 employees are also customers.'
使用本地语言模型
有时,您可能无法使用 OpenAI 或其他托管服务的大型语言模型。您当然可以尝试使用 SQLDatabaseChain
与本地模型一起使用,但很快会意识到,即使使用大型 GPU 运行本地模型,大多数模型仍然难以生成正确的输出。
import logging
import torch
from transformers import AutoTokenizer, GPT2TokenizerFast, pipeline, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from langchain import HuggingFacePipeline
# Note: This model requires a large GPU, e.g. an 80GB A100. See documentation for other ways to run private non-OpenAI models.
model_id = "google/flan-ul2"
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, temperature=0)
device_id = -1 # default to no-GPU, but use GPU and half precision mode if available
if torch.cuda.is_available():
device_id = 0
try:
model = model.half()
except RuntimeError as exc:
logging.warn(f"Could not run model in half precision mode: {str(exc)}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline(task="text2text-generation", model=model, tokenizer=tokenizer, max_length=1024, device=device_id)
local_llm = HuggingFacePipeline(pipeline=pipe)
/workspace/langchain/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:32<00:00, 4.11s/it]
from langchain import SQLDatabase, SQLDatabaseChain
db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db", include_tables=['Customer'])
local_chain = SQLDatabaseChain.from_llm(local_llm, db, verbose=True, return_intermediate_steps=True, use_query_checker=True)
只要您按照上面指定的方式使用查询检查器,例如:
local_chain("How many customers are there?")
> Entering new SQLDatabaseChain chain...
How many customers are there?
SQLQuery:
/workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
warnings.warn(
/workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
warnings.warn(
SELECT count(*) FROM Customer
SQLResult: [(59,)]
Answer:
/workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
warnings.warn(
[59]
> Finished chain.
{'query': 'How many customers are there?',
'result': '[59]',
'intermediate_steps': [{'input': 'How many customers are there?\nSQLQuery:SELECT count(*) FROM Customer\nSQLResult: [(59,)]\nAnswer:',
'top_k': '5',
'dialect': 'sqlite',
'table_info': '\nCREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"SupportRepId" INTEGER, \n\tPRIMARY KEY ("CustomerId"), \n\tFOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")\n)\n\n/*\n3 rows from Customer table:\nCustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\[email protected]\t3\n2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\[email protected]\t5\n3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\[email protected]\t3\n*/',
'stop': ['\nSQLResult:']},
'SELECT count(*) FROM Customer',
{'query': 'SELECT count(*) FROM Customer', 'dialect': 'sqlite'},
'SELECT count(*) FROM Customer',
'[(59,)]']}
即使对于相对复杂的 SQL,这个相对较大的模型也很可能无法独立生成。但是,您可以记录其输入和输出,以便手动纠正它们,并将纠正后的示例用于后续的 few shot prompt 示例。实际上,您可以记录任何引发异常的链执行(如下例所示),或在结果不正确的情况下直接获得用户反馈(但没有引发异常)。
poetry run pip install pyyaml chromadb
import yaml
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
11842.36s - pydevd: Sending message related to process being replaced timed-out after 5 seconds
Requirement already satisfied: pyyaml in /workspace/langchain/.venv/lib/python3.9/site-packages (6.0)
Requirement already satisfied: chromadb in /workspace/langchain/.venv/lib/python3.9/site-packages (0.3.21)
Requirement already satisfied: pandas>=1.3 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (2.0.1)
Requirement already satisfied: requests>=2.28 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (2.28.2)
Requirement already satisfied: pydantic>=1.9 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (1.10.7)
Requirement already satisfied: hnswlib>=0.7 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.7.0)
Requirement already satisfied: clickhouse-connect>=0.5.7 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.5.20)
Requirement already satisfied: sentence-transformers>=2.2.2 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (2.2.2)
Requirement already satisfied: duckdb>=0.7.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.7.1)
Requirement already satisfied: fastapi>=0.85.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.95.1)
Requirement already satisfied: uvicorn[standard]>=0.18.3 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.21.1)
Requirement already satisfied: numpy>=1.21.6 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (1.24.3)
Requirement already satisfied: posthog>=2.4.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (3.0.1)
Requirement already satisfied: certifi in /workspace/langchain/.venv/lib/python3.9/site-packages (from clickhouse-connect>=0.5.7->chromadb) (2022.12.7)
Requirement already satisfied: urllib3>=1.26 in /workspace/langchain/.venv/lib/python3.9/site-packages (from clickhouse-connect>=0.5.7->chromadb) (1.26.15)
Requirement already satisfied: pytz in /workspace/langchain/.venv/lib/python3.9/site-packages (from clickhouse-connect>=0.5.7->chromadb) (2023.3)
Requirement already satisfied: zstandard in /workspace/langchain/.venv/lib/python3.9/site-packages (from clickhouse-connect>=0.5.7->chromadb) (0.21.0)
Requirement already satisfied: lz4 in /workspace/langchain/.venv/lib/python3.9/site-packages (from clickhouse-connect>=0.5.7->chromadb) (4.3.2)
Requirement already satisfied: starlette<0.27.0,>=0.26.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from fastapi>=0.85.1->chromadb) (0.26.1)
Requirement already satisfied: python-dateutil>=2.8.2 in /workspace/langchain/.venv/lib/python3.9/site-packages (from pandas>=1.3->chromadb) (2.8.2)
Requirement already satisfied: tzdata>=2022.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from pandas>=1.3->chromadb) (2023.3)
Requirement already satisfied: six>=1.5 in /workspace/langchain/.venv/lib/python3.9/site-packages (from posthog>=2.4.0->chromadb) (1.16.0)
Requirement already satisfied: monotonic>=1.5 in /workspace/langchain/.venv/lib/python3.9/site-packages (from posthog>=2.4.0->chromadb) (1.6)
Requirement already satisfied: backoff>=1.10.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from posthog>=2.4.0->chromadb) (2.2.1)
Requirement already satisfied: typing-extensions>=4.2.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from pydantic>=1.9->chromadb) (4.5.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /workspace/langchain/.venv/lib/python3.9/site-packages (from requests>=2.28->chromadb) (3.1.0)
Requirement already satisfied: idna<4,>=2.5 in /workspace/langchain/.venv/lib/python3.9/site-packages (from requests>=2.28->chromadb) (3.4)
Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (4.28.1)
Requirement already satisfied: tqdm in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (4.65.0)
Requirement already satisfied: torch>=1.6.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (1.13.1)
Requirement already satisfied: torchvision in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (0.14.1)
Requirement already satisfied: scikit-learn in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (1.2.2)
Requirement already satisfied: scipy in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (1.9.3)
Requirement already satisfied: nltk in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (3.8.1)
Requirement already satisfied: sentencepiece in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (0.1.98)
Requirement already satisfied: huggingface-hub>=0.4.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from sentence-transformers>=2.2.2->chromadb) (0.13.4)
Requirement already satisfied: click>=7.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from uvicorn[standard]>=0.18.3->chromadb) (8.1.3)
Requirement already satisfied: h11>=0.8 in /workspace/langchain/.venv/lib/python3.9/site-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.14.0)
Requirement already satisfied: httptools>=0.5.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.5.0)
Requirement already satisfied: python-dotenv>=0.13 in /workspace/langchain/.venv/lib/python3.9/site-packages (from uvicorn[standard]>=0.18.3->chromadb) (1.0.0)
Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.17.0)
Requirement already satisfied: watchfiles>=0.13 in /workspace/langchain/.venv/lib/python3.9/site-packages (from uvicorn[standard]>=0.18.3->chromadb) (0.19.0)
Requirement already satisfied: websockets>=10.4 in /workspace/langchain/.venv/lib/python3.9/site-packages (from uvicorn[standard]>=0.18.3->chromadb) (11.0.2)
Requirement already satisfied: filelock in /workspace/langchain/.venv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence-transformers>=2.2.2->chromadb) (3.12.0)
Requirement already satisfied: packaging>=20.9 in /workspace/langchain/.venv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence-transformers>=2.2.2->chromadb) (23.1)
Requirement already satisfied: anyio<5,>=3.4.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from starlette<0.27.0,>=0.26.1->fastapi>=0.85.1->chromadb) (3.6.2)
Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /workspace/langchain/.venv/lib/python3.9/site-packages (from torch>=1.6.0->sentence-transformers>=2.2.2->chromadb) (11.7.99)
Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /workspace/langchain/.venv/lib/python3.9/site-packages (from torch>=1.6.0->sentence-transformers>=2.2.2->chromadb) (8.5.0.96)
Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /workspace/langchain/.venv/lib/python3.9/site-packages (from torch>=1.6.0->sentence-transformers>=2.2.2->chromadb) (11.10.3.66)
Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /workspace/langchain/.venv/lib/python3.9/site-packages (from torch>=1.6.0->sentence-transformers>=2.2.2->chromadb) (11.7.99)
Requirement already satisfied: setuptools in /workspace/langchain/.venv/lib/python3.9/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.6.0->sentence-transformers>=2.2.2->chromadb) (67.7.1)
Requirement already satisfied: wheel in /workspace/langchain/.venv/lib/python3.9/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.6.0->sentence-transformers>=2.2.2->chromadb) (0.40.0)
Requirement already satisfied: regex!=2019.12.17 in /workspace/langchain/.venv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers>=2.2.2->chromadb) (2023.3.23)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers>=2.2.2->chromadb) (0.13.3)
Requirement already satisfied: joblib in /workspace/langchain/.venv/lib/python3.9/site-packages (from nltk->sentence-transformers>=2.2.2->chromadb) (1.2.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from scikit-learn->sentence-transformers>=2.2.2->chromadb) (3.1.0)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /workspace/langchain/.venv/lib/python3.9/site-packages (from torchvision->sentence-transformers>=2.2.2->chromadb) (9.5.0)
Requirement already satisfied: sniffio>=1.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from anyio<5,>=3.4.0->starlette<0.27.0,>=0.26.1->fastapi>=0.85.1->chromadb) (1.3.0)
from typing import Dict
QUERY = "List all the customer first names that start with 'a'"
def _parse_example(result: Dict) -> Dict:
sql_cmd_key = "sql_cmd"
sql_result_key = "sql_result"
table_info_key = "table_info"
input_key = "input"
final_answer_key = "answer"
_example = {
"input": result.get("query"),
}
steps = result.get("intermediate_steps")
answer_key = sql_cmd_key # the first one
for step in steps:
# The steps are in pairs, a dict (input) followed by a string (output).
# Unfortunately there is no schema but you can look at the input key of the
# dict to see what the output is supposed to be
if isinstance(step, dict):
# Grab the table info from input dicts in the intermediate steps once
if table_info_key not in _example:
_example[table_info_key] = step.get(table_info_key)
if input_key in step:
if step[input_key].endswith("SQLQuery:"):
answer_key = sql_cmd_key # this is the SQL generation input
if step[input_key].endswith("Answer:"):
answer_key = final_answer_key # this is the final answer input
elif sql_cmd_key in step:
_example[sql_cmd_key] = step[sql_cmd_key]
answer_key = sql_result_key # this is SQL execution input
elif isinstance(step, str):
# The preceding element should have set the answer_key
_example[answer_key] = step
return _example
example: any
try:
result = local_chain(QUERY)
print("*** Query succeeded")
example = _parse_example(result)
except Exception as exc:
print("*** Query failed")
result = {
"query": QUERY,
"intermediate_steps": exc.intermediate_steps
}
example = _parse_example(result)
# print for now, in reality you may want to write this out to a YAML file or database for manual fix-ups offline
yaml_example = yaml.dump(example, allow_unicode=True)
print("\n" + yaml_example)
> Entering new SQLDatabaseChain chain...
List all the customer first names that start with 'a'
SQLQuery:
/workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
warnings.warn(
SELECT firstname FROM customer WHERE firstname LIKE '%a%'
SQLResult: [('François',), ('František',), ('Helena',), ('Astrid',), ('Daan',), ('Kara',), ('Eduardo',), ('Alexandre',), ('Fernanda',), ('Mark',), ('Frank',), ('Jack',), ('Dan',), ('Kathy',), ('Heather',), ('Frank',), ('Richard',), ('Patrick',), ('Julia',), ('Edward',), ('Martha',), ('Aaron',), ('Madalena',), ('Hannah',), ('Niklas',), ('Camille',), ('Marc',), ('Wyatt',), ('Isabelle',), ('Ladislav',), ('Lucas',), ('Johannes',), ('Stanisław',), ('Joakim',), ('Emma',), ('Mark',), ('Manoj',), ('Puja',)]
Answer:
/workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
warnings.warn(
[('François', 'Frantiek', 'Helena', 'Astrid', 'Daan', 'Kara', 'Eduardo', 'Alexandre', 'Fernanda', 'Mark', 'Frank', 'Jack', 'Dan', 'Kathy', 'Heather', 'Frank', 'Richard', 'Patrick', 'Julia', 'Edward', 'Martha', 'Aaron', 'Madalena', 'Hannah', 'Niklas', 'Camille', 'Marc', 'Wyatt', 'Isabelle', 'Ladislav', 'Lucas', 'Johannes', 'Stanisaw', 'Joakim', 'Emma', 'Mark', 'Manoj', 'Puja']
> Finished chain.
*** Query succeeded
answer: '[(''François'', ''Frantiek'', ''Helena'', ''Astrid'', ''Daan'', ''Kara'',
''Eduardo'', ''Alexandre'', ''Fernanda'', ''Mark'', ''Frank'', ''Jack'', ''Dan'',
''Kathy'', ''Heather'', ''Frank'', ''Richard'', ''Patrick'', ''Julia'', ''Edward'',
''Martha'', ''Aaron'', ''Madalena'', ''Hannah'', ''Niklas'', ''Camille'', ''Marc'',
''Wyatt'', ''Isabelle'', ''Ladislav'', ''Lucas'', ''Johannes'', ''Stanisaw'', ''Joakim'',
''Emma'', ''Mark'', ''Manoj'', ''Puja'']'
input: List all the customer first names that start with 'a'
sql_cmd: SELECT firstname FROM customer WHERE firstname LIKE '%a%'
sql_result: '[(''François'',), (''František'',), (''Helena'',), (''Astrid'',), (''Daan'',),
(''Kara'',), (''Eduardo'',), (''Alexandre'',), (''Fernanda'',), (''Mark'',), (''Frank'',),
(''Jack'',), (''Dan'',), (''Kathy'',), (''Heather'',), (''Frank'',), (''Richard'',),
(''Patrick'',), (''Julia'',), (''Edward'',), (''Martha'',), (''Aaron'',), (''Madalena'',),
(''Hannah'',), (''Niklas'',), (''Camille'',), (''Marc'',), (''Wyatt'',), (''Isabelle'',),
(''Ladislav'',), (''Lucas'',), (''Johannes'',), (''Stanisław'',), (''Joakim'',),
(''Emma'',), (''Mark'',), (''Manoj'',), (''Puja'',)]'
table_info: "\nCREATE TABLE \"Customer\" (\n\t\"CustomerId\" INTEGER NOT NULL, \n\t\
\"FirstName\" NVARCHAR(40) NOT NULL, \n\t\"LastName\" NVARCHAR(20) NOT NULL, \n\t\
\"Company\" NVARCHAR(80), \n\t\"Address\" NVARCHAR(70), \n\t\"City\" NVARCHAR(40),\
\ \n\t\"State\" NVARCHAR(40), \n\t\"Country\" NVARCHAR(40), \n\t\"PostalCode\" NVARCHAR(10),\
\ \n\t\"Phone\" NVARCHAR(24), \n\t\"Fax\" NVARCHAR(24), \n\t\"Email\" NVARCHAR(60)\
\ NOT NULL, \n\t\"SupportRepId\" INTEGER, \n\tPRIMARY KEY (\"CustomerId\"), \n\t\
FOREIGN KEY(\"SupportRepId\") REFERENCES \"Employee\" (\"EmployeeId\")\n)\n\n/*\n\
3 rows from Customer table:\nCustomerId\tFirstName\tLastName\tCompany\tAddress\t\
City\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n1\tLuís\tGonçalves\t\
Embraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\t\
São José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\t\
[email protected]\t3\n2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\t\
None\tGermany\t70174\t+49 0711 2842222\tNone\[email protected]\t5\n3\tFrançois\t\
Tremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\t\
None\[email protected]\t3\n*/"
多次运行上面的片段,或在部署环境中记录异常,以收集大量由语言模型生成的输入、table_info 和 sql_cmd 的示例。sql_cmd 的值将是不正确的,您可以手动修正它们以建立示例集合。例如,在这里,我们使用 YAML 来保持我们的输入和纠正后的 SQL 输出的整洁记录,以便随着时间的推移逐步建立它们。
YAML_EXAMPLES = """
- input: How many customers are not from Brazil?
table_info: |
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
sql_cmd: SELECT COUNT(*) FROM "Customer" WHERE NOT "Country" = "Brazil";
sql_result: "[(54,)]"
answer: 54 customers are not from Brazil.
- input: list all the genres that start with 'r'
table_info: |
CREATE TABLE "Genre" (
"GenreId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("GenreId")
)
/*
3 rows from Genre table:
GenreId Name
1 Rock
2 Jazz
3 Metal
*/
sql_cmd: SELECT "Name" FROM "Genre" WHERE "Name" LIKE 'r%';
sql_result: "[('Rock',), ('Rock and Roll',), ('Reggae',), ('R&B/Soul',)]"
answer: The genres that start with 'r' are Rock, Rock and Roll, Reggae and R&B/Soul.
"""
现在您有了一些示例(具有手动纠正的输出 SQL),您可以按通常的方式执行 few shot prompt 播种:
from langchain import FewShotPromptTemplate, PromptTemplate
from langchain.chains.sql_database.prompt import _sqlite_prompt, PROMPT_SUFFIX
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.prompts.example_selector.semantic_similarity import SemanticSimilarityExampleSelector
from langchain.vectorstores import Chroma
example_prompt = PromptTemplate(
input_variables=["table_info", "input", "sql_cmd", "sql_result", "answer"],
template="{table_info}\n\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {sql_result}\nAnswer: {answer}",
)
examples_dict = yaml.safe_load(YAML_EXAMPLES)
local_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
example_selector = SemanticSimilarityExampleSelector.from_examples(
# This is the list of examples available to select from.
examples_dict,
# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
local_embeddings,
# This is the VectorStore class that is used to store the embeddings and do a similarity search over.
Chroma, # type: ignore
# This is the number of examples to produce and include per prompt
k=min(3, len(examples_dict)),
)
few_shot_prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix=_sqlite_prompt + "Here are some examples:",
suffix=PROMPT_SUFFIX,
input_variables=["table_info", "input", "top_k"],
)
Using embedded DuckDB without persistence: data will be transient
现在,使用这个 few shot prompt,模型应该表现得更好,特别是对于与您种子相似的输入。
local_chain = SQLDatabaseChain.from_llm(local_llm, db, prompt=few_shot_prompt, use_query_checker=True, verbose=True, return_intermediate_steps=True)
result = local_chain("How many customers are from Brazil?")
> Entering new SQLDatabaseChain chain...
How many customers are from Brazil?
SQLQuery:SELECT count(*) FROM Customer WHERE Country = "Brazil";
SQLResult: [(5,)]
Answer:[5]
> Finished chain.
result = local_chain("How many customers are not from Brazil?")
> Entering new SQLDatabaseChain chain...
How many customers are not from Brazil?
SQLQuery:SELECT count(*) FROM customer WHERE country NOT IN (SELECT country FROM customer WHERE country = 'Brazil')
SQLResult: [(54,)]
Answer:54 customers are not from Brazil.
> Finished chain.
result = local_chain("How many customers are there in total?")
> Entering new SQLDatabaseChain chain...
How many customers are there in total?
SQLQuery:SELECT count(*) FROM Customer;
SQLResult: [(59,)]
Answer:There are 59 customers in total.
> Finished chain.