diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index 354fc1f7c..8b25f004d 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -39,9 +39,10 @@ import os import sys import time import errno -import threading import select import socket +import asyncio +import threading import ipaddress @@ -58,9 +59,8 @@ except: class SSLError(Exception): pass + class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): - _instances = [] - _selecting = threading.Lock() def __init__(self, irc): assert irc is not None self.irc = irc @@ -108,26 +108,19 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): # hasn't finished yet. We'll keep track of how many we get. if e is None or e.args[0] != 11 or self.eagains > 120: drivers.log.disconnect(self.currentServer, e) - if self in self._instances: - self._instances.remove(self) try: self.conn.close() except: pass self.connected = False if self.irc is None: - # This driver is dead already, but we're still running because - # of select() running in an other driver's thread that started - # before this one died and stil holding a reference to this - # instance. - # Just return, and we should never be called again. return self.scheduleReconnect() else: log.debug('Got EAGAIN, current count: %s.', self.eagains) self.eagains += 1 - def _sendIfMsgs(self): + async def _sendIfMsgs(self): if not self.connected: return if not self.zombie: @@ -137,71 +130,34 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): del msgs[-1] self.outbuffer += ''.join(map(str, msgs)) if self.outbuffer: + loop = asyncio.get_event_loop() try: - if minisix.PY2: - sent = self.conn.send(self.outbuffer) - else: - sent = self.conn.send(self.outbuffer.encode()) - self.outbuffer = self.outbuffer[sent:] + await loop.sock_sendall(self.conn, self.outbuffer.encode()) + self.outbuffer = '' self.eagains = 0 except socket.error as e: self._handleSocketError(e) if self.zombie and not self.outbuffer: self._reallyDie() - @classmethod - def _select(cls): - timeout = conf.supybot.drivers.poll() - try: - if not cls._selecting.acquire(blocking=False): - # there's already a thread running this code, abort. - return - for inst in cls._instances: - # Do not use a list comprehension here, we have to edit the list - # and not to reassign it. - if not inst.connected or \ - (minisix.PY3 and inst.conn._closed) or \ - (minisix.PY2 and - inst.conn._sock.__class__ is socket._closedsocket): - cls._instances.remove(inst) - elif inst.conn.fileno() == -1: - inst.reconnect() - if not cls._instances: - return - rlist, wlist, xlist = select.select([x.conn for x in cls._instances], - [], [], timeout) - for instance in cls._instances: - if instance.conn in rlist: - instance._read() - except select.error as e: - if e.args[0] != errno.EINTR: - # 'Interrupted system call' - raise - finally: - cls._selecting.release() - for instance in cls._instances: - if instance.irc and not instance.irc.zombie: - instance._sendIfMsgs() - - - def run(self): + async def run(self): now = time.time() if self.nextReconnectTime is not None and now > self.nextReconnectTime: self.reconnect() elif self.writeCheckTime is not None and now > self.writeCheckTime: self._checkAndWriteOrReconnect() if not self.connected: - # We sleep here because otherwise, if we're the only driver, we'll - # spin at 100% CPU while we're disconnected. - time.sleep(conf.supybot.drivers.poll()) return - self._sendIfMsgs() - self._select() + await asyncio.gather( + self._sendIfMsgs(), + self._read(), + ) - def _read(self): + async def _read(self): """Called by _select() when we can read data.""" + loop = asyncio.get_event_loop() try: - new_data = self.conn.recv(1024) + new_data = await loop.sock_recv(self.conn, 1024) if not new_data: # Socket was closed self._handleSocketError(None) @@ -241,7 +197,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): self._handleSocketError(e) return if self.irc and not self.irc.zombie: - self._sendIfMsgs() + await self._sendIfMsgs() def connect(self, **kwargs): self.reconnect(reset=False, **kwargs) @@ -252,8 +208,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): if self.connected: self.onDisconnect() drivers.log.reconnect(self.irc.network) - if self in self._instances: - self._instances.remove(self) try: self.conn.shutdown(socket.SHUT_RDWR) except: # "Transport endpoint not connected" @@ -314,12 +268,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): return # We allow more time for the connect here, since it might take longer. # At least 10 seconds. - self.conn.settimeout(max(10, conf.supybot.drivers.poll()*10)) try: # Connect before SSL, otherwise SSL is disabled if we use SOCKS. # See http://stackoverflow.com/q/16136916/539465 - self.conn.connect( - (address, self.currentServer.port)) + self.conn.connect((address, self.currentServer.port)) if network_config.ssl() or \ self.currentServer.force_tls_verification: self.starttls() @@ -343,8 +295,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): '') % self.irc.network) - conf.supybot.drivers.poll.addCallback(self.setTimeout) - self.setTimeout() self.connected = True self.resetDelay() except socket.error as e: @@ -361,13 +311,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): drivers.log.connectError(self.currentServer, e) self.scheduleReconnect() return - self._instances.append(self) - - def setTimeout(self): - try: - self.conn.settimeout(conf.supybot.drivers.poll()) - except Exception: - drivers.log.exception('Could not set socket timeout:') def _checkAndWriteOrReconnect(self): self.writeCheckTime = None @@ -393,9 +336,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): self.nextReconnectTime = when def die(self): - if self in self._instances: - self._instances.remove(self) - conf.supybot.drivers.poll.removeCallback(self.setTimeout) self.zombie = True if self.nextReconnectTime is not None: self.nextReconnectTime = None diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index 453901587..e8e7f1333 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -35,6 +35,7 @@ Contains various drivers (network, file, and otherwise) for using IRC objects. import time import socket +import asyncio from collections import namedtuple from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils @@ -146,15 +147,59 @@ def remove(name): """Removes the driver with the given name from the loop.""" _deadDrivers.add(name) + +def _loop_exception_handler(loop, context): + if 'Task was destroyed but it is pending' in context['message']: + # This happens when cancelling tasks because of KeyboardInterrupt. + # There is really nothing we can do about this, so it's + # pointless to log it. + return + log.error('Exception in drivers loop: %s', context['message']) + def run(): """Runs the whole driver loop.""" + loop = asyncio.new_event_loop() + loop.set_exception_handler(_loop_exception_handler) + driver_names = [] + futures = [] + coroutines = [] # Used to cleanup on shutdown to avoid warnings for (name, driver) in _drivers.items(): + if name not in _deadDrivers: + try: + coroutine = driver.run() + future = asyncio.ensure_future(coroutine, loop=loop) + except Exception: + log.exception('Exception in drivers.run for driver %s:', name) + continue + driver_names.append(name) + coroutines.append(coroutine) + futures.append(future) + + gather_task = asyncio.gather(*futures, return_exceptions=True, loop=loop) + timeout_gather_task = asyncio.wait_for( + gather_task, + timeout=conf.supybot.drivers.poll()) + coroutines.append(timeout_gather_task) + try: + loop.run_until_complete(timeout_gather_task) + except KeyboardInterrupt: + # Cleanup all the objects so they don't throw warnings. + gather_task.cancel() + for future in futures: + future.cancel() + for coroutine in coroutines: + coroutine.close() + raise + except asyncio.TimeoutError: + pass + + for (name, future) in zip(driver_names, futures): try: - if name not in _deadDrivers: - driver.run() + future.result() # Raises an exception if driver.run() did. except: - log.exception('Uncaught exception in in drivers.run:') + log.exception('Uncaught exception in drivers.run:') _deadDrivers.add(name) + for name in _deadDrivers: try: driver = _drivers[name] diff --git a/src/schedule.py b/src/schedule.py index 2904ad68e..fed3ef159 100644 --- a/src/schedule.py +++ b/src/schedule.py @@ -137,7 +137,7 @@ class Schedule(drivers.IrcDriver): removePeriodicEvent = removeEvent - def run(self): + async def run(self): if len(drivers._drivers) == 1 and not world.testing: log.error('Schedule is the only remaining driver, ' 'why do we continue to live?') diff --git a/test/test_schedule.py b/test/test_schedule.py index 6b86399ee..1c6ba7d54 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -31,6 +31,7 @@ from supybot.test import * import time +import asyncio import supybot.schedule as schedule @@ -43,6 +44,11 @@ class FakeSchedule(schedule.Schedule): def name(self): return 'FakeSchedule' + +def run(sched): + loop = asyncio.new_event_loop() + loop.run_until_complete(sched.run()) + class TestSchedule(SupyTestCase): def testSchedule(self): sched = FakeSchedule() @@ -55,14 +61,14 @@ class TestSchedule(SupyTestCase): sched.addEvent(add10, time.time() + 3) sched.addEvent(add1, time.time() + 1) timeFastForward(1.2) - sched.run() + run(sched) self.assertEqual(i[0], 1) timeFastForward(1.9) - sched.run() + run(sched) self.assertEqual(i[0], 11) sched.addEvent(add10, time.time() + 3, 'test') - sched.run() + run(sched) self.assertEqual(i[0], 11) sched.removeEvent('test') self.assertEqual(i[0], 11) @@ -77,10 +83,10 @@ class TestSchedule(SupyTestCase): n = sched.addEvent(inc, time.time() + 1) sched.rescheduleEvent(n, time.time() + 3) timeFastForward(1.2) - sched.run() + run(sched) self.assertEqual(i[0], 0) timeFastForward(2) - sched.run() + run(sched) self.assertEqual(i[0], 1) def testPeriodic(self): @@ -90,20 +96,20 @@ class TestSchedule(SupyTestCase): i[0] += 1 n = sched.addPeriodicEvent(inc, 1, name='test_periodic') timeFastForward(0.6) - sched.run() # 0.6 + run(sched) # 0.6 self.assertEqual(i[0], 1) timeFastForward(0.6) - sched.run() # 1.2 + run(sched) # 1.2 self.assertEqual(i[0], 2) timeFastForward(0.6) - sched.run() # 1.8 + run(sched) # 1.8 self.assertEqual(i[0], 2) timeFastForward(0.6) - sched.run() # 2.4 + run(sched) # 2.4 self.assertEqual(i[0], 3) sched.removePeriodicEvent(n) timeFastForward(1) - sched.run() # 3.4 + run(sched) # 3.4 self.assertEqual(i[0], 3) def testCountedPeriodic(self): @@ -113,19 +119,19 @@ class TestSchedule(SupyTestCase): i[0] += 1 n = sched.addPeriodicEvent(inc, 1, name='test_periodic', count=3) timeFastForward(0.6) - sched.run() # 0.6 + run(sched) # 0.6 self.assertEqual(i[0], 1) timeFastForward(0.6) - sched.run() # 1.2 + run(sched) # 1.2 self.assertEqual(i[0], 2) timeFastForward(0.6) - sched.run() # 1.8 + run(sched) # 1.8 self.assertEqual(i[0], 2) timeFastForward(0.6) - sched.run() # 2.4 + run(sched) # 2.4 self.assertEqual(i[0], 3) timeFastForward(1) - sched.run() # 3.4 + run(sched) # 3.4 self.assertEqual(i[0], 3)