Spaces:
Runtime error
Runtime error
File size: 4,406 Bytes
63deadc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""Pass input through a moderation endpoint."""
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import check_package_version, get_from_dict_or_env
from langchain.chains.base import Chain
class OpenAIModerationChain(Chain):
"""Pass input through a moderation endpoint.
To use, you should have the ``openai`` python package installed, and the
environment variable ``OPENAI_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.chains import OpenAIModerationChain
moderation = OpenAIModerationChain()
"""
client: Any #: :meta private:
async_client: Any #: :meta private:
model_name: Optional[str] = None
"""Moderation model name to use."""
error: bool = False
"""Whether or not to error if bad content was found."""
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
openai_api_key: Optional[str] = None
openai_organization: Optional[str] = None
_openai_pre_1_0: bool = Field(default=None)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
openai_organization = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
try:
import openai
openai.api_key = openai_api_key
if openai_organization:
openai.organization = openai_organization
values["_openai_pre_1_0"] = False
try:
check_package_version("openai", gte_version="1.0")
except ValueError:
values["_openai_pre_1_0"] = True
if values["_openai_pre_1_0"]:
values["client"] = openai.Moderation
else:
values["client"] = openai.OpenAI()
values["async_client"] = openai.AsyncOpenAI()
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _moderate(self, text: str, results: Any) -> str:
if self._openai_pre_1_0:
condition = results["flagged"]
else:
condition = results.flagged
if condition:
error_str = "Text was found that violates OpenAI's content policy."
if self.error:
raise ValueError(error_str)
else:
return error_str
return text
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
if self._openai_pre_1_0:
results = self.client.create(text)
output = self._moderate(text, results["results"][0])
else:
results = self.client.moderations.create(input=text)
output = self._moderate(text, results.results[0])
return {self.output_key: output}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self._openai_pre_1_0:
return await super()._acall(inputs, run_manager=run_manager)
text = inputs[self.input_key]
results = await self.async_client.moderations.create(input=text)
output = self._moderate(text, results.results[0])
return {self.output_key: output}
|