Use asyncio to manage the driver loop

This allows us to remove the custom shoehorned select() handling
without changing the driver design too much.
This commit is contained in:
Valentin Lorentz 2021-07-31 14:31:12 +02:00
parent 0ed743bb8e
commit f164ac7fbe
4 changed files with 87 additions and 96 deletions

View File

@ -39,9 +39,10 @@ import os
import sys
import time
import errno
import threading
import select
import socket
import asyncio
import threading
import ipaddress
@ -58,9 +59,8 @@ except:
class SSLError(Exception):
pass
class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
_instances = []
_selecting = threading.Lock()
def __init__(self, irc):
assert irc is not None
self.irc = irc
@ -108,26 +108,19 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
# hasn't finished yet. We'll keep track of how many we get.
if e is None or e.args[0] != 11 or self.eagains > 120:
drivers.log.disconnect(self.currentServer, e)
if self in self._instances:
self._instances.remove(self)
try:
self.conn.close()
except:
pass
self.connected = False
if self.irc is None:
# This driver is dead already, but we're still running because
# of select() running in an other driver's thread that started
# before this one died and stil holding a reference to this
# instance.
# Just return, and we should never be called again.
return
self.scheduleReconnect()
else:
log.debug('Got EAGAIN, current count: %s.', self.eagains)
self.eagains += 1
def _sendIfMsgs(self):
async def _sendIfMsgs(self):
if not self.connected:
return
if not self.zombie:
@ -137,71 +130,34 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
del msgs[-1]
self.outbuffer += ''.join(map(str, msgs))
if self.outbuffer:
loop = asyncio.get_event_loop()
try:
if minisix.PY2:
sent = self.conn.send(self.outbuffer)
else:
sent = self.conn.send(self.outbuffer.encode())
self.outbuffer = self.outbuffer[sent:]
await loop.sock_sendall(self.conn, self.outbuffer.encode())
self.outbuffer = ''
self.eagains = 0
except socket.error as e:
self._handleSocketError(e)
if self.zombie and not self.outbuffer:
self._reallyDie()
@classmethod
def _select(cls):
timeout = conf.supybot.drivers.poll()
try:
if not cls._selecting.acquire(blocking=False):
# there's already a thread running this code, abort.
return
for inst in cls._instances:
# Do not use a list comprehension here, we have to edit the list
# and not to reassign it.
if not inst.connected or \
(minisix.PY3 and inst.conn._closed) or \
(minisix.PY2 and
inst.conn._sock.__class__ is socket._closedsocket):
cls._instances.remove(inst)
elif inst.conn.fileno() == -1:
inst.reconnect()
if not cls._instances:
return
rlist, wlist, xlist = select.select([x.conn for x in cls._instances],
[], [], timeout)
for instance in cls._instances:
if instance.conn in rlist:
instance._read()
except select.error as e:
if e.args[0] != errno.EINTR:
# 'Interrupted system call'
raise
finally:
cls._selecting.release()
for instance in cls._instances:
if instance.irc and not instance.irc.zombie:
instance._sendIfMsgs()
def run(self):
async def run(self):
now = time.time()
if self.nextReconnectTime is not None and now > self.nextReconnectTime:
self.reconnect()
elif self.writeCheckTime is not None and now > self.writeCheckTime:
self._checkAndWriteOrReconnect()
if not self.connected:
# We sleep here because otherwise, if we're the only driver, we'll
# spin at 100% CPU while we're disconnected.
time.sleep(conf.supybot.drivers.poll())
return
self._sendIfMsgs()
self._select()
await asyncio.gather(
self._sendIfMsgs(),
self._read(),
)
def _read(self):
async def _read(self):
"""Called by _select() when we can read data."""
loop = asyncio.get_event_loop()
try:
new_data = self.conn.recv(1024)
new_data = await loop.sock_recv(self.conn, 1024)
if not new_data:
# Socket was closed
self._handleSocketError(None)
@ -241,7 +197,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self._handleSocketError(e)
return
if self.irc and not self.irc.zombie:
self._sendIfMsgs()
await self._sendIfMsgs()
def connect(self, **kwargs):
self.reconnect(reset=False, **kwargs)
@ -252,8 +208,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if self.connected:
self.onDisconnect()
drivers.log.reconnect(self.irc.network)
if self in self._instances:
self._instances.remove(self)
try:
self.conn.shutdown(socket.SHUT_RDWR)
except: # "Transport endpoint not connected"
@ -314,12 +268,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
return
# We allow more time for the connect here, since it might take longer.
# At least 10 seconds.
self.conn.settimeout(max(10, conf.supybot.drivers.poll()*10))
try:
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
# See http://stackoverflow.com/q/16136916/539465
self.conn.connect(
(address, self.currentServer.port))
self.conn.connect((address, self.currentServer.port))
if network_config.ssl() or \
self.currentServer.force_tls_verification:
self.starttls()
@ -343,8 +295,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
'<http://docs.limnoria.net/en/latest/use/faq.html#how-to-make-a-connection-secure>')
% self.irc.network)
conf.supybot.drivers.poll.addCallback(self.setTimeout)
self.setTimeout()
self.connected = True
self.resetDelay()
except socket.error as e:
@ -361,13 +311,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.connectError(self.currentServer, e)
self.scheduleReconnect()
return
self._instances.append(self)
def setTimeout(self):
try:
self.conn.settimeout(conf.supybot.drivers.poll())
except Exception:
drivers.log.exception('Could not set socket timeout:')
def _checkAndWriteOrReconnect(self):
self.writeCheckTime = None
@ -393,9 +336,6 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self.nextReconnectTime = when
def die(self):
if self in self._instances:
self._instances.remove(self)
conf.supybot.drivers.poll.removeCallback(self.setTimeout)
self.zombie = True
if self.nextReconnectTime is not None:
self.nextReconnectTime = None

