Skip to content

Commit c666bef

Browse files
committed
Added batch support
1 parent ef3f607 commit c666bef

File tree

4 files changed

+125
-28
lines changed

4 files changed

+125
-28
lines changed

Diff for: README.md

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ from flask_graphql import GraphQLView
1313

1414
app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True))
1515

16+
# Optional, for adding batch query support (used in Apollo-Client)
17+
app.add_url_rule('/graphql/batch', view_func=GraphQLView.as_view('graphql', schema=schema, batch=True))
1618
```
1719

1820
This will add `/graphql` and `/graphiql` endpoints to your app.

Diff for: flask_graphql/graphqlview.py

+51-26
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class GraphQLView(View):
3232
graphiql_version = None
3333
graphiql_template = None
3434
middleware = None
35+
batch = False
3536

3637
methods = ['GET', 'POST', 'PUT', 'DELETE']
3738

@@ -41,6 +42,7 @@ def __init__(self, **kwargs):
4142
if hasattr(self, key):
4243
setattr(self, key, value)
4344

45+
assert not all((self.graphiql, self.batch)), 'Use either graphiql or batch processing'
4446
assert isinstance(self.schema, GraphQLSchema), 'A Schema is required to be provided to GraphQLView.'
4547

4648
# noinspection PyUnusedLocal
@@ -66,33 +68,15 @@ def dispatch_request(self):
6668
data = self.parse_body(request)
6769
show_graphiql = self.graphiql and self.can_display_graphiql(data)
6870

69-
query, variables, operation_name = self.get_graphql_params(request, data)
70-
71-
execution_result = self.execute_graphql_request(
72-
data,
73-
query,
74-
variables,
75-
operation_name,
76-
show_graphiql
77-
)
78-
79-
if execution_result:
80-
response = {}
81-
82-
if execution_result.errors:
83-
response['errors'] = [self.format_error(e) for e in execution_result.errors]
84-
85-
if execution_result.invalid:
86-
status_code = 400
87-
else:
88-
status_code = 200
89-
response['data'] = execution_result.data
90-
91-
result = self.json_encode(request, response)
71+
if self.batch:
72+
responses = [self.get_response(request, entry) for entry in data]
73+
result = '[{}]'.format(','.join([response[0] for response in responses]))
74+
status_code = max(responses, key=lambda response: response[1])[1]
9275
else:
93-
result = None
76+
result, status_code = self.get_response(request, data, show_graphiql)
9477

9578
if show_graphiql:
79+
query, variables, operation_name, id = self.get_graphql_params(request, data)
9680
return render_graphiql(
9781
graphiql_version=self.graphiql_version,
9882
graphiql_template=self.graphiql_template,
@@ -118,6 +102,43 @@ def dispatch_request(self):
118102
content_type='application/json'
119103
)
120104

105+
def get_response(self, request, data, show_graphiql=False):
106+
query, variables, operation_name, id = self.get_graphql_params(request, data)
107+
108+
execution_result = self.execute_graphql_request(
109+
data,
110+
query,
111+
variables,
112+
operation_name,
113+
show_graphiql
114+
)
115+
116+
status_code = 200
117+
if execution_result:
118+
response = {}
119+
120+
if execution_result.errors:
121+
response['errors'] = [self.format_error(e) for e in execution_result.errors]
122+
123+
if execution_result.invalid:
124+
status_code = 400
125+
else:
126+
status_code = 200
127+
response['data'] = execution_result.data
128+
129+
if self.batch:
130+
response = {
131+
'id': id,
132+
'payload': response,
133+
'status': status_code,
134+
}
135+
136+
result = self.json_encode(request, response)
137+
else:
138+
result = None
139+
140+
return result, status_code
141+
121142
def json_encode(self, request, d):
122143
if not self.pretty and not request.args.get('pretty'):
123144
return json.dumps(d, separators=(',', ':'))
@@ -134,7 +155,10 @@ def parse_body(self, request):
134155
elif content_type == 'application/json':
135156
try:
136157
request_json = json.loads(request.data.decode('utf8'))
137-
assert isinstance(request_json, dict)
158+
if self.batch:
159+
assert isinstance(request_json, list)
160+
else:
161+
assert isinstance(request_json, dict)
138162
return request_json
139163
except:
140164
raise HttpError(BadRequest('POST body sent invalid JSON.'))
@@ -207,6 +231,7 @@ def request_wants_html(cls, request):
207231
def get_graphql_params(request, data):
208232
query = request.args.get('query') or data.get('query')
209233
variables = request.args.get('variables') or data.get('variables')
234+
id = request.args.get('id') or data.get('id')
210235

211236
if variables and isinstance(variables, six.text_type):
212237
try:
@@ -216,7 +241,7 @@ def get_graphql_params(request, data):
216241

217242
operation_name = request.args.get('operationName') or data.get('operationName')
218243

219-
return query, variables, operation_name
244+
return query, variables, operation_name, id
220245

221246
@staticmethod
222247
def format_error(error):

Diff for: tests/app.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from .schema import Schema
44

55

6-
def create_app(**kwargs):
6+
def create_app(path='/graphql', **kwargs):
77
app = Flask(__name__)
88
app.debug = True
9-
app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=Schema, **kwargs))
9+
app.add_url_rule(path, view_func=GraphQLView.as_view('graphql', schema=Schema, **kwargs))
1010
return app
1111

1212

Diff for: tests/test_graphqlview.py

+70
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def response_json(response):
3333

3434

3535
j = lambda **kwargs: json.dumps(kwargs)
36+
jl = lambda **kwargs: json.dumps([kwargs])
37+
3638

3739
def test_allows_get_with_query_param(client):
3840
response = client.get(url_string(query='{test}'))
@@ -453,3 +455,71 @@ def test_post_multipart_data(client):
453455

454456
assert response.status_code == 200
455457
assert response_json(response) == {'data': {u'writeTest': {u'test': u'Hello World'}}}
458+
459+
460+
@pytest.mark.parametrize('app', [create_app(batch=True)])
461+
def test_batch_allows_post_with_json_encoding(client):
462+
response = client.post(
463+
url_string(),
464+
data=jl(id=1, query='{test}'),
465+
content_type='application/json'
466+
)
467+
468+
assert response.status_code == 200
469+
assert response_json(response) == [{
470+
'id': 1,
471+
'payload': { 'data': {'test': "Hello World"} },
472+
'status': 200,
473+
}]
474+
475+
476+
@pytest.mark.parametrize('app', [create_app(batch=True)])
477+
def test_batch_supports_post_json_query_with_json_variables(client):
478+
response = client.post(
479+
url_string(),
480+
data=jl(
481+
id=1,
482+
query='query helloWho($who: String){ test(who: $who) }',
483+
variables={'who': "Dolly"}
484+
),
485+
content_type='application/json'
486+
)
487+
488+
assert response.status_code == 200
489+
assert response_json(response) == [{
490+
'id': 1,
491+
'payload': { 'data': {'test': "Hello Dolly"} },
492+
'status': 200,
493+
}]
494+
495+
496+
@pytest.mark.parametrize('app', [create_app(batch=True)])
497+
def test_batch_allows_post_with_operation_name(client):
498+
response = client.post(
499+
url_string(),
500+
data=jl(
501+
id=1,
502+
query='''
503+
query helloYou { test(who: "You"), ...shared }
504+
query helloWorld { test(who: "World"), ...shared }
505+
query helloDolly { test(who: "Dolly"), ...shared }
506+
fragment shared on QueryRoot {
507+
shared: test(who: "Everyone")
508+
}
509+
''',
510+
operationName='helloWorld'
511+
),
512+
content_type='application/json'
513+
)
514+
515+
assert response.status_code == 200
516+
assert response_json(response) == [{
517+
'id': 1,
518+
'payload': {
519+
'data': {
520+
'test': 'Hello World',
521+
'shared': 'Hello Everyone'
522+
}
523+
},
524+
'status': 200,
525+
}]

0 commit comments

Comments
 (0)