Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.runnables import Runnable, RunnablePassthrough | |
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS | |
if TYPE_CHECKING: | |
from langchain_community.utilities.sql_database import SQLDatabase | |
def _strip(text: str) -> str: | |
return text.strip() | |
class SQLInput(TypedDict): | |
"""Input for a SQL Chain.""" | |
question: str | |
class SQLInputWithTables(TypedDict): | |
"""Input for a SQL Chain.""" | |
question: str | |
table_names_to_use: List[str] | |
def create_sql_query_chain( | |
llm: BaseLanguageModel, | |
db: SQLDatabase, | |
prompt: Optional[BasePromptTemplate] = None, | |
k: int = 5, | |
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]: | |
"""Create a chain that generates SQL queries. | |
*Security Note*: This chain generates SQL queries for the given database. | |
The SQLDatabase class provides a get_table_info method that can be used | |
to get column information as well as sample data from the table. | |
To mitigate risk of leaking sensitive data, limit permissions | |
to read and scope to the tables that are needed. | |
Optionally, use the SQLInputWithTables input type to specify which tables | |
are allowed to be accessed. | |
Control access to who can submit requests to this chain. | |
See https://python.langchain.com/docs/security for more information. | |
Args: | |
llm: The language model to use. | |
db: The SQLDatabase to generate the query for. | |
prompt: The prompt to use. If none is provided, will choose one | |
based on dialect. Defaults to None. See Prompt section below for more. | |
k: The number of results per select statement to return. Defaults to 5. | |
Returns: | |
A chain that takes in a question and generates a SQL query that answers | |
that question. | |
Example: | |
.. code-block:: python | |
# pip install -U langchain langchain-community langchain-openai | |
from langchain_openai import ChatOpenAI | |
from langchain.chains import create_sql_query_chain | |
from langchain_community.utilities import SQLDatabase | |
db = SQLDatabase.from_uri("sqlite:///Chinook.db") | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
chain = create_sql_query_chain(llm, db) | |
response = chain.invoke({"question": "How many employees are there"}) | |
Prompt: | |
If no prompt is provided, a default prompt is selected based on the SQLDatabase dialect. If one is provided, it must support input variables: | |
* input: The user question plus suffix "\nSQLQuery: " is passed here. | |
* top_k: The number of results per select statement (the `k` argument to | |
this function) is passed in here. | |
* table_info: Table definitions and sample rows are passed in here. If the | |
user specifies "table_names_to_use" when invoking chain, only those | |
will be included. Otherwise, all tables are included. | |
* dialect (optional): If dialect input variable is in prompt, the db | |
dialect will be passed in here. | |
Here's an example prompt: | |
.. code-block:: python | |
from langchain_core.prompts import PromptTemplate | |
template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. | |
Use the following format: | |
Question: "Question here" | |
SQLQuery: "SQL Query to run" | |
SQLResult: "Result of the SQLQuery" | |
Answer: "Final answer here" | |
Only use the following tables: | |
{table_info}. | |
Question: {input}''' | |
prompt = PromptTemplate.from_template(template) | |
""" # noqa: E501 | |
if prompt is not None: | |
prompt_to_use = prompt | |
elif db.dialect in SQL_PROMPTS: | |
prompt_to_use = SQL_PROMPTS[db.dialect] | |
else: | |
prompt_to_use = PROMPT | |
if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables): | |
raise ValueError( | |
f"Prompt must have input variables: 'input', 'top_k', " | |
f"'table_info'. Received prompt with input variables: " | |
f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}" | |
) | |
if "dialect" in prompt_to_use.input_variables: | |
prompt_to_use = prompt_to_use.partial(dialect=db.dialect) | |
inputs = { | |
"input": lambda x: x["question"] + "\nSQLQuery: ", | |
"table_info": lambda x: db.get_table_info( | |
table_names=x.get("table_names_to_use") | |
), | |
} | |
return ( | |
RunnablePassthrough.assign(**inputs) # type: ignore | |
| ( | |
lambda x: { | |
k: v | |
for k, v in x.items() | |
if k not in ("question", "table_names_to_use") | |
} | |
) | |
| prompt_to_use.partial(top_k=str(k)) | |
| llm.bind(stop=["\nSQLResult:"]) | |
| StrOutputParser() | |
| _strip | |
) | |