View File

@ -35,6 +35,7 @@ Contains various drivers (network, file, and otherwise) for using IRC objects.
import time
import socket
import asyncio
from collections import namedtuple
from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
@ -146,15 +147,59 @@ def remove(name):
"""Removes the driver with the given name from the loop."""
_deadDrivers.add(name)
def _loop_exception_handler(loop, context):
if 'Task was destroyed but it is pending' in context['message']:
# This happens when cancelling tasks because of KeyboardInterrupt.
# There is really nothing we can do about this, so it's
# pointless to log it.
return
log.error('Exception in drivers loop: %s', context['message'])
def run():
"""Runs the whole driver loop."""
loop = asyncio.new_event_loop()
loop.set_exception_handler(_loop_exception_handler)
driver_names = []
futures = []
coroutines = [] # Used to cleanup on shutdown to avoid warnings
for (name, driver) in _drivers.items():
if name not in _deadDrivers:
try:
coroutine = driver.run()
future = asyncio.ensure_future(coroutine, loop=loop)
except Exception:
log.exception('Exception in drivers.run for driver %s:', name)
continue
driver_names.append(name)
coroutines.append(coroutine)
futures.append(future)
gather_task = asyncio.gather(*futures, return_exceptions=True, loop=loop)
timeout_gather_task = asyncio.wait_for(
gather_task,
timeout=conf.supybot.drivers.poll())
coroutines.append(timeout_gather_task)
try:
loop.run_until_complete(timeout_gather_task)
except KeyboardInterrupt:
# Cleanup all the objects so they don't throw warnings.
gather_task.cancel()
for future in futures:
future.cancel()
for coroutine in coroutines:
coroutine.close()
raise
except asyncio.TimeoutError:
pass
for (name, future) in zip(driver_names, futures):
try:
if name not in _deadDrivers:
driver.run()
future.result() # Raises an exception if driver.run() did.
except:
log.exception('Uncaught exception in in drivers.run:')
log.exception('Uncaught exception in drivers.run:')
_deadDrivers.add(name)
for name in _deadDrivers:
try:
driver = _drivers[name]

View File

@ -137,7 +137,7 @@ class Schedule(drivers.IrcDriver):
removePeriodicEvent = removeEvent
def run(self):
async def run(self):
if len(drivers._drivers) == 1 and not world.testing:
log.error('Schedule is the only remaining driver, '
'why do we continue to live?')

View File

@ -31,6 +31,7 @@
from supybot.test import *
import time
import asyncio
import supybot.schedule as schedule
@ -43,6 +44,11 @@ class FakeSchedule(schedule.Schedule):
def name(self):
return 'FakeSchedule'
def run(sched):
loop = asyncio.new_event_loop()
loop.run_until_complete(sched.run())
class TestSchedule(SupyTestCase):
def testSchedule(self):
sched = FakeSchedule()
@ -55,14 +61,14 @@ class TestSchedule(SupyTestCase):
sched.addEvent(add10, time.time() + 3)
sched.addEvent(add1, time.time() + 1)
timeFastForward(1.2)
sched.run()
run(sched)
self.assertEqual(i[0], 1)
timeFastForward(1.9)
sched.run()
run(sched)
self.assertEqual(i[0], 11)
sched.addEvent(add10, time.time() + 3, 'test')
sched.run()
run(sched)
self.assertEqual(i[0], 11)
sched.removeEvent('test')
self.assertEqual(i[0], 11)
@ -77,10 +83,10 @@ class TestSchedule(SupyTestCase):
n = sched.addEvent(inc, time.time() + 1)
sched.rescheduleEvent(n, time.time() + 3)
timeFastForward(1.2)
sched.run()
run(sched)
self.assertEqual(i[0], 0)
timeFastForward(2)
sched.run()
run(sched)
self.assertEqual(i[0], 1)
def testPeriodic(self):
@ -90,20 +96,20 @@ class TestSchedule(SupyTestCase):
i[0] += 1
n = sched.addPeriodicEvent(inc, 1, name='test_periodic')
timeFastForward(0.6)
sched.run() # 0.6
run(sched) # 0.6
self.assertEqual(i[0], 1)
timeFastForward(0.6)
sched.run() # 1.2
run(sched) # 1.2
self.assertEqual(i[0], 2)
timeFastForward(0.6)
sched.run() # 1.8
run(sched) # 1.8
self.assertEqual(i[0], 2)
timeFastForward(0.6)
sched.run() # 2.4
run(sched) # 2.4
self.assertEqual(i[0], 3)
sched.removePeriodicEvent(n)
timeFastForward(1)
sched.run() # 3.4
run(sched) # 3.4
self.assertEqual(i[0], 3)
def testCountedPeriodic(self):
@ -113,19 +119,19 @@ class TestSchedule(SupyTestCase):
i[0] += 1
n = sched.addPeriodicEvent(inc, 1, name='test_periodic', count=3)
timeFastForward(0.6)
sched.run() # 0.6
run(sched) # 0.6
self.assertEqual(i[0], 1)
timeFastForward(0.6)
sched.run() # 1.2
run(sched) # 1.2
self.assertEqual(i[0], 2)
timeFastForward(0.6)
sched.run() # 1.8
run(sched) # 1.8
self.assertEqual(i[0], 2)
timeFastForward(0.6)
sched.run() # 2.4
run(sched) # 2.4
self.assertEqual(i[0], 3)
timeFastForward(1)
sched.run() # 3.4
run(sched) # 3.4
self.assertEqual(i[0], 3)