diff --git a/ChatGPT/plugin.py b/ChatGPT/plugin.py index dbf5374..3ec2e80 100644 --- a/ChatGPT/plugin.py +++ b/ChatGPT/plugin.py @@ -61,29 +61,38 @@ class ChatGPT(callbacks.Plugin): # Initialize client client = OpenAI( - api_key=self.registryValue("api_key"), - base_url=self.registryValue("base_url") - ) + api_key=self.registryValue("api_key"), + base_url=self.registryValue("base_url") + ) self.history.setdefault(channel, None) max_history = self.registryValue("max_history", msg.channel) prompt = self.registryValue("prompt", msg.channel).replace("$botnick", irc.nick) + if not self.history[channel] or max_history < 1: self.history[channel] = [] - completion = client.chat.completions.create( - model=self.registryValue("model", msg.channel), - messages=self.history[channel][-max_history:] + [ + model_name = self.registryValue("model", msg.channel) + + # Base request parameters + request_params = { + "model": model_name, + "messages": self.history[channel][-max_history:] + [ {"role": "system", "content": prompt}, {"role": "user", "content": text} ], - temperature=self.registryValue("temperature", msg.channel), - top_p=self.registryValue("top_p", msg.channel), - max_tokens=self.registryValue("max_tokens", msg.channel), - presence_penalty=self.registryValue("presence_penalty", msg.channel), - frequency_penalty=self.registryValue("frequency_penalty", msg.channel), - user=msg.nick, - ) + "temperature": self.registryValue("temperature", msg.channel), + "top_p": self.registryValue("top_p", msg.channel), + "max_tokens": self.registryValue("max_tokens", msg.channel), + "presence_penalty": self.registryValue("presence_penalty", msg.channel), + "user": msg.nick, + } + # Gemini models fail if frequency_penalty is included + if "gemini" not in model_name.lower(): + request_params["frequency_penalty"] = self.registryValue("frequency_penalty", msg.channel) + + completion = client.chat.completions.create(**request_params) + if self.registryValue("nick_strip", msg.channel): content = re.sub( r"^%s: " % (irc.nick), "", completion.choices[0].message.content