diff --git a/src/irclib.py b/src/irclib.py index 95b555733..1157bc10a 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -30,6 +30,7 @@ import re import copy import time +import enum import random import base64 import textwrap @@ -389,14 +390,108 @@ class ChannelState(utils.python.Object): Batch = collections.namedtuple('Batch', 'type arguments messages') +class IrcStateFsm(object): + '''Finite State Machine keeping track of what part of the connection + initialization we are in.''' + __slots__ = ('state',) + + @enum.unique + class States(enum.Enum): + UNINITIALIZED = 10 + '''Nothing received yet (except server notices)''' + + INIT_CAP_NEGOTIATION = 20 + '''Sent CAP LS, did not send CAP END yet''' + + INIT_SASL = 30 + '''In an AUTHENTICATE session''' + + INIT_WAITING_MOTD = 50 + '''Waiting for start of MOTD''' + + INIT_MOTD = 60 + '''Waiting for end of MOTD''' + + CONNECTED = 70 + '''Normal state of the connections''' + + CONNECTED_SASL = 80 + '''Doing SASL authentication in the middle of a connection.''' + + def __init__(self): + self.reset() + + def reset(self): + self.state = self.States.UNINITIALIZED + + def _transition(self, to_state, expected_from=None): + if expected_from is None or self.state in expected_from: + log.debug('transition from %s to %s', self.state, to_state) + self.state = to_state + else: + raise ValueError('unexpected transition to %s while in state %s' % + (to_state, self.state)) + + def expect_state(self, expected_states): + if self.state not in expected_states: + raise ValueError(('Connection in state %s, but expected to be ' + 'in state %s') % (self.state, expected_states)) + + def on_init_messages_sent(self): + '''As soon as USER/NICK/CAP LS are sent''' + self._transition(self.States.INIT_CAP_NEGOTIATION, [ + self.States.UNINITIALIZED, + ]) + + def on_sasl_cap(self): + '''Whenever we see the 'sasl' capability in a CAP LS response''' + if self.state == self.States.INIT_CAP_NEGOTIATION: + self._transition(self.States.INIT_SASL) + elif self.state == self.States.CONNECTED: + self._transition(self.States.CONNECTED_SASL) + else: + raise ValueError('Got sasl cap while in state %s' % self.state) + + def on_sasl_auth_finished(self): + '''When sasl auth either succeeded or failed.''' + if self.state == self.States.INIT_SASL: + self._transition(self.States.INIT_CAP_NEGOTIATION) + elif self.state == self.States.CONNECTED_SASL: + self._transition(self.States.CONNECTED) + else: + raise ValueError('Finished SASL auth while in state %s' % self.state) + + def on_cap_end(self): + '''When we send CAP END''' + self._transition(self.States.INIT_WAITING_MOTD, [ + self.States.INIT_CAP_NEGOTIATION, + ]) + + def on_start_motd(self): + '''On 375 (RPL_MOTDSTART)''' + self._transition(self.States.INIT_MOTD, [ + self.States.INIT_CAP_NEGOTIATION, + self.States.INIT_WAITING_MOTD, + ]) + + def on_end_motd(self): + '''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)''' + self._transition(self.States.CONNECTED, [ + self.States.INIT_CAP_NEGOTIATION, + self.States.INIT_WAITING_MOTD, + self.States.INIT_MOTD + ]) + class IrcState(IrcCommandDispatcher, log.Firewalled): """Maintains state of the Irc connection. Should also become smarter. """ __firewalled__ = {'addMsg': None} def __init__(self, history=None, supported=None, nicksToHostmasks=None, channels=None, + capabilities_req=None, capabilities_ack=None, capabilities_nak=None, capabilities_ls=None): + self.fsm = IrcStateFsm() if history is None: history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength()) if supported is None: @@ -405,6 +500,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): nicksToHostmasks = ircutils.IrcDict() if channels is None: channels = ircutils.IrcDict() + self.capabilities_req = capabilities_req or set() self.capabilities_ack = capabilities_ack or set() self.capabilities_nak = capabilities_nak or set() self.capabilities_ls = capabilities_ls or {} @@ -417,6 +513,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): def reset(self): """Resets the state to normal, unconnected state.""" + self.fsm.reset() self.history.reset() self.history.resize(conf.supybot.protocols.irc.maxHistoryLength()) self.ircd = None @@ -424,6 +521,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): self.supported.clear() self.nicksToHostmasks.clear() self.batches = {} + self.capabilities_req = set() self.capabilities_ack = set() self.capabilities_nak = set() self.capabilities_ls = {} @@ -1115,6 +1213,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendAuthenticationMessages() + self.state.fsm.on_init_messages_sent() + def sendAuthenticationMessages(self): # Notes: # * using sendMsg instead of queueMsg because these messages cannot @@ -1135,10 +1235,43 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendMsg(ircmsgs.user(self.ident, self.user)) + def capUpkeep(self): + self.state.fsm.expect_state([ + # Normal CAP ACK / CAP NAK during cap negotiation + IrcStateFsm.States.INIT_CAP_NEGOTIATION, + # CAP ACK / CAP NAK after a CAP NEW (probably) + IrcStateFsm.States.CONNECTED, + ]) + + capabilities_responded = (self.state.capabilities_ack | + self.state.capabilities_nak) + if not capabilities_responded <= self.state.capabilities_req: + log.error('Server responded with unrequested ACK/NAK ' + 'capabilities: req=%r, ack=%r, nak=%r', + self.state.capabilities_req, + self.state.capabilities_ack, + self.state.capabilities_nak) + self.driver.reconnect() + elif capabilities_responded == self.state.capabilities_req: + log.debug('Got all capabilities ACKed/NAKed') + # We got all the capabilities we asked for + if 'sasl' in self.state.capabilities_ack: + if self.state.fsm.state in [ + IrcStateFsm.States.INIT_CAP_NEGOTIATION, + IrcStateFsm.States.CONNECTED]: + self._maybeStartSasl() + else: + pass # Already in the middle of a SASL auth + else: + self.endCapabilityNegociation() + else: + log.debug('Waiting for ACK/NAK of capabilities: %r', + self.state.capabilities_req - capabilities_responded) + pass # Do nothing, we'll get more + def endCapabilityNegociation(self): - if not self.capNegociationEnded: - self.capNegociationEnded = True - self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) + self.state.fsm.on_cap_end() + self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) def sendSaslString(self, string): for chunk in ircutils.authenticate_generator(string): @@ -1146,6 +1279,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled): args=(chunk,))) def tryNextSaslMechanism(self): + self.state.fsm.expect_state([ + IrcStateFsm.States.INIT_SASL, + IrcStateFsm.States.CONNECTED_SASL, + ]) if self.sasl_next_mechanisms: self.sasl_current_mechanism = self.sasl_next_mechanisms.pop(0) self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', @@ -1155,15 +1292,30 @@ class Irc(IrcCommandDispatcher, log.Firewalled): 'aborting connection.') else: self.sasl_current_mechanism = None - self.endCapabilityNegociation() + self.state.fsm.on_sasl_auth_finished() + if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: + self.endCapabilityNegociation() - def filterSaslMechanisms(self, available): - available = set(map(str.lower, available)) - self.sasl_next_mechanisms = [ - x for x in self.sasl_next_mechanisms - if x.lower() in available] + def _maybeStartSasl(self): + if not self.sasl_authenticated and \ + 'sasl' in self.state.capabilities_ack: + self.state.fsm.on_sasl_cap() + assert 'sasl' in self.state.capabilities_ls, ( + 'Got "CAP ACK sasl" without receiving "CAP LS sasl" or ' + '"CAP NEW sasl" first.') + s = self.state.capabilities_ls['sasl'] + if s is not None: + available = set(map(str.lower, s.split(','))) + self.sasl_next_mechanisms = [ + x for x in self.sasl_next_mechanisms + if x.lower() in available] + self.tryNextSaslMechanism() def doAuthenticate(self, msg): + self.state.fsm.expect_state([ + IrcStateFsm.States.INIT_SASL, + IrcStateFsm.States.CONNECTED_SASL, + ]) if not self.authenticate_decoder: self.authenticate_decoder = ircutils.AuthenticateDecoder() self.authenticate_decoder.feed(msg) @@ -1265,7 +1417,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled): def do903(self, msg): log.info('%s: SASL authentication successful', self.network) self.sasl_authenticated = True - self.endCapabilityNegociation() + self.state.fsm.on_sasl_auth_finished() + if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: + self.endCapabilityNegociation() def do904(self, msg): log.warning('%s: SASL authentication failed (mechanism: %s)', @@ -1301,10 +1455,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.network, caps) self.state.capabilities_ack.update(caps) - if 'sasl' in caps: - self.tryNextSaslMechanism() - else: - self.endCapabilityNegociation() + self.capUpkeep() + def doCapNak(self, msg): if len(msg.args) != 3: log.warning('Bad CAP NAK from server: %r', msg) @@ -1314,7 +1466,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_nak.update(caps) log.warning('%s: Server refused capabilities: %L', self.network, caps) - self.endCapabilityNegociation() + self.capUpkeep() + def _addCapabilities(self, capstring): for item in capstring.split(): while item.startswith(('=', '~')): @@ -1324,6 +1477,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_ls[cap] = value else: self.state.capabilities_ls[item] = None + def doCapLs(self, msg): if len(msg.args) == 4: # Multi-line LS @@ -1333,12 +1487,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self._addCapabilities(msg.args[3]) elif len(msg.args) == 3: # End of LS self._addCapabilities(msg.args[2]) - - if 'sasl' in self.state.capabilities_ls: - s = self.state.capabilities_ls['sasl'] - if s is not None: - self.filterSaslMechanisms(set(s.split(','))) - + self.state.fsm.expect_state([ + # Normal case: + IrcStateFsm.States.INIT_CAP_NEGOTIATION, + # Should only happen if a plugin sends a CAP LS (which they + # shouldn't do): + IrcStateFsm.States.CONNECTED, + IrcStateFsm.States.CONNECTED_SASL, + ]) # Normally at this point, self.state.capabilities_ack should be # empty; but let's just make sure we're not requesting the same # caps twice for no reason. @@ -1356,6 +1512,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): else: log.warning('Bad CAP LS from server: %r', msg) return + def doCapDel(self, msg): if len(msg.args) != 3: log.warning('Bad CAP DEL from server: %r', msg) @@ -1374,18 +1531,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_ack.remove(cap) except KeyError: pass + def doCapNew(self, msg): + # Note that in theory, this method may be called at any time, even + # before CAP END (or even before the initial CAP LS). if len(msg.args) != 3: log.warning('Bad CAP NEW from server: %r', msg) return caps = msg.args[2].split() assert caps, 'Empty list of capabilities' self._addCapabilities(msg.args[2]) - if not self.sasl_authenticated and 'sasl' in self.state.capabilities_ls: - self.resetSasl() - s = self.state.capabilities_ls['sasl'] - if s is not None: - self.filterSaslMechanisms(set(s.split(','))) common_supported_unrequested_capabilities = ( set(self.state.capabilities_ls) & self.REQUEST_CAPABILITIES - @@ -1394,6 +1549,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self._requestCaps(common_supported_unrequested_capabilities) def _requestCaps(self, caps): + self.state.capabilities_req |= caps + caps = ' '.join(sorted(caps)) # textwrap works here because in ASCII, all chars are 1 bytes: cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :')) @@ -1474,6 +1631,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.outstandingPing = False def do376(self, msg): + self.state.fsm.on_end_motd() log.info('Got end of MOTD from %s', self.server) self.afterConnect = True # Let's reset nicks in case we had to use a weird one. diff --git a/test/test_irclib.py b/test/test_irclib.py index bcd2719a8..63fff6db2 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -832,6 +832,8 @@ class SaslTestCase(SupyTestCase): while self.irc.takeMsg(): pass + self.irc.feedMsg(ircmsgs.IrcMsg(command='422')) # ERR_NOMOTD + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'NEW', 'sasl=EXTERNAL')))