File size: 2,126 Bytes
4114d85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import { LLM, BaseLLMParams } from 'langchain/llms/base'

export interface CohereInput extends BaseLLMParams {
    /** Sampling temperature to use */
    temperature?: number

    /**
     * Maximum number of tokens to generate in the completion.
     */
    maxTokens?: number

    /** Model to use */
    model?: string

    apiKey?: string
}

export class Cohere extends LLM implements CohereInput {
    temperature = 0

    maxTokens = 250

    model: string

    apiKey: string

    constructor(fields?: CohereInput) {
        super(fields ?? {})

        const apiKey = fields?.apiKey ?? undefined

        if (!apiKey) {
            throw new Error('Please set the COHERE_API_KEY environment variable or pass it to the constructor as the apiKey field.')
        }

        this.apiKey = apiKey
        this.maxTokens = fields?.maxTokens ?? this.maxTokens
        this.temperature = fields?.temperature ?? this.temperature
        this.model = fields?.model ?? this.model
    }

    _llmType() {
        return 'cohere'
    }

    /** @ignore */
    async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
        const { cohere } = await Cohere.imports()

        cohere.init(this.apiKey)

        // Hit the `generate` endpoint on the `large` model
        const generateResponse = await this.caller.callWithOptions({ signal: options.signal }, cohere.generate.bind(cohere), {
            prompt,
            model: this.model,
            max_tokens: this.maxTokens,
            temperature: this.temperature,
            end_sequences: options.stop
        })
        try {
            return generateResponse.body.generations[0].text
        } catch {
            throw new Error('Could not parse response.')
        }
    }

    /** @ignore */
    static async imports(): Promise<{
        cohere: typeof import('cohere-ai')
    }> {
        try {
            const { default: cohere } = await import('cohere-ai')
            return { cohere }
        } catch (e) {
            throw new Error('Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`')
        }
    }
}