diff --git a/ChatGPT/config.py b/ChatGPT/config.py index 056e1c8..442c021 100644 --- a/ChatGPT/config.py +++ b/ChatGPT/config.py @@ -74,6 +74,19 @@ conf.registerGlobalValue( ), ) +conf.registerGlobalValue( + ChatGPT, + "base_url", + registry.String( + "https://api.openai.com/v1/", + _( + """ + API server, for using a non-OpenAI model which has a compatible API, default: "https://api.openai.com/v1/" + """ + ), + ), +) + conf.registerChannelValue( ChatGPT, "prompt", @@ -100,6 +113,7 @@ conf.registerChannelValue( ), ) + conf.registerChannelValue( ChatGPT, "reply_intact", diff --git a/ChatGPT/plugin.py b/ChatGPT/plugin.py index 1f36f3d..3ec2e80 100644 --- a/ChatGPT/plugin.py +++ b/ChatGPT/plugin.py @@ -32,7 +32,8 @@ from supybot import utils, plugins, ircutils, callbacks from supybot.commands import * from supybot.i18n import PluginInternationalization import re -import openai +#import openai +from openai import OpenAI _ = PluginInternationalization("ChatGPT") @@ -57,26 +58,41 @@ class ChatGPT(callbacks.Plugin): return if self.registryValue("nick_include", msg.channel): text = "%s: %s" % (msg.nick, text) + + # Initialize client + client = OpenAI( + api_key=self.registryValue("api_key"), + base_url=self.registryValue("base_url") + ) self.history.setdefault(channel, None) max_history = self.registryValue("max_history", msg.channel) prompt = self.registryValue("prompt", msg.channel).replace("$botnick", irc.nick) + if not self.history[channel] or max_history < 1: self.history[channel] = [] - openai.api_key = self.registryValue("api_key") - completion = openai.chat.completions.create( - model=self.registryValue("model", msg.channel), - messages=self.history[channel][-max_history:] - + [ + + model_name = self.registryValue("model", msg.channel) + + # Base request parameters + request_params = { + "model": model_name, + "messages": self.history[channel][-max_history:] + [ {"role": "system", "content": prompt}, - {"role": "user", "content": text}, + {"role": "user", "content": text} ], - temperature=self.registryValue("temperature", msg.channel), - top_p=self.registryValue("top_p", msg.channel), - max_tokens=self.registryValue("max_tokens", msg.channel), - presence_penalty=self.registryValue("presence_penalty", msg.channel), - frequency_penalty=self.registryValue("frequency_penalty", msg.channel), - user=msg.nick, - ) + "temperature": self.registryValue("temperature", msg.channel), + "top_p": self.registryValue("top_p", msg.channel), + "max_tokens": self.registryValue("max_tokens", msg.channel), + "presence_penalty": self.registryValue("presence_penalty", msg.channel), + "user": msg.nick, + } + + # Gemini models fail if frequency_penalty is included + if "gemini" not in model_name.lower(): + request_params["frequency_penalty"] = self.registryValue("frequency_penalty", msg.channel) + + completion = client.chat.completions.create(**request_params) + if self.registryValue("nick_strip", msg.channel): content = re.sub( r"^%s: " % (irc.nick), "", completion.choices[0].message.content