rohan13's picture
Flowise Changes
4114d85
import { BaseLanguageModel } from 'langchain/base_language'
import { ICommonObject, INode, INodeData, INodeParams, PromptRetriever } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { MultiPromptChain } from 'langchain/chains'
class MultiPromptChain_Chains implements INode {
label: string
name: string
type: string
icon: string
category: string
baseClasses: string[]
description: string
inputs: INodeParams[]
constructor() {
this.label = 'Multi Prompt Chain'
this.name = 'multiPromptChain'
this.type = 'MultiPromptChain'
this.icon = 'chain.svg'
this.category = 'Chains'
this.description = 'Chain automatically picks an appropriate prompt from multiple prompt templates'
this.baseClasses = [this.type, ...getBaseClasses(MultiPromptChain)]
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel'
},
{
label: 'Prompt Retriever',
name: 'promptRetriever',
type: 'PromptRetriever',
list: true
}
]
}
async init(nodeData: INodeData): Promise<any> {
const model = nodeData.inputs?.model as BaseLanguageModel
const promptRetriever = nodeData.inputs?.promptRetriever as PromptRetriever[]
const promptNames = []
const promptDescriptions = []
const promptTemplates = []
for (const prompt of promptRetriever) {
promptNames.push(prompt.name)
promptDescriptions.push(prompt.description)
promptTemplates.push(prompt.systemMessage)
}
const chain = MultiPromptChain.fromPrompts(model, promptNames, promptDescriptions, promptTemplates, undefined, {
verbose: process.env.DEBUG === 'true' ? true : false
} as any)
return chain
}
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
const chain = nodeData.instance as MultiPromptChain
const obj = { input }
if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
const res = await chain.call(obj, [handler])
return res?.text
} else {
const res = await chain.call(obj)
return res?.text
}
}
}
module.exports = { nodeClass: MultiPromptChain_Chains }