Skip to content
Closed
9 changes: 6 additions & 3 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import pytest

from . import live_server_helper
from .django_compat import is_django_unittest
from .lazy_django import skip_if_no_django

Expand Down Expand Up @@ -385,6 +384,7 @@ def live_server(request):
"""
skip_if_no_django()

from . import live_server_helper
import django

addr = request.config.getvalue("liveserver") or os.getenv(
Expand Down Expand Up @@ -434,8 +434,11 @@ def _live_server_helper(request):
request.getfixturevalue("transactional_db")

live_server = request.getfixturevalue("live_server")
live_server._live_server_modified_settings.enable()
request.addfinalizer(live_server._live_server_modified_settings.disable)

modified_settings = live_server._dj_testcase._live_server_modified_settings
if not hasattr(modified_settings, "wrapped"):
modified_settings.enable()
request.addfinalizer(modified_settings.disable)


@contextmanager
Expand Down
86 changes: 10 additions & 76 deletions pytest_django/live_server_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,66 +10,28 @@ class LiveServer(object):
"""

def __init__(self, addr):
import django
from django.db import connections
from django.test.testcases import LiveServerThread
from django.test.utils import modify_settings

connections_override = {}
for conn in connections.all():
# If using in-memory sqlite databases, pass the connections to
# the server thread.
if (
conn.settings_dict["ENGINE"] == "django.db.backends.sqlite3"
and conn.settings_dict["NAME"] == ":memory:"
):
# Explicitly enable thread-shareability for this connection
conn.allow_thread_sharing = True
connections_override[conn.alias] = conn

liveserver_kwargs = {"connections_override": connections_override}
from django.test.testcases import LiveServerTestCase
from django.conf import settings

if "django.contrib.staticfiles" in settings.INSTALLED_APPS:
from django.contrib.staticfiles.handlers import StaticFilesHandler

liveserver_kwargs["static_handler"] = StaticFilesHandler
else:
from django.test.testcases import _StaticFilesHandler

liveserver_kwargs["static_handler"] = _StaticFilesHandler

if django.VERSION < (1, 11):
host, possible_ports = parse_addr(addr)
self.thread = LiveServerThread(host, possible_ports, **liveserver_kwargs)
else:
try:
host, port = addr.split(":")
except ValueError:
host = addr
else:
liveserver_kwargs["port"] = int(port)
self.thread = LiveServerThread(host, **liveserver_kwargs)
from django.test.testcases import _StaticFilesHandler as StaticFilesHandler

self._live_server_modified_settings = modify_settings(
ALLOWED_HOSTS={"append": host}
)
class CustomLiveServerTestCase(LiveServerTestCase):
static_handler = StaticFilesHandler

self.thread.daemon = True
self.thread.start()
self.thread.is_ready.wait()

if self.thread.error:
raise self.thread.error
self._dj_testcase = CustomLiveServerTestCase("__init__")
self._dj_testcase.setUpClass()

def stop(self):
"""Stop the server"""
self.thread.terminate()
self.thread.join()
if not hasattr(self._dj_testcase._live_server_modified_settings, "wrapped"):
self._dj_testcase._live_server_modified_settings.enable()
self._dj_testcase.tearDownClass()

@property
def url(self):
return "http://%s:%s" % (self.thread.host, self.thread.port)
return self._dj_testcase.live_server_url

def __str__(self):
return self.url
Expand All @@ -79,31 +41,3 @@ def __add__(self, other):

def __repr__(self):
return "<LiveServer listening at %s>" % self.url


def parse_addr(specified_address):
"""Parse the --liveserver argument into a host/IP address and port range"""
# This code is based on
# django.test.testcases.LiveServerTestCase.setUpClass

# The specified ports may be of the form '8000-8010,8080,9200-9300'
# i.e. a comma-separated list of ports or ranges of ports, so we break
# it down into a detailed list of all possible ports.
possible_ports = []
try:
host, port_ranges = specified_address.split(":")
for port_range in port_ranges.split(","):
# A port range can be of either form: '8000' or '8000-8010'.
extremes = list(map(int, port_range.split("-")))
assert len(extremes) in (1, 2)
if len(extremes) == 1:
# Port range of the form '8000'
possible_ports.append(extremes[0])
else:
# Port range of the form '8000-8010'
for port in range(extremes[0], extremes[1] + 1):
possible_ports.append(port)
except Exception:
raise Exception('Invalid address ("%s") for live server.' % specified_address)

return host, possible_ports