diff --git a/src/irclib.py b/src/irclib.py index 5c505a79f..a2a466897 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -32,6 +32,7 @@ import copy import time import random import base64 +import collections try: from ecdsa import SigningKey, BadDigestError @@ -341,6 +342,7 @@ class ChannelState(utils.python.Object): ret = ret and getattr(self, name) == getattr(other, name) return ret +Batch = collections.namedtuple('Batch', 'type arguments messages') class IrcState(IrcCommandDispatcher, log.Firewalled): """Maintains state of the Irc connection. Should also become smarter. @@ -366,6 +368,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): self.history = history self.channels = channels self.nicksToHostmasks = nicksToHostmasks + self.batches = {} def reset(self): """Resets the state to normal, unconnected state.""" @@ -374,6 +377,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): self.supported.clear() self.nicksToHostmasks.clear() self.history.resize(conf.supybot.protocols.irc.maxHistoryLength()) + self.batches = {} def __reduce__(self): return (self.__class__, (self.history, self.supported, @@ -383,7 +387,8 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): return self.history == other.history and \ self.channels == other.channels and \ self.supported == other.supported and \ - self.nicksToHostmasks == other.nicksToHostmasks + self.nicksToHostmasks == other.nicksToHostmasks and \ + self.batches == other.batches def __ne__(self, other): return not self == other @@ -393,6 +398,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): ret.history = copy.deepcopy(self.history) ret.nicksToHostmasks = copy.deepcopy(self.nicksToHostmasks) ret.channels = copy.deepcopy(self.channels) + ret.batches = copy.deepcopy(self.batches) return ret def addMsg(self, irc, msg): @@ -400,6 +406,11 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): self.history.append(msg) if ircutils.isUserHostmask(msg.prefix) and not msg.command == 'NICK': self.nicksToHostmasks[msg.nick] = msg.prefix + if 'batch' in msg.server_tags: + batch = msg.server_tags['batch'] + assert batch in self.batches, \ + 'Server references undeclared batch %s' % batch + self.batches[batch].messages.append(msg) method = self.dispatchCommand(msg.command) if method is not None: method(irc, msg) @@ -647,6 +658,18 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): for channel in self.channels.values(): channel.replaceUser(oldNick, newNick) + def doBatch(self, irc, msg): + batch_name = msg.args[0][1:] + if msg.args[0].startswith('+'): + batch_type = msg.args[1] + batch_arguments = tuple(msg.args[2:]) + self.batches[batch_name] = Batch(type=batch_type, + arguments=batch_arguments, messages=[]) + elif msg.args[0].startswith('-'): + batch = self.batches.pop(batch_name) + msg.tag('batch', batch) + else: + assert False, msg.args[0] ### @@ -967,7 +990,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): REQUEST_CAPABILITIES = set(['account-notify', 'extended-join', 'multi-prefix', 'metadata-notify', 'account-tag', 'userhost-in-names', 'invite-notify', 'server-time', - 'chghost']) + 'chghost', 'batch']) def _queueConnectMessages(self): if self.zombie: diff --git a/src/ircmsgs.py b/src/ircmsgs.py index 0ac502c48..a2062215a 100644 --- a/src/ircmsgs.py +++ b/src/ircmsgs.py @@ -186,6 +186,7 @@ class IrcMsg(object): assert all(ircutils.isValidArgument, args), args self.args = args self.time = None + self.server_tags = {} self.args = tuple(self.args) if isUserHostmask(self.prefix): (self.nick,self.user,self.host)=ircutils.splitHostmask(self.prefix) diff --git a/test/test_irclib.py b/test/test_irclib.py index f4ae0b52b..049454ac2 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -482,7 +482,31 @@ class IrcTestCase(SupyTestCase): self.irc.feedMsg(ircmsgs.IrcMsg(':someuser QUIT')) finally: self.irc.removeCallback(c.name()) - self.assertEqual(c.channels_set, ircutils.IrcSet({'#foo', '#bar'})) + self.assertEqual(c.channels_set, ircutils.IrcSet(['#foo', '#bar'])) + + def testBatch(self): + self.irc.reset() + self.irc.feedMsg(ircmsgs.IrcMsg(':someuser1 JOIN #foo')) + self.irc.feedMsg(ircmsgs.IrcMsg(':host BATCH +name netjoin')) + m1 = ircmsgs.IrcMsg('@batch=name :someuser2 JOIN #foo') + self.irc.feedMsg(m1) + self.irc.feedMsg(ircmsgs.IrcMsg(':someuser3 JOIN #foo')) + m2 = ircmsgs.IrcMsg('@batch=name :someuser4 JOIN #foo') + self.irc.feedMsg(m2) + class Callback(irclib.IrcCallback): + batch = None + def name(self): + return 'testcallback' + def doBatch(self2, irc, msg): + self2.batch = msg.tagged('batch') + c = Callback() + self.irc.addCallback(c) + try: + self.irc.feedMsg(ircmsgs.IrcMsg(':host BATCH -name')) + finally: + self.irc.removeCallback(c.name()) + self.assertEqual(c.batch, irclib.Batch('netjoin', (), [m1, m2])) +