diff --git a/.travis.yml b/.travis.yml index a3ef963..67a356c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,10 +10,11 @@ deploy: tags: true install: pip install -U tox-travis language: python +dist: focal python: +- 3.9 - 3.8 - 3.7 - 3.6 -- 3.5 - 2.7 script: tox diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 01d606e..a2315ad 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -68,7 +68,7 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. $ mkvirtualenv graphql_ws $ cd graphql_ws/ - $ python setup.py develop + $ pip install -e .[dev] 4. Create a branch for local development:: @@ -79,11 +79,8 @@ Ready to contribute? Here's how to set up `graphql_ws` for local development. 5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox:: $ flake8 graphql_ws tests - $ python setup.py test or py.test $ tox - To get flake8 and tox, just pip install them into your virtualenv. - 6. Commit your changes and push your branch to GitHub:: $ git add . @@ -101,14 +98,6 @@ Before you submit a pull request, check that it meets these guidelines: 2. If the pull request adds functionality, the docs should be updated. Put your new functionality into a function with a docstring, and add the feature to the list in README.rst. -3. The pull request should work for Python 2.6, 2.7, 3.3, 3.4 and 3.5, and for PyPy. Check +3. The pull request should work for Python 2.7, 3.5, 3.6, 3.7 and 3.8. Check https://travis-ci.org/graphql-python/graphql_ws/pull_requests and make sure that the tests pass for all supported Python versions. - -Tips ----- - -To run a subset of tests:: - -$ py.test tests.test_graphql_ws - diff --git a/MANIFEST.in b/MANIFEST.in index 2c0ba39..e58af99 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,6 +5,9 @@ include LICENSE include README.rst include tox.ini +graft graphql_ws/django/templates +graft examples +prune examples/django_channels2/.cache graft tests global-exclude __pycache__ global-exclude *.py[co] diff --git a/README.rst b/README.rst index 90ee500..4882551 100644 --- a/README.rst +++ b/README.rst @@ -1,14 +1,24 @@ +========== GraphQL WS ========== -Websocket server for GraphQL subscriptions. +Websocket backend for GraphQL subscriptions. + +Supports the following application servers: + +Python 3 application servers, using asyncio: + + * `aiohttp`_ + * `websockets compatible servers`_ such as Sanic + (via `websockets `__ library) + * `Django v2+`_ -Currently supports: +Python 2 application servers: + + * `Gevent compatible servers`_ such as Flask + * `Django v1.x`_ + (via `channels v1.x `__) -* `aiohttp `__ -* `Gevent `__ -* Sanic (uses `websockets `__ - library) Installation instructions ========================= @@ -19,21 +29,54 @@ For instaling graphql-ws, just run this command in your shell pip install graphql-ws + Examples --------- +======== + +Python 3 servers +---------------- + +Create a subscribable schema like this: + +.. code:: python + + import asyncio + import graphene + + + class Query(graphene.ObjectType): + hello = graphene.String() + + @staticmethod + def resolve_hello(obj, info, **kwargs): + return "world" + + + class Subscription(graphene.ObjectType): + count_seconds = graphene.Float(up_to=graphene.Int()) + + async def resolve_count_seconds(root, info, up_to): + for i in range(up_to): + yield i + await asyncio.sleep(1.) + yield up_to + + + schema = graphene.Schema(query=Query, subscription=Subscription) aiohttp ~~~~~~~ -For setting up, just plug into your aiohttp server. +Then just plug into your aiohttp server. .. code:: python from graphql_ws.aiohttp import AiohttpSubscriptionServer - + from .schema import schema subscription_server = AiohttpSubscriptionServer(schema) + async def subscriptions(request): ws = web.WebSocketResponse(protocols=('graphql-ws',)) await ws.prepare(request) @@ -47,16 +90,20 @@ For setting up, just plug into your aiohttp server. web.run_app(app, port=8000) -Sanic -~~~~~ +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp + + +websockets compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Works with any framework that uses the websockets library for it’s -websocket implementation. For this example, plug in your Sanic server. +Works with any framework that uses the websockets library for its websocket +implementation. For this example, plug in your Sanic server. .. code:: python from graphql_ws.websockets_lib import WsLibSubscriptionServer - + from . import schema app = Sanic(__name__) @@ -70,49 +117,61 @@ websocket implementation. For this example, plug in your Sanic server. app.run(host="0.0.0.0", port=8000) -And then, plug into a subscribable schema: -.. code:: python +Django v2+ +~~~~~~~~~~ - import asyncio - import graphene +Django Channels 2 +~~~~~~~~~~~~~~~~~ - class Query(graphene.ObjectType): - base = graphene.String() +Set up with Django Channels just takes three steps: +1. Install the apps +2. Set up your schema +3. Configure the channels router application - class Subscription(graphene.ObjectType): - count_seconds = graphene.Float(up_to=graphene.Int()) +First ``pip install channels`` and it to your ``INSTALLED_APPS``. If you +want graphiQL, install the ``graphql_ws.django`` app before +``graphene_django`` to serve a graphiQL template that will work with +websockets: - async def resolve_count_seconds(root, info, up_to): - for i in range(up_to): - yield i - await asyncio.sleep(1.) - yield up_to +.. code:: python + INSTALLED_APPS = [ + "channels", + "graphql_ws.django", + "graphene_django", + # ... + ] - schema = graphene.Schema(query=Query, subscription=Subscription) +Point to your schema in Django settings: -You can see a full example here: -https://github.com/graphql-python/graphql-ws/tree/master/examples/aiohttp +.. code:: python -Gevent -~~~~~~ + GRAPHENE = { + 'SCHEMA': 'yourproject.schema.schema' + } -For setting up, just plug into your Gevent server. +Finally, you can set up channels routing yourself (maybe using +``graphql_ws.django.routing.websocket_urlpatterns`` in your +``URLRouter``), or you can just use one of the preset channels +applications: .. code:: python - subscription_server = GeventSubscriptionServer(schema) - app.app_protocol = lambda environ_path_info: 'graphql-ws' + ASGI_APPLICATION = 'graphql_ws.django.routing.application' + # or + ASGI_APPLICATION = 'graphql_ws.django.routing.auth_application' - @sockets.route('/subscriptions') - def echo_socket(ws): - subscription_server.handle(ws) - return [] +Run ``./manage.py runserver`` and go to +`http://localhost:8000/graphql`__ to test! -And then, plug into a subscribable schema: + +Python 2 servers +----------------- + +Create a subscribable schema like this: .. code:: python @@ -121,7 +180,11 @@ And then, plug into a subscribable schema: class Query(graphene.ObjectType): - base = graphene.String() + hello = graphene.String() + + @staticmethod + def resolve_hello(obj, info, **kwargs): + return "world" class Subscription(graphene.ObjectType): @@ -135,71 +198,54 @@ And then, plug into a subscribable schema: schema = graphene.Schema(query=Query, subscription=Subscription) -You can see a full example here: -https://github.com/graphql-python/graphql-ws/tree/master/examples/flask_gevent - -Django Channels -~~~~~~~~~~~~~~~ +Gevent compatible servers +~~~~~~~~~~~~~~~~~~~~~~~~~ -First ``pip install channels`` and it to your django apps - -Then add the following to your settings.py +Then just plug into your Gevent server, for example, Flask: .. code:: python - CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] - CHANNEL_LAYERS = { - "default": { - "BACKEND": "asgiref.inmemory.ChannelLayer", - "ROUTING": "django_subscriptions.urls.channel_routing", - }, - - } - -Setup your graphql schema + from flask_sockets import Sockets + from graphql_ws.gevent import GeventSubscriptionServer + from schema import schema -.. code:: python - - import graphene - from rx import Observable - - - class Query(graphene.ObjectType): - hello = graphene.String() + subscription_server = GeventSubscriptionServer(schema) + app.app_protocol = lambda environ_path_info: 'graphql-ws' - def resolve_hello(self, info, **kwargs): - return 'world' - class Subscription(graphene.ObjectType): + @sockets.route('/subscriptions') + def echo_socket(ws): + subscription_server.handle(ws) + return [] - count_seconds = graphene.Int(up_to=graphene.Int()) +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/flask_gevent +Django v1.x +~~~~~~~~~~~ - def resolve_count_seconds( - root, - info, - up_to=5 - ): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - +For Django v1.x and Django Channels v1.x, setup your schema in ``settings.py`` +.. code:: python - schema = graphene.Schema( - query=Query, - subscription=Subscription - ) + GRAPHENE = { + 'SCHEMA': 'yourproject.schema.schema' + } -Setup your schema in settings.py +Then ``pip install "channels<1"`` and it to your django apps, adding the +following to your ``settings.py`` .. code:: python - GRAPHENE = { - 'SCHEMA': 'path.to.schema' + CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] + CHANNEL_LAYERS = { + "default": { + "BACKEND": "asgiref.inmemory.ChannelLayer", + "ROUTING": "django_subscriptions.urls.channel_routing", + }, } -and finally add the channel routes +And finally add the channel routes .. code:: python @@ -209,3 +255,6 @@ and finally add the channel routes channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), ] + +You can see a full example here: +https://github.com/graphql-python/graphql-ws/tree/master/examples/django_subscriptions diff --git a/examples/aiohttp/app.py b/examples/aiohttp/app.py index 56dcaff..336a0c6 100644 --- a/examples/aiohttp/app.py +++ b/examples/aiohttp/app.py @@ -10,24 +10,25 @@ async def graphql_view(request): payload = await request.json() - response = await schema.execute(payload.get('query', ''), return_promise=True) + response = await schema.execute(payload.get("query", ""), return_promise=True) data = {} if response.errors: - data['errors'] = [format_error(e) for e in response.errors] + data["errors"] = [format_error(e) for e in response.errors] if response.data: - data['data'] = response.data + data["data"] = response.data jsondata = json.dumps(data,) - return web.Response(text=jsondata, headers={'Content-Type': 'application/json'}) + return web.Response(text=jsondata, headers={"Content-Type": "application/json"}) async def graphiql_view(request): - return web.Response(text=render_graphiql(), headers={'Content-Type': 'text/html'}) + return web.Response(text=render_graphiql(), headers={"Content-Type": "text/html"}) + subscription_server = AiohttpSubscriptionServer(schema) async def subscriptions(request): - ws = web.WebSocketResponse(protocols=('graphql-ws',)) + ws = web.WebSocketResponse(protocols=("graphql-ws",)) await ws.prepare(request) await subscription_server.handle(ws) @@ -35,9 +36,9 @@ async def subscriptions(request): app = web.Application() -app.router.add_get('/subscriptions', subscriptions) -app.router.add_get('/graphiql', graphiql_view) -app.router.add_get('/graphql', graphql_view) -app.router.add_post('/graphql', graphql_view) +app.router.add_get("/subscriptions", subscriptions) +app.router.add_get("/graphiql", graphiql_view) +app.router.add_get("/graphql", graphql_view) +app.router.add_post("/graphql", graphql_view) web.run_app(app, port=8000) diff --git a/examples/aiohttp/schema.py b/examples/aiohttp/schema.py index 3c23d00..ae107c7 100644 --- a/examples/aiohttp/schema.py +++ b/examples/aiohttp/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/aiohttp/template.py b/examples/aiohttp/template.py index 0b74e96..709f7cf 100644 --- a/examples/aiohttp/template.py +++ b/examples/aiohttp/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_channels2/Pipfile b/examples/django_channels2/Pipfile new file mode 100644 index 0000000..ba4a0a4 --- /dev/null +++ b/examples/django_channels2/Pipfile @@ -0,0 +1,14 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[dev-packages] + +[packages] +graphql-ws = {path = "./../..", editable = true} +channels = "==2.*" +graphene-django = "*" + +[requires] +python_version = "3.6" diff --git a/examples/django_channels2/Pipfile.lock b/examples/django_channels2/Pipfile.lock new file mode 100644 index 0000000..e33bf6c --- /dev/null +++ b/examples/django_channels2/Pipfile.lock @@ -0,0 +1,303 @@ +{ + "_meta": { + "hash": { + "sha256": "75a0ce53afdb6d8ea231f82ff73a9de06bad3a6dba8263f658d14abe9f6cf9f9" + }, + "pipfile-spec": 6, + "requires": { + "python_version": "3.6" + }, + "sources": [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": true + } + ] + }, + "default": { + "aniso8601": { + "hashes": [ + "sha256:b8a6a9b24611fc50cf2d9b45d371bfdc4fd0581d1cc52254f5502130a776d4af", + "sha256:bb167645c79f7a438f9dfab6161af9bed75508c645b1f07d1158240841d22673" + ], + "version": "==6.0.0" + }, + "asgiref": { + "hashes": [ + "sha256:865b7ccce5a6e815607b08d9059fe9c058cd75c77f896f5e0b74ff6c1ba81818", + "sha256:b718a9d35e204a96e2456c2271b0ef12e36124c363b3a8fd1d626744f23192aa" + ], + "version": "==3.1.4" + }, + "asn1crypto": { + "hashes": [ + "sha256:2f1adbb7546ed199e3c90ef23ec95c5cf3585bac7d11fb7eb562a3fe89c64e87", + "sha256:9d5c20441baf0cb60a4ac34cc447c6c189024b6b4c6cd7877034f4965c464e49" + ], + "version": "==0.24.0" + }, + "attrs": { + "hashes": [ + "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79", + "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399" + ], + "version": "==19.1.0" + }, + "autobahn": { + "hashes": [ + "sha256:70f0cfb8005b5429df5709acf5d66a8eba00669765547029371648dffd4a0470", + "sha256:89f94a1535673b1655df28ef91e96b7f34faea76da04a5e56441c9ac779a2f9f" + ], + "version": "==19.7.1" + }, + "automat": { + "hashes": [ + "sha256:cbd78b83fa2d81fe2a4d23d258e1661dd7493c9a50ee2f1a5b2cac61c1793b0e", + "sha256:fdccab66b68498af9ecfa1fa43693abe546014dd25cf28543cbe9d1334916a58" + ], + "version": "==0.7.0" + }, + "cffi": { + "hashes": [ + "sha256:041c81822e9f84b1d9c401182e174996f0bae9991f33725d059b771744290774", + "sha256:046ef9a22f5d3eed06334d01b1e836977eeef500d9b78e9ef693f9380ad0b83d", + "sha256:066bc4c7895c91812eff46f4b1c285220947d4aa46fa0a2651ff85f2afae9c90", + "sha256:066c7ff148ae33040c01058662d6752fd73fbc8e64787229ea8498c7d7f4041b", + "sha256:2444d0c61f03dcd26dbf7600cf64354376ee579acad77aef459e34efcb438c63", + "sha256:300832850b8f7967e278870c5d51e3819b9aad8f0a2c8dbe39ab11f119237f45", + "sha256:34c77afe85b6b9e967bd8154e3855e847b70ca42043db6ad17f26899a3df1b25", + "sha256:46de5fa00f7ac09f020729148ff632819649b3e05a007d286242c4882f7b1dc3", + "sha256:4aa8ee7ba27c472d429b980c51e714a24f47ca296d53f4d7868075b175866f4b", + "sha256:4d0004eb4351e35ed950c14c11e734182591465a33e960a4ab5e8d4f04d72647", + "sha256:4e3d3f31a1e202b0f5a35ba3bc4eb41e2fc2b11c1eff38b362de710bcffb5016", + "sha256:50bec6d35e6b1aaeb17f7c4e2b9374ebf95a8975d57863546fa83e8d31bdb8c4", + "sha256:55cad9a6df1e2a1d62063f79d0881a414a906a6962bc160ac968cc03ed3efcfb", + "sha256:5662ad4e4e84f1eaa8efce5da695c5d2e229c563f9d5ce5b0113f71321bcf753", + "sha256:59b4dc008f98fc6ee2bb4fd7fc786a8d70000d058c2bbe2698275bc53a8d3fa7", + "sha256:73e1ffefe05e4ccd7bcea61af76f36077b914f92b76f95ccf00b0c1b9186f3f9", + "sha256:a1f0fd46eba2d71ce1589f7e50a9e2ffaeb739fb2c11e8192aa2b45d5f6cc41f", + "sha256:a2e85dc204556657661051ff4bab75a84e968669765c8a2cd425918699c3d0e8", + "sha256:a5457d47dfff24882a21492e5815f891c0ca35fefae8aa742c6c263dac16ef1f", + "sha256:a8dccd61d52a8dae4a825cdbb7735da530179fea472903eb871a5513b5abbfdc", + "sha256:ae61af521ed676cf16ae94f30fe202781a38d7178b6b4ab622e4eec8cefaff42", + "sha256:b012a5edb48288f77a63dba0840c92d0504aa215612da4541b7b42d849bc83a3", + "sha256:d2c5cfa536227f57f97c92ac30c8109688ace8fa4ac086d19d0af47d134e2909", + "sha256:d42b5796e20aacc9d15e66befb7a345454eef794fdb0737d1af593447c6c8f45", + "sha256:dee54f5d30d775f525894d67b1495625dd9322945e7fee00731952e0368ff42d", + "sha256:e070535507bd6aa07124258171be2ee8dfc19119c28ca94c9dfb7efd23564512", + "sha256:e1ff2748c84d97b065cc95429814cdba39bcbd77c9c85c89344b317dc0d9cbff", + "sha256:ed851c75d1e0e043cbf5ca9a8e1b13c4c90f3fbd863dacb01c0808e2b5204201" + ], + "version": "==1.12.3" + }, + "channels": { + "hashes": [ + "sha256:9191a85800673b790d1d74666fb7676f430600b71b662581e97dd69c9aedd29a", + "sha256:af7cdba9efb3f55b939917d1b15defb5d40259936013e60660e5e9aff98db4c5" + ], + "index": "pypi", + "version": "==2.2.0" + }, + "constantly": { + "hashes": [ + "sha256:586372eb92059873e29eba4f9dec8381541b4d3834660707faf8ba59146dfc35", + "sha256:dd2fa9d6b1a51a83f0d7dd76293d734046aa176e384bf6e33b7e44880eb37c5d" + ], + "version": "==15.1.0" + }, + "cryptography": { + "hashes": [ + "sha256:24b61e5fcb506424d3ec4e18bca995833839bf13c59fc43e530e488f28d46b8c", + "sha256:25dd1581a183e9e7a806fe0543f485103232f940fcfc301db65e630512cce643", + "sha256:3452bba7c21c69f2df772762be0066c7ed5dc65df494a1d53a58b683a83e1216", + "sha256:41a0be220dd1ed9e998f5891948306eb8c812b512dc398e5a01846d855050799", + "sha256:5751d8a11b956fbfa314f6553d186b94aa70fdb03d8a4d4f1c82dcacf0cbe28a", + "sha256:5f61c7d749048fa6e3322258b4263463bfccefecb0dd731b6561cb617a1d9bb9", + "sha256:72e24c521fa2106f19623a3851e9f89ddfdeb9ac63871c7643790f872a305dfc", + "sha256:7b97ae6ef5cba2e3bb14256625423413d5ce8d1abb91d4f29b6d1a081da765f8", + "sha256:961e886d8a3590fd2c723cf07be14e2a91cf53c25f02435c04d39e90780e3b53", + "sha256:96d8473848e984184b6728e2c9d391482008646276c3ff084a1bd89e15ff53a1", + "sha256:ae536da50c7ad1e002c3eee101871d93abdc90d9c5f651818450a0d3af718609", + "sha256:b0db0cecf396033abb4a93c95d1602f268b3a68bb0a9cc06a7cff587bb9a7292", + "sha256:cfee9164954c186b191b91d4193989ca994703b2fff406f71cf454a2d3c7327e", + "sha256:e6347742ac8f35ded4a46ff835c60e68c22a536a8ae5c4422966d06946b6d4c6", + "sha256:f27d93f0139a3c056172ebb5d4f9056e770fdf0206c2f422ff2ebbad142e09ed", + "sha256:f57b76e46a58b63d1c6375017f4564a28f19a5ca912691fd2e4261b3414b618d" + ], + "version": "==2.7" + }, + "daphne": { + "hashes": [ + "sha256:2329b7a74b5559f7ea012879c10ba945c3a53df7d8d2b5932a904e3b4c9abcc2", + "sha256:3cae286a995ae5b127d7de84916f0480cb5be19f81125b6a150b8326250dadd5" + ], + "version": "==2.3.0" + }, + "django": { + "hashes": [ + "sha256:4d23f61b26892bac785f07401bc38cbf8fa4cec993f400e9cd9ddf28fd51c0ea", + "sha256:6e974d4b57e3b29e4882b244d40171d6a75202ab8d2402b8e8adbd182e25cf0c" + ], + "version": "==2.2.3" + }, + "graphene": { + "hashes": [ + "sha256:77d61618132ccd084c343e64c22d806cee18dce73cc86e0f427378dbdeeac287", + "sha256:acf808d50d053b94f7958414d511489a9e490a7f9563b9be80f6875fc5723d2a" + ], + "version": "==2.1.7" + }, + "graphene-django": { + "hashes": [ + "sha256:3101e8a8353c6b13f33261f5b0437deb3d3614d1c44b2d56932b158e3660c0cd", + "sha256:5714c5dd1200800ddc12d0782b0d82db70aedf387575e5b57ee2cdee4f25c681" + ], + "index": "pypi", + "version": "==2.4.0" + }, + "graphql-core": { + "hashes": [ + "sha256:1488f2a5c2272dc9ba66e3042a6d1c30cea0db4c80bd1e911c6791ad6187d91b", + "sha256:da64c472d720da4537a2e8de8ba859210b62841bd47a9be65ca35177f62fe0e4" + ], + "version": "==2.2.1" + }, + "graphql-relay": { + "hashes": [ + "sha256:0e94201af4089e1f81f07d7bd8f84799768e39d70fa1ea16d1df505b46cc6335", + "sha256:75aa0758971e252964cb94068a4decd472d2a8295229f02189e3cbca1f10dbb5", + "sha256:7fa74661246e826ef939ee92e768f698df167a7617361ab399901eaebf80dce6" + ], + "version": "==2.0.0" + }, + "graphql-ws": { + "editable": true, + "path": "./../.." + }, + "hyperlink": { + "hashes": [ + "sha256:4288e34705da077fada1111a24a0aa08bb1e76699c9ce49876af722441845654", + "sha256:ab4a308feb039b04f855a020a6eda3b18ca5a68e6d8f8c899cbe9e653721d04f" + ], + "version": "==19.0.0" + }, + "idna": { + "hashes": [ + "sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407", + "sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c" + ], + "version": "==2.8" + }, + "incremental": { + "hashes": [ + "sha256:717e12246dddf231a349175f48d74d93e2897244939173b01974ab6661406b9f", + "sha256:7b751696aaf36eebfab537e458929e194460051ccad279c72b755a167eebd4b3" + ], + "version": "==17.5.0" + }, + "promise": { + "hashes": [ + "sha256:2ebbfc10b7abf6354403ed785fe4f04b9dfd421eb1a474ac8d187022228332af", + "sha256:348f5f6c3edd4fd47c9cd65aed03ac1b31136d375aa63871a57d3e444c85655c" + ], + "version": "==2.2.1" + }, + "pycparser": { + "hashes": [ + "sha256:a988718abfad80b6b157acce7bf130a30876d27603738ac39f140993246b25b3" + ], + "version": "==2.19" + }, + "pyhamcrest": { + "hashes": [ + "sha256:6b672c02fdf7470df9674ab82263841ce8333fb143f32f021f6cb26f0e512420", + "sha256:8ffaa0a53da57e89de14ced7185ac746227a8894dbd5a3c718bf05ddbd1d56cd" + ], + "version": "==1.9.0" + }, + "pytz": { + "hashes": [ + "sha256:303879e36b721603cc54604edcac9d20401bdbe31e1e4fdee5b9f98d5d31dfda", + "sha256:d747dd3d23d77ef44c6a3526e274af6efeb0a6f1afd5a69ba4d5be4098c8e141" + ], + "version": "==2019.1" + }, + "rx": { + "hashes": [ + "sha256:13a1d8d9e252625c173dc795471e614eadfe1cf40ffc684e08b8fff0d9748c23", + "sha256:7357592bc7e881a95e0c2013b73326f704953301ab551fbc8133a6fadab84105" + ], + "version": "==1.6.1" + }, + "singledispatch": { + "hashes": [ + "sha256:5b06af87df13818d14f08a028e42f566640aef80805c3b50c5056b086e3c2b9c", + "sha256:833b46966687b3de7f438c761ac475213e53b306740f1abfaa86e1d1aae56aa8" + ], + "version": "==3.4.0.3" + }, + "six": { + "hashes": [ + "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", + "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" + ], + "version": "==1.12.0" + }, + "sqlparse": { + "hashes": [ + "sha256:40afe6b8d4b1117e7dff5504d7a8ce07d9a1b15aeeade8a2d10f130a834f8177", + "sha256:7c3dca29c022744e95b547e867cee89f4fce4373f3549ccd8797d8eb52cdb873" + ], + "version": "==0.3.0" + }, + "twisted": { + "hashes": [ + "sha256:fa2c04c2d68a9be7fc3975ba4947f653a57a656776f24be58ff0fe4b9aaf3e52" + ], + "version": "==19.2.1" + }, + "txaio": { + "hashes": [ + "sha256:67e360ac73b12c52058219bb5f8b3ed4105d2636707a36a7cdafb56fe06db7fe", + "sha256:b6b235d432cc58ffe111b43e337db71a5caa5d3eaa88f0eacf60b431c7626ef5" + ], + "version": "==18.8.1" + }, + "zope.interface": { + "hashes": [ + "sha256:086707e0f413ff8800d9c4bc26e174f7ee4c9c8b0302fbad68d083071822316c", + "sha256:1157b1ec2a1f5bf45668421e3955c60c610e31913cc695b407a574efdbae1f7b", + "sha256:11ebddf765bff3bbe8dbce10c86884d87f90ed66ee410a7e6c392086e2c63d02", + "sha256:14b242d53f6f35c2d07aa2c0e13ccb710392bcd203e1b82a1828d216f6f6b11f", + "sha256:1b3d0dcabc7c90b470e59e38a9acaa361be43b3a6ea644c0063951964717f0e5", + "sha256:20a12ab46a7e72b89ce0671e7d7a6c3c1ca2c2766ac98112f78c5bddaa6e4375", + "sha256:298f82c0ab1b182bd1f34f347ea97dde0fffb9ecf850ecf7f8904b8442a07487", + "sha256:2f6175722da6f23dbfc76c26c241b67b020e1e83ec7fe93c9e5d3dd18667ada2", + "sha256:3b877de633a0f6d81b600624ff9137312d8b1d0f517064dfc39999352ab659f0", + "sha256:4265681e77f5ac5bac0905812b828c9fe1ce80c6f3e3f8574acfb5643aeabc5b", + "sha256:550695c4e7313555549aa1cdb978dc9413d61307531f123558e438871a883d63", + "sha256:5f4d42baed3a14c290a078e2696c5f565501abde1b2f3f1a1c0a94fbf6fbcc39", + "sha256:62dd71dbed8cc6a18379700701d959307823b3b2451bdc018594c48956ace745", + "sha256:7040547e5b882349c0a2cc9b50674b1745db551f330746af434aad4f09fba2cc", + "sha256:7e099fde2cce8b29434684f82977db4e24f0efa8b0508179fce1602d103296a2", + "sha256:7e5c9a5012b2b33e87980cee7d1c82412b2ebabcb5862d53413ba1a2cfde23aa", + "sha256:81295629128f929e73be4ccfdd943a0906e5fe3cdb0d43ff1e5144d16fbb52b1", + "sha256:95cc574b0b83b85be9917d37cd2fad0ce5a0d21b024e1a5804d044aabea636fc", + "sha256:968d5c5702da15c5bf8e4a6e4b67a4d92164e334e9c0b6acf080106678230b98", + "sha256:9e998ba87df77a85c7bed53240a7257afe51a07ee6bc3445a0bf841886da0b97", + "sha256:a0c39e2535a7e9c195af956610dba5a1073071d2d85e9d2e5d789463f63e52ab", + "sha256:a15e75d284178afe529a536b0e8b28b7e107ef39626a7809b4ee64ff3abc9127", + "sha256:a6a6ff82f5f9b9702478035d8f6fb6903885653bff7ec3a1e011edc9b1a7168d", + "sha256:b639f72b95389620c1f881d94739c614d385406ab1d6926a9ffe1c8abbea23fe", + "sha256:bad44274b151d46619a7567010f7cde23a908c6faa84b97598fd2f474a0c6891", + "sha256:bbcef00d09a30948756c5968863316c949d9cedbc7aabac5e8f0ffbdb632e5f1", + "sha256:d788a3999014ddf416f2dc454efa4a5dbeda657c6aba031cf363741273804c6b", + "sha256:eed88ae03e1ef3a75a0e96a55a99d7937ed03e53d0cffc2451c208db445a2966", + "sha256:f99451f3a579e73b5dd58b1b08d1179791d49084371d9a47baad3b22417f0317" + ], + "version": "==4.6.0" + } + }, + "develop": {} +} diff --git a/examples/django_channels2/django_channels2/__init__.py b/examples/django_channels2/django_channels2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/django_channels2/django_channels2/schema.py b/examples/django_channels2/django_channels2/schema.py new file mode 100644 index 0000000..158ba13 --- /dev/null +++ b/examples/django_channels2/django_channels2/schema.py @@ -0,0 +1,54 @@ +import asyncio + +import graphene +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer + +channel_layer = get_channel_layer() + + +class Query(graphene.ObjectType): + hello = graphene.String() + + def resolve_hello(self, info, **kwargs): + return "world" + + +class TestMessageMutation(graphene.Mutation): + class Arguments: + input_text = graphene.String() + + output_text = graphene.String() + + def mutate(self, info, input_text): + async_to_sync(channel_layer.group_send)("new_message", {"data": input_text}) + return TestMessageMutation(output_text=input_text) + + +class Mutations(graphene.ObjectType): + test_message = TestMessageMutation.Field() + + +class Subscription(graphene.ObjectType): + count_seconds = graphene.Int(up_to=graphene.Int()) + new_message = graphene.String() + + async def resolve_count_seconds(self, info, up_to=5): + i = 1 + while i <= up_to: + yield str(i) + await asyncio.sleep(1) + i += 1 + + async def resolve_new_message(self, info): + channel_name = await channel_layer.new_channel() + await channel_layer.group_add("new_message", channel_name) + try: + while True: + message = await channel_layer.receive(channel_name) + yield message["data"] + finally: + await channel_layer.group_discard("new_message", channel_name) + + +schema = graphene.Schema(query=Query, mutation=Mutations, subscription=Subscription) diff --git a/examples/django_channels2/django_channels2/settings.py b/examples/django_channels2/django_channels2/settings.py new file mode 100644 index 0000000..6c7b22b --- /dev/null +++ b/examples/django_channels2/django_channels2/settings.py @@ -0,0 +1,33 @@ +""" +Django settings for django_channels2 project. +""" +SECRET_KEY = "0%1c709jhmggqhk&=tci06iy+%jedfxpcoai69jd8wjzm+k2f0" +DEBUG = True + + +INSTALLED_APPS = ["channels", "graphql_ws.django", "graphene_django"] + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + ] + }, + } +] + +MIDDLEWARE = [ + 'django.middleware.common.CommonMiddleware', +] + +ROOT_URLCONF = "django_channels2.urls" +ASGI_APPLICATION = "graphql_ws.django.routing.application" + + +CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}} +GRAPHENE = {"MIDDLEWARE": [], "SCHEMA": "django_channels2.schema.schema"} diff --git a/examples/django_channels2/django_channels2/urls.py b/examples/django_channels2/django_channels2/urls.py new file mode 100644 index 0000000..d0e41c4 --- /dev/null +++ b/examples/django_channels2/django_channels2/urls.py @@ -0,0 +1,4 @@ +from django.urls import path +from graphene_django.views import GraphQLView + +urlpatterns = [path("graphql/", GraphQLView.as_view(graphiql=True))] diff --git a/examples/django_channels2/manage.py b/examples/django_channels2/manage.py new file mode 100755 index 0000000..1b65bb4 --- /dev/null +++ b/examples/django_channels2/manage.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +import os +import sys + +if __name__ == "__main__": + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_channels2.settings") + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) diff --git a/examples/django_subscriptions/django_subscriptions/asgi.py b/examples/django_subscriptions/django_subscriptions/asgi.py index e6edd7d..35b4d4d 100644 --- a/examples/django_subscriptions/django_subscriptions/asgi.py +++ b/examples/django_subscriptions/django_subscriptions/asgi.py @@ -3,4 +3,4 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_subscriptions.settings") -channel_layer = get_channel_layer() \ No newline at end of file +channel_layer = get_channel_layer() diff --git a/examples/django_subscriptions/django_subscriptions/schema.py b/examples/django_subscriptions/django_subscriptions/schema.py index b55d76e..db6893c 100644 --- a/examples/django_subscriptions/django_subscriptions/schema.py +++ b/examples/django_subscriptions/django_subscriptions/schema.py @@ -6,18 +6,19 @@ class Query(graphene.ObjectType): hello = graphene.String() def resolve_hello(self, info, **kwargs): - return 'world' + return "world" + class Subscription(graphene.ObjectType): count_seconds = graphene.Int(up_to=graphene.Int()) - def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) - + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) -schema = graphene.Schema(query=Query, subscription=Subscription) \ No newline at end of file +schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/django_subscriptions/django_subscriptions/settings.py b/examples/django_subscriptions/django_subscriptions/settings.py index 45d0471..7bb3f24 100644 --- a/examples/django_subscriptions/django_subscriptions/settings.py +++ b/examples/django_subscriptions/django_subscriptions/settings.py @@ -20,7 +20,7 @@ # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c' +SECRET_KEY = "fa#kz8m$l6)4(np9+-j_-z!voa090mah!s9^4jp=kj!^nwdq^c" # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True @@ -31,53 +31,53 @@ # Application definition INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'channels', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "channels", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", ] -ROOT_URLCONF = 'django_subscriptions.urls' +ROOT_URLCONF = "django_subscriptions.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, ] -WSGI_APPLICATION = 'django_subscriptions.wsgi.application' +WSGI_APPLICATION = "django_subscriptions.wsgi.application" # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db.sqlite3"), } } @@ -87,26 +87,20 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, + {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, + {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, + {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, ] # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -118,20 +112,16 @@ # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/1.11/howto/static-files/ -STATIC_URL = '/static/' -CHANNELS_WS_PROTOCOLS = ["graphql-ws", ] +STATIC_URL = "/static/" +CHANNELS_WS_PROTOCOLS = [ + "graphql-ws", +] CHANNEL_LAYERS = { "default": { - "BACKEND": "asgi_redis.RedisChannelLayer", - "CONFIG": { - "hosts": [("localhost", 6379)], - }, + "BACKEND": "asgiref.inmemory.ChannelLayer", "ROUTING": "django_subscriptions.urls.channel_routing", }, - } -GRAPHENE = { - 'SCHEMA': 'django_subscriptions.schema.schema' -} \ No newline at end of file +GRAPHENE = {"SCHEMA": "django_subscriptions.schema.schema"} diff --git a/examples/django_subscriptions/django_subscriptions/template.py b/examples/django_subscriptions/django_subscriptions/template.py index b067ae5..738d9e7 100644 --- a/examples/django_subscriptions/django_subscriptions/template.py +++ b/examples/django_subscriptions/django_subscriptions/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.11.10', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.11.10", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/django_subscriptions/django_subscriptions/urls.py b/examples/django_subscriptions/django_subscriptions/urls.py index 3848d22..caf790d 100644 --- a/examples/django_subscriptions/django_subscriptions/urls.py +++ b/examples/django_subscriptions/django_subscriptions/urls.py @@ -21,20 +21,21 @@ from graphene_django.views import GraphQLView from django.views.decorators.csrf import csrf_exempt +from channels.routing import route_class +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + def graphiql(request): response = HttpResponse(content=render_graphiql()) return response + urlpatterns = [ - url(r'^admin/', admin.site.urls), - url(r'^graphiql/', graphiql), - url(r'^graphql', csrf_exempt(GraphQLView.as_view(graphiql=True))) + url(r"^admin/", admin.site.urls), + url(r"^graphiql/", graphiql), + url(r"^graphql", csrf_exempt(GraphQLView.as_view(graphiql=True))), ] -from channels.routing import route_class -from graphql_ws.django_channels import GraphQLSubscriptionConsumer - channel_routing = [ route_class(GraphQLSubscriptionConsumer, path=r"^/subscriptions"), -] \ No newline at end of file +] diff --git a/examples/django_subscriptions/requirements.txt b/examples/django_subscriptions/requirements.txt new file mode 100644 index 0000000..557e99f --- /dev/null +++ b/examples/django_subscriptions/requirements.txt @@ -0,0 +1,4 @@ +-e ../.. +django<2 +channels<2 +graphene_django<3 \ No newline at end of file diff --git a/examples/flask_gevent/app.py b/examples/flask_gevent/app.py index dbb0cca..efd145b 100644 --- a/examples/flask_gevent/app.py +++ b/examples/flask_gevent/app.py @@ -1,5 +1,3 @@ -import json - from flask import Flask, make_response from flask_graphql import GraphQLView from flask_sockets import Sockets @@ -14,19 +12,20 @@ sockets = Sockets(app) -@app.route('/graphiql') +@app.route("/graphiql") def graphql_view(): return make_response(render_graphiql()) app.add_url_rule( - '/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=False)) + "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=False) +) subscription_server = GeventSubscriptionServer(schema) -app.app_protocol = lambda environ_path_info: 'graphql-ws' +app.app_protocol = lambda environ_path_info: "graphql-ws" -@sockets.route('/subscriptions') +@sockets.route("/subscriptions") def echo_socket(ws): subscription_server.handle(ws) return [] @@ -35,5 +34,6 @@ def echo_socket(ws): if __name__ == "__main__": from gevent import pywsgi from geventwebsocket.handler import WebSocketHandler - server = pywsgi.WSGIServer(('', 5000), app, handler_class=WebSocketHandler) + + server = pywsgi.WSGIServer(("", 5000), app, handler_class=WebSocketHandler) server.serve_forever() diff --git a/examples/flask_gevent/schema.py b/examples/flask_gevent/schema.py index 6e6298c..eb48050 100644 --- a/examples/flask_gevent/schema.py +++ b/examples/flask_gevent/schema.py @@ -19,12 +19,16 @@ class Subscription(graphene.ObjectType): random_int = graphene.Field(RandomType) def resolve_count_seconds(root, info, up_to=5): - return Observable.interval(1000)\ - .map(lambda i: "{0}".format(i))\ - .take_while(lambda i: int(i) <= up_to) + return ( + Observable.interval(1000) + .map(lambda i: "{0}".format(i)) + .take_while(lambda i: int(i) <= up_to) + ) def resolve_random_int(root, info): - return Observable.interval(1000).map(lambda i: RandomType(seconds=i, random_int=random.randint(0, 500))) + return Observable.interval(1000).map( + lambda i: RandomType(seconds=i, random_int=random.randint(0, 500)) + ) schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/flask_gevent/template.py b/examples/flask_gevent/template.py index 41f52e1..ea0438c 100644 --- a/examples/flask_gevent/template.py +++ b/examples/flask_gevent/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,10 +116,11 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.12.0', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:5000/subscriptions', +""" + ).substitute( + GRAPHIQL_VERSION="0.12.0", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:5000/subscriptions", # subscriptionsEndpoint='ws://localhost:5000/', - endpointURL='/graphql', + endpointURL="/graphql", ) diff --git a/examples/websockets_lib/app.py b/examples/websockets_lib/app.py index 0de6988..7638f3d 100644 --- a/examples/websockets_lib/app.py +++ b/examples/websockets_lib/app.py @@ -8,21 +8,23 @@ app = Sanic(__name__) -@app.listener('before_server_start') +@app.listener("before_server_start") def init_graphql(app, loop): - app.add_route(GraphQLView.as_view(schema=schema, - executor=AsyncioExecutor(loop=loop)), - '/graphql') + app.add_route( + GraphQLView.as_view(schema=schema, executor=AsyncioExecutor(loop=loop)), + "/graphql", + ) -@app.route('/graphiql') +@app.route("/graphiql") async def graphiql_view(request): return response.html(render_graphiql()) + subscription_server = WsLibSubscriptionServer(schema) -@app.websocket('/subscriptions', subprotocols=['graphql-ws']) +@app.websocket("/subscriptions", subprotocols=["graphql-ws"]) async def subscriptions(request, ws): await subscription_server.handle(ws) return ws diff --git a/examples/websockets_lib/schema.py b/examples/websockets_lib/schema.py index 3c23d00..ae107c7 100644 --- a/examples/websockets_lib/schema.py +++ b/examples/websockets_lib/schema.py @@ -20,14 +20,14 @@ async def resolve_count_seconds(root, info, up_to=5): for i in range(up_to): print("YIELD SECOND", i) yield i - await asyncio.sleep(1.) + await asyncio.sleep(1.0) yield up_to async def resolve_random_int(root, info): i = 0 while True: yield RandomType(seconds=i, random_int=random.randint(0, 500)) - await asyncio.sleep(1.) + await asyncio.sleep(1.0) i += 1 diff --git a/examples/websockets_lib/template.py b/examples/websockets_lib/template.py index 03587bb..8f007b9 100644 --- a/examples/websockets_lib/template.py +++ b/examples/websockets_lib/template.py @@ -1,9 +1,9 @@ - from string import Template def render_graphiql(): - return Template(''' + return Template( + """ @@ -116,9 +116,10 @@ def render_graphiql(): ); -''').substitute( - GRAPHIQL_VERSION='0.10.2', - SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', - subscriptionsEndpoint='ws://localhost:8000/subscriptions', - endpointURL='/graphql', +""" + ).substitute( + GRAPHIQL_VERSION="0.10.2", + SUBSCRIPTIONS_TRANSPORT_VERSION="0.7.0", + subscriptionsEndpoint="ws://localhost:8000/subscriptions", + endpointURL="/graphql", ) diff --git a/graphql_ws/__init__.py b/graphql_ws/__init__.py index 44c7dc3..0ffa258 100644 --- a/graphql_ws/__init__.py +++ b/graphql_ws/__init__.py @@ -3,8 +3,5 @@ """Top-level package for GraphQL WS.""" __author__ = """Syrus Akbary""" -__email__ = 'me@syrusakbary.com' -__version__ = '0.3.1' - - -from .base import BaseConnectionContext, BaseSubscriptionServer # noqa: F401 +__email__ = "me@syrusakbary.com" +__version__ = "0.3.1" diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 363ca67..baf8837 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,23 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield +import json +from asyncio import shield from aiohttp import WSMsgType -from graphql.execution.executors.asyncio import AsyncioExecutor -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -from .constants import ( - GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR, - GQL_COMPLETE -) -setup_observable_extension() - - -class AiohttpConnectionContext(BaseConnectionContext): +class AiohttpConnectionContext(BaseAsyncConnectionContext): async def receive(self): msg = await self.ws.receive() if msg.type == WSMsgType.TEXT: @@ -32,7 +22,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send_str(data) + await self.ws.send_str(json.dumps(data)) @property def closed(self): @@ -42,21 +32,10 @@ async def close(self, code): await self.ws.close(code=code) -class AiohttpSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(AiohttpSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class AiohttpSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context=None): connection_context = AiohttpConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -64,59 +43,9 @@ async def _handle(self, ws, request_context=None): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - - self.on_close(connection_context) - for task in pending: - task.cancel() + self.on_message(connection_context, message) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index f3aa1e7..31ad657 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -1,16 +1,16 @@ import json from collections import OrderedDict -from graphql import graphql, format_error +from graphql import format_error, graphql from .constants import ( + GQL_CONNECTION_ERROR, GQL_CONNECTION_INIT, GQL_CONNECTION_TERMINATE, + GQL_DATA, + GQL_ERROR, GQL_START, GQL_STOP, - GQL_ERROR, - GQL_CONNECTION_ERROR, - GQL_DATA ) @@ -34,7 +34,20 @@ def get_operation(self, op_id): return self.operations[op_id] def remove_operation(self, op_id): - del self.operations[op_id] + try: + return self.operations.pop(op_id) + except KeyError: + return + + def unsubscribe(self, op_id): + async_iterator = self.remove_operation(op_id) + if hasattr(async_iterator, 'dispose'): + async_iterator.dispose() + return async_iterator + + def unsubscribe_all(self): + for op_id in list(self.operations): + self.unsubscribe(op_id) def receive(self): raise NotImplementedError("receive method not implemented") @@ -51,33 +64,19 @@ def close(self, code): class BaseSubscriptionServer(object): + graphql_executor = None def __init__(self, schema, keep_alive=True): self.schema = schema self.keep_alive = keep_alive - def get_graphql_params(self, connection_context, payload): - return { - 'request_string': payload.get('query'), - 'variable_values': payload.get('variables'), - 'operation_name': payload.get('operationName'), - 'context_value': payload.get('context'), - } - - def build_message(self, id, op_type, payload): - message = {} - if id is not None: - message['id'] = id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - return message + def execute(self, params): + return graphql(self.schema, **dict(params, allow_subscriptions=True)) def process_message(self, connection_context, parsed_message): - op_id = parsed_message.get('id') - op_type = parsed_message.get('type') - payload = parsed_message.get('payload') + op_id = parsed_message.get("id") + op_type = parsed_message.get("type") + payload = parsed_message.get("payload") if op_type == GQL_CONNECTION_INIT: return self.on_connection_init(connection_context, op_id, payload) @@ -87,27 +86,59 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_START: assert isinstance(payload, dict), "The payload must be a dict" - params = self.get_graphql_params(connection_context, payload) - if not isinstance(params, dict): - error = Exception( - "Invalid params returned from get_graphql_params!" - " Return values must be a dict.") - return self.send_error(connection_context, op_id, error) - - # If we already have a subscription with this id, unsubscribe from - # it first - if connection_context.has_operation(op_id): - self.unsubscribe(connection_context, op_id) - return self.on_start(connection_context, op_id, params) elif op_type == GQL_STOP: return self.on_stop(connection_context, op_id) else: - return self.send_error(connection_context, op_id, Exception( - "Invalid message type: {}.".format(op_type))) + return self.send_error( + connection_context, + op_id, + Exception("Invalid message type: {}.".format(op_type)), + ) + + def on_connection_init(self, connection_context, op_id, payload): + raise NotImplementedError("on_connection_init method not implemented") + + def on_connection_terminate(self, connection_context, op_id): + return connection_context.close(1011) + + def get_graphql_params(self, connection_context, payload): + context = payload.get("context", connection_context.request_context) + return { + "request_string": payload.get("query"), + "variable_values": payload.get("variables"), + "operation_name": payload.get("operationName"), + "context_value": context, + "executor": self.graphql_executor(), + } + + def on_open(self, connection_context): + raise NotImplementedError("on_open method not implemented") + + def on_stop(self, connection_context, op_id): + return connection_context.unsubscribe(op_id) + + def on_close(self, connection_context): + return connection_context.unsubscribe_all() + + def send_message(self, connection_context, op_id=None, op_type=None, payload=None): + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return connection_context.send(message) + + def build_message(self, id, op_type, payload): + message = {} + if id is not None: + message["id"] = id + if op_type is not None: + message["type"] = op_type + if payload is not None: + message["payload"] = payload + assert message, "You need to send at least one thing" + return message def send_execution_result(self, connection_context, op_id, execution_result): result = self.execution_result_to_dict(execution_result) @@ -116,86 +147,34 @@ def send_execution_result(self, connection_context, op_id, execution_result): def execution_result_to_dict(self, execution_result): result = OrderedDict() if execution_result.data: - result['data'] = execution_result.data + result["data"] = execution_result.data if execution_result.errors: - result['errors'] = [format_error(error) - for error in execution_result.errors] + result["errors"] = [ + format_error(error) for error in execution_result.errors + ] return result - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = self.build_message(op_id, op_type, payload) - assert message, "You need to send at least one thing" - json_message = json.dumps(message) - return connection_context.send(json_message) - def send_error(self, connection_context, op_id, error, error_type=None): if error_type is None: error_type = GQL_ERROR assert error_type in [GQL_CONNECTION_ERROR, GQL_ERROR], ( - 'error_type should be one of the allowed error messages' - ' GQL_CONNECTION_ERROR or GQL_ERROR' - ) - - error_payload = { - 'message': str(error) - } - - return self.send_message( - connection_context, - op_id, - error_type, - error_payload + "error_type should be one of the allowed error messages" + " GQL_CONNECTION_ERROR or GQL_ERROR" ) - def unsubscribe(self, connection_context, op_id): - if connection_context.has_operation(op_id): - # Close async iterator - connection_context.get_operation(op_id).dispose() - # Close operation - connection_context.remove_operation(op_id) - self.on_operation_complete(connection_context, op_id) + error_payload = {"message": str(error)} - def on_operation_complete(self, connection_context, op_id): - pass - - def on_connection_terminate(self, connection_context, op_id): - return connection_context.close(1011) - - def execute(self, request_context, params): - return graphql( - self.schema, **dict(params, allow_subscriptions=True)) - - def handle(self, ws, request_context=None): - raise NotImplementedError("handle method not implemented") + return self.send_message(connection_context, op_id, error_type, error_payload) def on_message(self, connection_context, message): try: if not isinstance(message, dict): parsed_message = json.loads(message) - assert isinstance( - parsed_message, dict), "Payload must be an object." + assert isinstance(parsed_message, dict), "Payload must be an object." else: parsed_message = message except Exception as e: return self.send_error(connection_context, None, e) return self.process_message(connection_context, parsed_message) - - def on_open(self, connection_context): - raise NotImplementedError("on_open method not implemented") - - def on_connect(self, connection_context, payload): - raise NotImplementedError("on_connect method not implemented") - - def on_close(self, connection_context): - raise NotImplementedError("on_close method not implemented") - - def on_connection_init(self, connection_context, op_id, payload): - raise NotImplementedError("on_connection_init method not implemented") - - def on_stop(self, connection_context, op_id): - raise NotImplementedError("on_stop method not implemented") - - def on_start(self, connection_context, op_id, params): - raise NotImplementedError("on_start method not implemented") diff --git a/graphql_ws/base_async.py b/graphql_ws/base_async.py new file mode 100644 index 0000000..a21ca5e --- /dev/null +++ b/graphql_ws/base_async.py @@ -0,0 +1,189 @@ +import asyncio +import inspect +from abc import ABC, abstractmethod +from types import CoroutineType, GeneratorType +from typing import Any, Dict, List, Union +from weakref import WeakSet + +from graphql.execution.executors.asyncio import AsyncioExecutor +from promise import Promise + +from graphql_ws import base + +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .observable_aiter import setup_observable_extension + +setup_observable_extension() +CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE + + +# Copied from graphql-core v3.1.0 (graphql/pyutils/is_awaitable.py) +def is_awaitable(value: Any) -> bool: + """Return true if object can be passed to an ``await`` expression. + Instead of testing if the object is an instance of abc.Awaitable, it checks + the existence of an `__await__` attribute. This is much faster. + """ + return ( + # check for coroutine objects + isinstance(value, CoroutineType) + # check for old-style generator based coroutine objects + or isinstance(value, GeneratorType) + and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE) + # check for other awaitables (e.g. futures) + or hasattr(value, "__await__") + ) + + +async def resolve( + data: Any, _container: Union[List, Dict] = None, _key: Union[str, int] = None +) -> None: + """ + Recursively wait on any awaitable children of a data element and resolve any + Promises. + """ + if is_awaitable(data): + data = await data + if isinstance(data, Promise): + data = data.value # type: Any + if _container is not None: + _container[_key] = data + if isinstance(data, dict): + items = data.items() + elif isinstance(data, list): + items = enumerate(data) + else: + items = None + if items is not None: + children = [ + asyncio.ensure_future(resolve(child, _container=data, _key=key)) + for key, child in items + ] + if children: + await asyncio.wait(children) + + +class BaseAsyncConnectionContext(base.BaseConnectionContext, ABC): + def __init__(self, ws, request_context=None): + super().__init__(ws, request_context=request_context) + self.pending_tasks = WeakSet() + + @abstractmethod + async def receive(self): + raise NotImplementedError("receive method not implemented") + + @abstractmethod + async def send(self, data): + ... + + @property + @abstractmethod + def closed(self): + ... + + @abstractmethod + async def close(self, code): + ... + + def remember_task(self, task): + self.pending_tasks.add(task) + # Clear completed tasks + self.pending_tasks -= WeakSet( + task for task in self.pending_tasks if task.done() + ) + + async def unsubscribe(self, op_id): + async_iterator = super().unsubscribe(op_id) + if getattr(async_iterator, "future", None) and async_iterator.future.cancel(): + await async_iterator.future + + async def unsubscribe_all(self): + awaitables = [self.unsubscribe(op_id) for op_id in list(self.operations)] + for task in self.pending_tasks: + task.cancel() + awaitables.append(task) + if awaitables: + try: + await asyncio.gather(*awaitables) + except asyncio.CancelledError: + pass + + +class BaseAsyncSubscriptionServer(base.BaseSubscriptionServer, ABC): + graphql_executor = AsyncioExecutor + + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) + + @abstractmethod + async def handle(self, ws, request_context=None): + ... + + def process_message(self, connection_context, parsed_message): + task = asyncio.ensure_future( + super().process_message(connection_context, parsed_message), loop=self.loop + ) + connection_context.remember_task(task) + return task + + async def on_open(self, connection_context): + pass + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + await connection_context.unsubscribe(op_id) + + execution_result = self.execute(params) + + connection_context.register_operation(op_id, execution_result) + if hasattr(execution_result, "__aiter__"): + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + try: + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result + ) + except Exception as e: + await self.send_error(connection_context, op_id, e) + else: + try: + if is_awaitable(execution_result): + execution_result = await execution_result + await self.send_execution_result( + connection_context, op_id, execution_result + ) + except Exception as e: + await self.send_error(connection_context, op_id, e) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + await connection_context.unsubscribe(op_id) + await self.on_operation_complete(connection_context, op_id) + + async def send_message( + self, connection_context, op_id=None, op_type=None, payload=None + ): + if op_id is None or connection_context.has_operation(op_id): + message = self.build_message(op_id, op_type, payload) + return await connection_context.send(message) + + async def on_operation_complete(self, connection_context, op_id): + pass + + async def send_execution_result(self, connection_context, op_id, execution_result): + # Resolve any pending promises + await resolve(execution_result.data) + await super().send_execution_result(connection_context, op_id, execution_result) diff --git a/graphql_ws/base_sync.py b/graphql_ws/base_sync.py new file mode 100644 index 0000000..f6b6c68 --- /dev/null +++ b/graphql_ws/base_sync.py @@ -0,0 +1,80 @@ +from graphql.execution.executors.sync import SyncExecutor +from rx import Observable, Observer + +from .base import BaseSubscriptionServer +from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + + +class BaseSyncSubscriptionServer(BaseSubscriptionServer): + graphql_executor = SyncExecutor + + def on_operation_complete(self, connection_context, op_id): + pass + + def handle(self, ws, request_context=None): + raise NotImplementedError("handle method not implemented") + + def on_open(self, connection_context): + pass + + def on_connect(self, connection_context, payload): + pass + + def on_connection_init(self, connection_context, op_id, payload): + try: + self.on_connect(connection_context, payload) + self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + + except Exception as e: + self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + connection_context.close(1011) + + def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + connection_context.unsubscribe(op_id) + try: + execution_result = self.execute(params) + assert isinstance( + execution_result, Observable + ), "A subscription must return an observable" + disposable = execution_result.subscribe( + SubscriptionObserver( + connection_context, + op_id, + self.send_execution_result, + self.send_error, + self.send_message, + ) + ) + connection_context.register_operation(op_id, disposable) + + except Exception as e: + self.send_error(connection_context, op_id, e) + self.send_message(connection_context, op_id, GQL_COMPLETE) + + +class SubscriptionObserver(Observer): + def __init__( + self, connection_context, op_id, send_execution_result, send_error, send_message + ): + self.connection_context = connection_context + self.op_id = op_id + self.send_execution_result = send_execution_result + self.send_error = send_error + self.send_message = send_message + + def on_next(self, value): + if isinstance(value, Exception): + send_method = self.send_error + else: + send_method = self.send_execution_result + send_method(self.connection_context, self.op_id, value) + + def on_completed(self): + self.send_message(self.connection_context, self.op_id, GQL_COMPLETE) + self.connection_context.remove_operation(self.op_id) + + def on_error(self, error): + self.send_error(self.connection_context, self.op_id, error) + self.on_completed() diff --git a/graphql_ws/constants.py b/graphql_ws/constants.py index 4f9d2f1..8b57a60 100644 --- a/graphql_ws/constants.py +++ b/graphql_ws/constants.py @@ -1,15 +1,15 @@ -GRAPHQL_WS = 'graphql-ws' +GRAPHQL_WS = "graphql-ws" WS_PROTOCOL = GRAPHQL_WS -GQL_CONNECTION_INIT = 'connection_init' # Client -> Server -GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client -GQL_CONNECTION_ERROR = 'connection_error' # Server -> Client +GQL_CONNECTION_INIT = "connection_init" # Client -> Server +GQL_CONNECTION_ACK = "connection_ack" # Server -> Client +GQL_CONNECTION_ERROR = "connection_error" # Server -> Client # NOTE: This one here don't follow the standard due to connection optimization -GQL_CONNECTION_TERMINATE = 'connection_terminate' # Client -> Server -GQL_CONNECTION_KEEP_ALIVE = 'ka' # Server -> Client -GQL_START = 'start' # Client -> Server -GQL_DATA = 'data' # Server -> Client -GQL_ERROR = 'error' # Server -> Client -GQL_COMPLETE = 'complete' # Server -> Client -GQL_STOP = 'stop' # Client -> Server +GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server +GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client +GQL_START = "start" # Client -> Server +GQL_DATA = "data" # Server -> Client +GQL_ERROR = "error" # Server -> Client +GQL_COMPLETE = "complete" # Server -> Client +GQL_STOP = "stop" # Client -> Server diff --git a/graphql_ws/django/__init__.py b/graphql_ws/django/__init__.py new file mode 100644 index 0000000..d08b0f3 --- /dev/null +++ b/graphql_ws/django/__init__.py @@ -0,0 +1 @@ +default_app_config = "graphql_ws.django.apps.GraphQLChannelsApp" diff --git a/graphql_ws/django/apps.py b/graphql_ws/django/apps.py new file mode 100644 index 0000000..eb65bb2 --- /dev/null +++ b/graphql_ws/django/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class GraphQLChannelsApp(AppConfig): + name = "graphql_ws.django" + label = "graphql_channels" diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py new file mode 100644 index 0000000..b1c64d1 --- /dev/null +++ b/graphql_ws/django/consumers.py @@ -0,0 +1,31 @@ +import json + +from channels.generic.websocket import AsyncJsonWebsocketConsumer + +from ..constants import WS_PROTOCOL +from .subscriptions import subscription_server + + +class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): + + async def connect(self): + self.connection_context = None + if WS_PROTOCOL in self.scope["subprotocols"]: + self.connection_context = await subscription_server.handle( + ws=self, request_context=self.scope + ) + await self.accept(subprotocol=WS_PROTOCOL) + else: + await self.close() + + async def disconnect(self, code): + if self.connection_context: + self.connection_context.socket_closed = True + await subscription_server.on_close(self.connection_context) + + async def receive_json(self, content): + subscription_server.on_message(self.connection_context, content) + + @classmethod + async def encode_json(cls, content): + return json.dumps(content) diff --git a/graphql_ws/django/routing.py b/graphql_ws/django/routing.py new file mode 100644 index 0000000..15a1356 --- /dev/null +++ b/graphql_ws/django/routing.py @@ -0,0 +1,24 @@ +from channels.routing import ProtocolTypeRouter, URLRouter +from channels.sessions import SessionMiddlewareStack +from django.apps import apps +from django.urls import path +from .consumers import GraphQLSubscriptionConsumer + +if apps.is_installed("django.contrib.auth"): + from channels.auth import AuthMiddlewareStack +else: + AuthMiddlewareStack = None + + +websocket_urlpatterns = [path("subscriptions", GraphQLSubscriptionConsumer)] + +application = ProtocolTypeRouter({"websocket": URLRouter(websocket_urlpatterns)}) + +session_application = ProtocolTypeRouter( + {"websocket": SessionMiddlewareStack(URLRouter(websocket_urlpatterns))} +) + +if AuthMiddlewareStack: + auth_application = ProtocolTypeRouter( + {"websocket": AuthMiddlewareStack(URLRouter(websocket_urlpatterns))} + ) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py new file mode 100644 index 0000000..086445f --- /dev/null +++ b/graphql_ws/django/subscriptions.py @@ -0,0 +1,39 @@ +from graphene_django.settings import graphene_settings +from ..base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer +from ..observable_aiter import setup_observable_extension + +setup_observable_extension() + + +class ChannelsConnectionContext(BaseAsyncConnectionContext): + def __init__(self, *args, **kwargs): + super(ChannelsConnectionContext, self).__init__(*args, **kwargs) + self.socket_closed = False + + async def send(self, data): + if self.closed: + return + await self.ws.send_json(data) + + @property + def closed(self): + return self.socket_closed + + async def close(self, code): + await self.ws.close(code=code) + + async def receive(self, code): + """ + Unused, as the django consumer handles receiving messages and passes + them straight to ChannelsSubscriptionServer.on_message. + """ + + +class ChannelsSubscriptionServer(BaseAsyncSubscriptionServer): + async def handle(self, ws, request_context=None): + connection_context = ChannelsConnectionContext(ws, request_context) + await self.on_open(connection_context) + return connection_context + + +subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) diff --git a/graphql_ws/django/templates/graphene/graphiql.html b/graphql_ws/django/templates/graphene/graphiql.html new file mode 100644 index 0000000..dce2683 --- /dev/null +++ b/graphql_ws/django/templates/graphene/graphiql.html @@ -0,0 +1,135 @@ + + + + + + + + + + + + + + + + + diff --git a/graphql_ws/django_channels.py b/graphql_ws/django_channels.py index 61a7247..ddba58d 100644 --- a/graphql_ws/django_channels.py +++ b/graphql_ws/django_channels.py @@ -1,129 +1,47 @@ import json -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor - from channels.generic.websockets import JsonWebsocketConsumer from graphene_django.settings import graphene_settings -from .base import BaseConnectionContext, BaseSubscriptionServer -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR +from .base import BaseConnectionContext +from .base_sync import BaseSyncSubscriptionServer class DjangoChannelConnectionContext(BaseConnectionContext): - - def __init__(self, message, request_context=None): - self.message = message - self.operations = {} - self.request_context = request_context + def __init__(self, message): + super(DjangoChannelConnectionContext, self).__init__( + message.reply_channel, + request_context={"user": message.user, "session": message.http_session}, + ) def send(self, data): - self.message.reply_channel.send(data) + self.ws.send({"text": json.dumps(data)}) def close(self, reason): - data = { - 'close': True, - 'text': reason - } - self.message.reply_channel.send(data) - + data = {"close": True, "text": reason} + self.ws.send(data) -class DjangoChannelSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(DjangoChannelSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) +class DjangoChannelSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, message, connection_context): self.on_message(connection_context, message) - def send_message(self, connection_context, op_id=None, op_type=None, payload=None): - message = {} - if op_id is not None: - message['id'] = op_id - if op_type is not None: - message['type'] = op_type - if payload is not None: - message['payload'] = payload - - assert message, "You need to send at least one thing" - return connection_context.send({'text': json.dumps(message)}) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) +subscription_server = DjangoChannelSubscriptionServer(graphene_settings.SCHEMA) class GraphQLSubscriptionConsumer(JsonWebsocketConsumer): http_user_and_session = True strict_ordering = True - def connect(self, message, **_kwargs): + def connect(self, message, **kwargs): message.reply_channel.send({"accept": True}) - def receive(self, content, **_kwargs): + def receive(self, content, **kwargs): """ Called when a message is received with either text or bytes filled out. """ - self.connection_context = DjangoChannelConnectionContext(self.message) - self.subscription_server = DjangoChannelSubscriptionServer( - graphene_settings.SCHEMA) - self.subscription_server.on_open(self.connection_context) - self.subscription_server.handle(content, self.connection_context) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) + context = DjangoChannelConnectionContext(self.message) + subscription_server.on_open(context) + subscription_server.handle(content, context) diff --git a/graphql_ws/gevent.py b/graphql_ws/gevent.py index aadbe64..b7d6849 100644 --- a/graphql_ws/gevent.py +++ b/graphql_ws/gevent.py @@ -1,15 +1,15 @@ from __future__ import absolute_import -from rx import Observer, Observable -from graphql.execution.executors.sync import SyncExecutor +import json from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR + BaseConnectionContext, + ConnectionClosedException, +) +from .base_sync import BaseSyncSubscriptionServer class GeventConnectionContext(BaseConnectionContext): - def receive(self): msg = self.ws.receive() return msg @@ -17,7 +17,7 @@ def receive(self): def send(self, data): if self.closed: return - self.ws.send(data) + self.ws.send(json.dumps(data)) @property def closed(self): @@ -27,13 +27,7 @@ def close(self, code): self.ws.close(code) -class GeventSubscriptionServer(BaseSubscriptionServer): - - def get_graphql_params(self, *args, **kwargs): - params = super(GeventSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, executor=SyncExecutor()) - +class GeventSubscriptionServer(BaseSyncSubscriptionServer): def handle(self, ws, request_context=None): connection_context = GeventConnectionContext(ws, request_context) self.on_open(connection_context) @@ -46,62 +40,3 @@ def handle(self, ws, request_context=None): self.on_close(connection_context) return self.on_message(connection_context, message) - - def on_open(self, connection_context): - pass - - def on_connect(self, connection_context, payload): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - def on_connection_init(self, connection_context, op_id, payload): - try: - self.on_connect(connection_context, payload) - self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) - - except Exception as e: - self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) - connection_context.close(1011) - - def on_start(self, connection_context, op_id, params): - try: - execution_result = self.execute( - connection_context.request_context, params) - assert isinstance(execution_result, Observable), \ - "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( - connection_context, - op_id, - self.send_execution_result, - self.send_error, - self.on_close - )) - except Exception as e: - self.send_error(connection_context, op_id, str(e)) - - def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) - - -class SubscriptionObserver(Observer): - - def __init__(self, connection_context, op_id, - send_execution_result, send_error, on_close): - self.connection_context = connection_context - self.op_id = op_id - self.send_execution_result = send_execution_result - self.send_error = send_error - self.on_close = on_close - - def on_next(self, value): - self.send_execution_result(self.connection_context, self.op_id, value) - - def on_completed(self): - self.on_close(self.connection_context) - - def on_error(self, error): - self.send_error(self.connection_context, self.op_id, error) diff --git a/graphql_ws/observable_aiter.py b/graphql_ws/observable_aiter.py index 0bd1a59..424d95f 100644 --- a/graphql_ws/observable_aiter.py +++ b/graphql_ws/observable_aiter.py @@ -1,7 +1,7 @@ from asyncio import Future -from rx.internal import extensionmethod from rx.core import Observable +from rx.internal import extensionmethod async def __aiter__(self): @@ -13,15 +13,11 @@ def __init__(self): self.future = Future() self.disposable = source.materialize().subscribe(self.on_next) - # self.disposed = False def __aiter__(self): return self def dispose(self): - # self.future.cancel() - # self.disposed = True - # self.future.set_exception(StopAsyncIteration) self.disposable.dispose() def feeder(self): @@ -30,11 +26,11 @@ def feeder(self): notification = self.notifications.pop(0) kind = notification.kind - if kind == 'N': + if kind == "N": self.future.set_result(notification.value) - if kind == 'E': + if kind == "E": self.future.set_exception(notification.exception) - if kind == 'C': + if kind == "C": self.future.set_exception(StopAsyncIteration) def on_next(self, notification): @@ -42,8 +38,6 @@ def on_next(self, notification): self.feeder() async def __anext__(self): - # if self.disposed: - # raise StopAsyncIteration self.feeder() value = await self.future @@ -53,38 +47,5 @@ async def __anext__(self): return AIterator() -# def __aiter__(self, sentinel=None): -# loop = get_event_loop() -# future = [Future()] -# notifications = [] - -# def feeder(): -# if not len(notifications) or future[0].done(): -# return -# notification = notifications.pop(0) -# if notification.kind == "E": -# future[0].set_exception(notification.exception) -# elif notification.kind == "C": -# future[0].set_exception(StopIteration(sentinel)) -# else: -# future[0].set_result(notification.value) - -# def on_next(value): -# """Takes on_next values and appends them to the notification queue""" -# notifications.append(value) -# loop.call_soon(feeder) - -# self.materialize().subscribe(on_next) - -# @asyncio.coroutine -# def gen(): -# """Generator producing futures""" -# loop.call_soon(feeder) -# future[0] = Future() -# return future[0] - -# return gen - - def setup_observable_extension(): extensionmethod(Observable)(__aiter__) diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index 7e78d5d..c0adc67 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,19 +1,13 @@ -from inspect import isawaitable -from asyncio import ensure_future, wait, shield -from websockets import ConnectionClosed -from graphql.execution.executors.asyncio import AsyncioExecutor - -from .base import ( - ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer) -from .observable_aiter import setup_observable_extension +import json +from asyncio import shield -from .constants import ( - GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE) +from websockets import ConnectionClosed -setup_observable_extension() +from .base import ConnectionClosedException +from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer -class WsLibConnectionContext(BaseConnectionContext): +class WsLibConnectionContext(BaseAsyncConnectionContext): async def receive(self): try: msg = await self.ws.recv() @@ -24,7 +18,7 @@ async def receive(self): async def send(self, data): if self.closed: return - await self.ws.send(data) + await self.ws.send(json.dumps(data)) @property def closed(self): @@ -34,21 +28,10 @@ async def close(self, code): await self.ws.close(code) -class WsLibSubscriptionServer(BaseSubscriptionServer): - def __init__(self, schema, keep_alive=True, loop=None): - self.loop = loop - super().__init__(schema, keep_alive) - - def get_graphql_params(self, *args, **kwargs): - params = super(WsLibSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, - executor=AsyncioExecutor(loop=self.loop)) - +class WsLibSubscriptionServer(BaseAsyncSubscriptionServer): async def _handle(self, ws, request_context): connection_context = WsLibConnectionContext(ws, request_context) await self.on_open(connection_context) - pending = set() while True: try: if connection_context.closed: @@ -56,61 +39,9 @@ async def _handle(self, ws, request_context): message = await connection_context.receive() except ConnectionClosedException: break - finally: - if pending: - (_, pending) = await wait(pending, timeout=0, loop=self.loop) - - task = ensure_future( - self.on_message(connection_context, message), loop=self.loop) - pending.add(task) - self.on_close(connection_context) - for task in pending: - task.cancel() + self.on_message(connection_context, message) + await self.on_close(connection_context) async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) - - async def on_open(self, connection_context): - pass - - def on_close(self, connection_context): - remove_operations = list(connection_context.operations.keys()) - for op_id in remove_operations: - self.unsubscribe(connection_context, op_id) - - async def on_connect(self, connection_context, payload): - pass - - async def on_connection_init(self, connection_context, op_id, payload): - try: - await self.on_connect(connection_context, payload) - await self.send_message( - connection_context, op_type=GQL_CONNECTION_ACK) - except Exception as e: - await self.send_error( - connection_context, op_id, e, GQL_CONNECTION_ERROR) - await connection_context.close(1011) - - async def on_start(self, connection_context, op_id, params): - execution_result = self.execute( - connection_context.request_context, params) - - if isawaitable(execution_result): - execution_result = await execution_result - - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - - async def on_stop(self, connection_context, op_id): - self.unsubscribe(connection_context, op_id) diff --git a/setup.cfg b/setup.cfg index df50b23..3d07a80 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ [metadata] name = graphql-ws version = 0.3.1 -description = Websocket server for GraphQL subscriptions +description = Websocket backend for GraphQL subscriptions long_description = file: README.rst, CHANGES.rst author = Syrus Akbary author_email = me@syrusakbary.com @@ -15,15 +15,12 @@ classifiers = License :: OSI Approved :: MIT License Natural Language :: English Programming Language :: Python :: 2 - Programming Language :: Python :: 2.6 Programming Language :: Python :: 2.7 Programming Language :: Python :: 3 - Programming Language :: Python :: 3.3 - Programming Language :: Python :: 3.4 - Programming Language :: Python :: 3.5 Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 [options] zip_safe = False @@ -35,6 +32,7 @@ install_requires = [options.packages.find] include = graphql_ws + graphql_ws.* [options.extras_require] maintainer = @@ -90,3 +88,8 @@ ignore = W503 [coverage:run] omit = .tox/* + +[coverage:report] +exclude_lines = + pragma: no cover + @abstract diff --git a/tests/conftest.py b/tests/conftest.py index e551557..595968a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,5 @@ if sys.version_info > (3,): collect_ignore = ["test_django_channels.py"] - if sys.version_info < (3, 6): - collect_ignore.append('test_gevent.py') else: - collect_ignore = ["test_aiohttp.py"] + collect_ignore = ["test_aiohttp.py", "test_base_async.py"] diff --git a/tests/django_routing.py b/tests/django_routing.py new file mode 100644 index 0000000..9d01766 --- /dev/null +++ b/tests/django_routing.py @@ -0,0 +1,6 @@ +from channels.routing import route +from graphql_ws.django_channels import GraphQLSubscriptionConsumer + +channel_routing = [ + route("websocket.receive", GraphQLSubscriptionConsumer), +] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index f20ca15..40c43fd 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,15 +1,22 @@ +try: + from aiohttp import WSMsgType + from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer +except ImportError: # pragma: no cover + WSMsgType = None + from unittest import mock import pytest -from aiohttp import WSMsgType -from graphql_ws.aiohttp import AiohttpConnectionContext, AiohttpSubscriptionServer from graphql_ws.base import ConnectionClosedException +if_aiohttp_installed = pytest.mark.skipif( + WSMsgType is None, reason="aiohttp is not installed" +) + class AsyncMock(mock.Mock): def __call__(self, *args, **kwargs): - async def coro(): return super(AsyncMock, self).__call__(*args, **kwargs) @@ -24,6 +31,7 @@ def mock_ws(): return ws +@if_aiohttp_installed @pytest.mark.asyncio class TestConnectionContext: async def test_receive_good_data(self, mock_ws): @@ -55,7 +63,7 @@ async def test_receive_closed(self, mock_ws): async def test_send(self, mock_ws): connection_context = AiohttpConnectionContext(ws=mock_ws) await connection_context.send("test") - mock_ws.send_str.assert_called_with("test") + mock_ws.send_str.assert_called_with('"test"') async def test_send_closed(self, mock_ws): mock_ws.closed = True @@ -69,5 +77,6 @@ async def test_close(self, mock_ws): mock_ws.close.assert_called_with(code=123) +@if_aiohttp_installed def test_subscription_server_smoke(): AiohttpSubscriptionServer(schema=None) diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..1ce6300 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,112 @@ +try: + from unittest import mock +except ImportError: + import mock + +import json + +import pytest + +from graphql_ws import base +from graphql_ws.base_sync import SubscriptionObserver + + +def test_not_implemented(): + server = base.BaseSubscriptionServer(schema=None) + with pytest.raises(NotImplementedError): + server.on_connection_init(connection_context=None, op_id=1, payload={}) + with pytest.raises(NotImplementedError): + server.on_open(connection_context=None) + + +def test_on_stop(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.on_stop(connection_context=context, op_id=1) + context.unsubscribe.assert_called_with(1) + + +def test_terminate(): + server = base.BaseSubscriptionServer(schema=None) + + context = mock.Mock() + server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +def test_send_error(): + server = base.BaseSubscriptionServer(schema=None) + context = mock.Mock() + server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +def test_message(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +def test_message_str(): + server = base.BaseSubscriptionServer(schema=None) + server.process_message = mock.Mock() + context = mock.Mock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +def test_message_invalid(): + server = base.BaseSubscriptionServer(schema=None) + server.send_error = mock.Mock() + server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called + + +def test_context_operations(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + assert not context.has_operation(1) + context.register_operation(1, None) + assert context.has_operation(1) + context.remove_operation(1) + assert not context.has_operation(1) + # Removing a non-existant operation fails silently. + context.remove_operation(999) + + +def test_observer_data(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next('data') + assert send_result.called + assert not send_error.called + + +def test_observer_exception(): + ws = mock.Mock() + context = base.BaseConnectionContext(ws) + send_result, send_error, send_message = mock.Mock(), mock.Mock(), mock.Mock() + observer = SubscriptionObserver( + connection_context=context, + op_id=1, + send_execution_result=send_result, + send_error=send_error, + send_message=send_message, + ) + observer.on_next(TypeError('some bad message')) + assert send_error.called + assert not send_result.called diff --git a/tests/test_base_async.py b/tests/test_base_async.py new file mode 100644 index 0000000..d62eda5 --- /dev/null +++ b/tests/test_base_async.py @@ -0,0 +1,100 @@ +from unittest import mock + +import json +import promise + +import pytest + +from graphql_ws import base, base_async + +pytestmark = pytest.mark.asyncio + + +class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + +class TstServer(base_async.BaseAsyncSubscriptionServer): + def handle(self, *args, **kwargs): + pass # pragma: no cover + + +@pytest.fixture +def server(): + return TstServer(schema=None) + + +async def test_terminate(server: TstServer): + context = AsyncMock() + await server.on_connection_terminate(connection_context=context, op_id=1) + context.close.assert_called_with(1011) + + +async def test_send_error(server: TstServer): + context = AsyncMock() + context.has_operation = mock.Mock() + await server.send_error(connection_context=context, op_id=1, error="test error") + context.send.assert_called_with( + {"id": 1, "type": "error", "payload": {"message": "test error"}} + ) + + +async def test_message(server): + server.process_message = AsyncMock() + context = AsyncMock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + await server.on_message(context, msg) + server.process_message.assert_called_with(context, msg) + + +async def test_message_str(server): + server.process_message = AsyncMock() + context = AsyncMock() + msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} + await server.on_message(context, json.dumps(msg)) + server.process_message.assert_called_with(context, msg) + + +async def test_message_invalid(server): + server.send_error = AsyncMock() + await server.on_message(connection_context=None, message="'not-json") + assert server.send_error.called + + +async def test_resolver(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, 2]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + + +@pytest.mark.asyncio +async def test_resolver_with_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, 2]} + + +async def test_resolver_with_nested_promise(server): + server.send_message = AsyncMock() + result = mock.Mock() + inner = promise.Promise(lambda resolve, reject: resolve(2)) + outer = promise.Promise(lambda resolve, reject: resolve({'in': inner})) + result.data = {"test": [1, outer]} + result.errors = None + await server.send_execution_result( + connection_context=None, op_id=1, execution_result=result + ) + assert server.send_message.called + assert result.data == {'test': [1, {'in': 2}]} diff --git a/tests/test_django_channels.py b/tests/test_django_channels.py index e7b054c..0552c7b 100644 --- a/tests/test_django_channels.py +++ b/tests/test_django_channels.py @@ -1,11 +1,35 @@ +from __future__ import unicode_literals + +import json + +import django import mock +from channels import Channel +from channels.test import ChannelTestCase from django.conf import settings +from django.core.management import call_command -settings.configure() # noqa +settings.configure( + CHANNEL_LAYERS={ + "default": { + "BACKEND": "asgiref.inmemory.ChannelLayer", + "ROUTING": "tests.django_routing.channel_routing", + }, + }, + INSTALLED_APPS=[ + "django.contrib.sessions", + "django.contrib.contenttypes", + "django.contrib.auth", + ], + DATABASES={"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}}, +) +django.setup() -from graphql_ws.django_channels import ( +from graphql_ws.constants import GQL_CONNECTION_ACK, GQL_CONNECTION_INIT # noqa: E402 +from graphql_ws.django_channels import ( # noqa: E402 DjangoChannelConnectionContext, DjangoChannelSubscriptionServer, + GraphQLSubscriptionConsumer, ) @@ -14,7 +38,7 @@ def test_send(self): msg = mock.Mock() connection_context = DjangoChannelConnectionContext(message=msg) connection_context.send("test") - msg.reply_channel.send.assert_called_with("test") + msg.reply_channel.send.assert_called_with({"text": '"test"'}) def test_close(self): msg = mock.Mock() @@ -25,3 +49,21 @@ def test_close(self): def test_subscription_server_smoke(): DjangoChannelSubscriptionServer(schema=None) + + +class TestConsumer(ChannelTestCase): + def test_connect(self): + call_command("migrate") + Channel("websocket.receive").send( + { + "path": "/graphql", + "order": 0, + "reply_channel": "websocket.receive", + "text": json.dumps({"type": GQL_CONNECTION_INIT, "id": 1}), + } + ) + message = self.get_next_message("websocket.receive", require=True) + GraphQLSubscriptionConsumer(message) + result = self.get_next_message("websocket.receive", require=True) + result_content = json.loads(result.content["text"]) + assert result_content == {"type": GQL_CONNECTION_ACK} diff --git a/tests/test_gevent.py b/tests/test_gevent.py index f766c5a..a734970 100644 --- a/tests/test_gevent.py +++ b/tests/test_gevent.py @@ -17,8 +17,8 @@ def test_send(self): ws = mock.Mock() ws.closed = False connection_context = GeventConnectionContext(ws=ws) - connection_context.send("test") - ws.send.assert_called_with("test") + connection_context.send({"text": "test"}) + ws.send.assert_called_with('{"text": "test"}') def test_send_closed(self): ws = mock.Mock() diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 3ba1120..3b85c49 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -1,12 +1,14 @@ from collections import OrderedDict + try: from unittest import mock except ImportError: import mock import pytest +from graphql.execution.executors.sync import SyncExecutor -from graphql_ws import base, constants +from graphql_ws import base, base_sync, constants @pytest.fixture @@ -18,7 +20,7 @@ def cc(): @pytest.fixture def ss(): - return base.BaseSubscriptionServer(schema=None) + return base_sync.BaseSyncSubscriptionServer(schema=None) class TestConnectionContextOperation: @@ -93,13 +95,13 @@ def test_start_existing_op(self, ss, cc): ss.get_graphql_params.return_value = {"params": True} cc.has_operation = mock.Mock() cc.has_operation.return_value = True - ss.unsubscribe = mock.Mock() - ss.on_start = mock.Mock() + cc.unsubscribe = mock.Mock() + ss.execute = mock.Mock() + ss.send_message = mock.Mock() ss.process_message( cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} ) - assert ss.unsubscribe.called - ss.on_start.assert_called_with(cc, "1", {"params": True}) + assert cc.unsubscribe.called def test_start_bad_graphql_params(self, ss, cc): ss.get_graphql_params = mock.Mock() @@ -109,9 +111,7 @@ def test_start_bad_graphql_params(self, ss, cc): ss.send_error = mock.Mock() ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() - ss.process_message( - cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} - ) + ss.process_message(cc, {"id": "1", "type": None, "payload": {"a": "b"}}) assert ss.send_error.called assert ss.send_error.call_args[0][:2] == (cc, "1") assert isinstance(ss.send_error.call_args[0][2], Exception) @@ -135,13 +135,15 @@ def test_get_graphql_params(ss, cc): "query": "req", "variables": "vars", "operationName": "query", - "context": "ctx", + "context": {}, } - assert ss.get_graphql_params(cc, payload) == { + params = ss.get_graphql_params(cc, payload) + assert isinstance(params.pop("executor"), SyncExecutor) + assert params == { "request_string": "req", "variable_values": "vars", "operation_name": "query", - "context_value": "ctx", + "context_value": {}, } @@ -159,7 +161,8 @@ def test_build_message_partial(ss): assert ss.build_message(id=None, op_type=None, payload="PAYLOAD") == { "payload": "PAYLOAD" } - assert ss.build_message(id=None, op_type=None, payload=None) == {} + with pytest.raises(AssertionError): + ss.build_message(id=None, op_type=None, payload=None) def test_send_execution_result(ss): @@ -189,34 +192,10 @@ def test_send_message(ss, cc): cc.send = mock.Mock() cc.send.return_value = "returned" assert "returned" == ss.send_message(cc) - cc.send.assert_called_with('{"mess": "age"}') + cc.send.assert_called_with({"mess": "age"}) class TestSSNotImplemented: def test_handle(self, ss): with pytest.raises(NotImplementedError): ss.handle(ws=None, request_context=None) - - def test_on_open(self, ss): - with pytest.raises(NotImplementedError): - ss.on_open(connection_context=None) - - def test_on_connect(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connect(connection_context=None, payload=None) - - def test_on_close(self, ss): - with pytest.raises(NotImplementedError): - ss.on_close(connection_context=None) - - def test_on_connection_init(self, ss): - with pytest.raises(NotImplementedError): - ss.on_connection_init(connection_context=None, op_id=None, payload=None) - - def test_on_stop(self, ss): - with pytest.raises(NotImplementedError): - ss.on_stop(connection_context=None, op_id=None) - - def test_on_start(self, ss): - with pytest.raises(NotImplementedError): - ss.on_start(connection_context=None, op_id=None, params=None) diff --git a/tox.ini b/tox.ini index 6de6deb..62e2f8b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,15 +1,15 @@ [tox] -envlist = +envlist = coverage_setup - py27, py35, py36, py37, py38, flake8 + py27, py36, py37, py38, py39, flake8 coverage_report [travis] python = - 3.8: py38, flake8 + 3.9: py39, flake8 + 3.8: py38 3.7: py37 3.6: py36 - 3.5: py35 2.7: py27 [testenv] @@ -31,5 +31,6 @@ skip_install = true deps = coverage commands = coverage html + coverage xml coverage report --include="tests/*" --fail-under=100 -m - coverage report --omit="tests/*" # --fail-under=90 -m \ No newline at end of file + coverage report --omit="tests/*" # --fail-under=90 -m