Skip to content

Commit d7d16bd

Browse files
committed
fix(#343): reload relationships from database if included during POST
1 parent 262d341 commit d7d16bd

File tree

4 files changed

+218
-21
lines changed

4 files changed

+218
-21
lines changed

src/JsonApiDotNetCore/Data/DefaultEntityRepository.cs

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
namespace JsonApiDotNetCore.Data
1515
{
16+
/// <inheritdoc />
1617
public class DefaultEntityRepository<TEntity>
1718
: DefaultEntityRepository<TEntity, int>,
1819
IEntityRepository<TEntity>
@@ -26,6 +27,10 @@ public DefaultEntityRepository(
2627
{ }
2728
}
2829

30+
/// <summary>
31+
/// Provides a default repository implementation and is responsible for
32+
/// abstracting any EF Core APIs away from the service layer.
33+
/// </summary>
2934
public class DefaultEntityRepository<TEntity, TId>
3035
: IEntityRepository<TEntity, TId>
3136
where TEntity : class, IIdentifiable<TId>
@@ -48,7 +53,7 @@ public DefaultEntityRepository(
4853
_genericProcessorFactory = _jsonApiContext.GenericProcessorFactory;
4954
}
5055

51-
/// </ inheritdoc>
56+
/// <inheritdoc />
5257
public virtual IQueryable<TEntity> Get()
5358
{
5459
if (_jsonApiContext.QuerySet?.Fields != null && _jsonApiContext.QuerySet.Fields.Count > 0)
@@ -57,41 +62,43 @@ public virtual IQueryable<TEntity> Get()
5762
return _dbSet;
5863
}
5964

60-
/// </ inheritdoc>
65+
/// <inheritdoc />
6166
public virtual IQueryable<TEntity> Filter(IQueryable<TEntity> entities, FilterQuery filterQuery)
6267
{
6368
return entities.Filter(_jsonApiContext, filterQuery);
6469
}
6570

66-
/// </ inheritdoc>
71+
/// <inheritdoc />
6772
public virtual IQueryable<TEntity> Sort(IQueryable<TEntity> entities, List<SortQuery> sortQueries)
6873
{
6974
return entities.Sort(sortQueries);
7075
}
7176

72-
/// </ inheritdoc>
77+
/// <inheritdoc />
7378
public virtual async Task<TEntity> GetAsync(TId id)
7479
{
7580
return await Get().SingleOrDefaultAsync(e => e.Id.Equals(id));
7681
}
7782

78-
/// </ inheritdoc>
83+
/// <inheritdoc />
7984
public virtual async Task<TEntity> GetAndIncludeAsync(TId id, string relationshipName)
8085
{
8186
_logger.LogDebug($"[JADN] GetAndIncludeAsync({id}, {relationshipName})");
8287

83-
var result = await Include(Get(), relationshipName).SingleOrDefaultAsync(e => e.Id.Equals(id));
88+
var includedSet = await IncludeAsync(Get(), relationshipName);
89+
var result = await includedSet.SingleOrDefaultAsync(e => e.Id.Equals(id));
8490

8591
return result;
8692
}
8793

88-
/// </ inheritdoc>
94+
/// <inheritdoc />
8995
public virtual async Task<TEntity> CreateAsync(TEntity entity)
9096
{
9197
AttachRelationships();
9298
_dbSet.Add(entity);
9399

94100
await _context.SaveChangesAsync();
101+
95102
return entity;
96103
}
97104

@@ -129,7 +136,7 @@ private void AttachHasOnePointers()
129136
_context.Entry(relationship.Value).State = EntityState.Unchanged;
130137
}
131138

132-
/// </ inheritdoc>
139+
/// <inheritdoc />
133140
public virtual async Task<TEntity> UpdateAsync(TId id, TEntity entity)
134141
{
135142
var oldEntity = await GetAsync(id);
@@ -148,14 +155,14 @@ public virtual async Task<TEntity> UpdateAsync(TId id, TEntity entity)
148155
return oldEntity;
149156
}
150157

151-
/// </ inheritdoc>
158+
/// <inheritdoc />
152159
public async Task UpdateRelationshipsAsync(object parent, RelationshipAttribute relationship, IEnumerable<string> relationshipIds)
153160
{
154161
var genericProcessor = _genericProcessorFactory.GetProcessor<IGenericProcessor>(typeof(GenericProcessor<>), relationship.Type);
155162
await genericProcessor.UpdateRelationshipsAsync(parent, relationship, relationshipIds);
156163
}
157164

158-
/// </ inheritdoc>
165+
/// <inheritdoc />
159166
public virtual async Task<bool> DeleteAsync(TId id)
160167
{
161168
var entity = await GetAsync(id);
@@ -170,7 +177,8 @@ public virtual async Task<bool> DeleteAsync(TId id)
170177
return true;
171178
}
172179

173-
/// </ inheritdoc>
180+
/// <inheritdoc />
181+
[Obsolete("Use IncludeAsync")]
174182
public virtual IQueryable<TEntity> Include(IQueryable<TEntity> entities, string relationshipName)
175183
{
176184
var entity = _jsonApiContext.RequestEntity;
@@ -185,10 +193,57 @@ public virtual IQueryable<TEntity> Include(IQueryable<TEntity> entities, string
185193
{
186194
throw new JsonApiException(400, $"Including the relationship {relationshipName} on {entity.EntityName} is not allowed");
187195
}
196+
197+
return entities.Include(relationship.InternalRelationshipName);
198+
}
199+
200+
/// <inheritdoc />
201+
public virtual async Task<IQueryable<TEntity>> IncludeAsync(IQueryable<TEntity> entities, string relationshipName)
202+
{
203+
var entity = _jsonApiContext.RequestEntity;
204+
var relationship = entity.Relationships.FirstOrDefault(r => r.PublicRelationshipName == relationshipName);
205+
if (relationship == null)
206+
{
207+
throw new JsonApiException(400, $"Invalid relationship {relationshipName} on {entity.EntityName}",
208+
$"{entity.EntityName} does not have a relationship named {relationshipName}");
209+
}
210+
211+
if (!relationship.CanInclude)
212+
{
213+
throw new JsonApiException(400, $"Including the relationship {relationshipName} on {entity.EntityName} is not allowed");
214+
}
215+
216+
await ReloadPointerAsync(relationship);
217+
188218
return entities.Include(relationship.InternalRelationshipName);
189219
}
190220

191-
/// </ inheritdoc>
221+
/// <summary>
222+
/// Ensure relationships on the provided entity have been fully loaded from the database.
223+
/// </summary>
224+
/// <remarks>
225+
/// The only known case when this should be called is when a POST request is
226+
/// sent with an ?include query.
227+
///
228+
/// See https://github.com/json-api-dotnet/JsonApiDotNetCore/issues/343
229+
/// </remarks>
230+
private async Task ReloadPointerAsync(RelationshipAttribute relationshipAttr)
231+
{
232+
if (relationshipAttr.IsHasOne && _jsonApiContext.HasOneRelationshipPointers.Get().TryGetValue(relationshipAttr, out var pointer))
233+
{
234+
await _context.Entry(pointer).ReloadAsync();
235+
}
236+
237+
if (relationshipAttr.IsHasMany && _jsonApiContext.HasManyRelationshipPointers.Get().TryGetValue(relationshipAttr, out var pointers))
238+
{
239+
foreach (var hasManyPointer in pointers)
240+
{
241+
await _context.Entry(hasManyPointer).ReloadAsync();
242+
}
243+
}
244+
}
245+
246+
/// <inheritdoc />
192247
public virtual async Task<IEnumerable<TEntity>> PageAsync(IQueryable<TEntity> entities, int pageSize, int pageNumber)
193248
{
194249
if (pageNumber >= 0)
@@ -209,23 +264,23 @@ public virtual async Task<IEnumerable<TEntity>> PageAsync(IQueryable<TEntity> en
209264
.ToListAsync();
210265
}
211266

212-
/// </ inheritdoc>
267+
/// <inheritdoc />
213268
public async Task<int> CountAsync(IQueryable<TEntity> entities)
214269
{
215270
return (entities is IAsyncEnumerable<TEntity>)
216271
? await entities.CountAsync()
217272
: entities.Count();
218273
}
219274

220-
/// </ inheritdoc>
275+
/// <inheritdoc />
221276
public async Task<TEntity> FirstOrDefaultAsync(IQueryable<TEntity> entities)
222277
{
223278
return (entities is IAsyncEnumerable<TEntity>)
224279
? await entities.FirstOrDefaultAsync()
225280
: entities.FirstOrDefault();
226281
}
227282

228-
/// </ inheritdoc>
283+
/// <inheritdoc />
229284
public async Task<IReadOnlyList<TEntity>> ToListAsync(IQueryable<TEntity> entities)
230285
{
231286
return (entities is IAsyncEnumerable<TEntity>)

src/JsonApiDotNetCore/Data/IEntityReadRepository.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Collections.Generic;
23
using System.Linq;
34
using System.Threading.Tasks;
@@ -20,6 +21,9 @@ public interface IEntityReadRepository<TEntity, in TId>
2021
/// </summary>
2122
IQueryable<TEntity> Get();
2223

24+
[Obsolete("Use IncludeAsync")]
25+
IQueryable<TEntity> Include(IQueryable<TEntity> entities, string relationshipName);
26+
2327
/// <summary>
2428
/// Include a relationship in the query
2529
/// </summary>
@@ -28,7 +32,7 @@ public interface IEntityReadRepository<TEntity, in TId>
2832
/// _todoItemsRepository.GetAndIncludeAsync(1, "achieved-date");
2933
/// </code>
3034
/// </example>
31-
IQueryable<TEntity> Include(IQueryable<TEntity> entities, string relationshipName);
35+
Task<IQueryable<TEntity>> IncludeAsync(IQueryable<TEntity> entities, string relationshipName);
3236

3337
/// <summary>
3438
/// Apply a filter to the provided queryable

src/JsonApiDotNetCore/Services/EntityResourceService.cs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ public virtual async Task<TResource> CreateAsync(TResource resource)
7777

7878
entity = await _entities.CreateAsync(entity);
7979

80+
// this ensures relationships get reloaded from the database if they have
81+
// been requested
82+
// https://github.com/json-api-dotnet/JsonApiDotNetCore/issues/343
83+
if (ShouldIncludeRelationships())
84+
return await GetWithRelationshipsAsync(entity.Id);
85+
8086
return MapOut(entity);
8187
}
8288

@@ -92,7 +98,7 @@ public virtual async Task<IEnumerable<TResource>> GetAsync()
9298
entities = ApplySortAndFilterQuery(entities);
9399

94100
if (ShouldIncludeRelationships())
95-
entities = IncludeRelationships(entities, _jsonApiContext.QuerySet.IncludedRelationships);
101+
entities = await IncludeRelationshipsAsync(entities, _jsonApiContext.QuerySet.IncludedRelationships);
96102

97103
if (_jsonApiContext.Options.IncludeTotalRecordCount)
98104
_jsonApiContext.PageManager.TotalRecords = await _entities.CountAsync(entities);
@@ -218,7 +224,8 @@ protected virtual IQueryable<TEntity> ApplySortAndFilterQuery(IQueryable<TEntity
218224
return entities;
219225
}
220226

221-
protected virtual IQueryable<TEntity> IncludeRelationships(IQueryable<TEntity> entities, List<string> relationships)
227+
[Obsolete("Use IncludeRelationshipsAsync")]
228+
protected IQueryable<TEntity> IncludeRelationships(IQueryable<TEntity> entities, List<string> relationships)
222229
{
223230
_jsonApiContext.IncludedRelationships = relationships;
224231

@@ -228,14 +235,24 @@ protected virtual IQueryable<TEntity> IncludeRelationships(IQueryable<TEntity> e
228235
return entities;
229236
}
230237

238+
protected virtual async Task<IQueryable<TEntity>> IncludeRelationshipsAsync(IQueryable<TEntity> entities, List<string> relationships)
239+
{
240+
_jsonApiContext.IncludedRelationships = relationships;
241+
242+
foreach (var r in relationships)
243+
entities = await _entities.IncludeAsync(entities, r);
244+
245+
return entities;
246+
}
247+
231248
private async Task<TResource> GetWithRelationshipsAsync(TId id)
232249
{
233250
var query = _entities.Get().Where(e => e.Id.Equals(id));
234251

235-
_jsonApiContext.QuerySet.IncludedRelationships.ForEach(r =>
252+
foreach (var r in _jsonApiContext.QuerySet.IncludedRelationships)
236253
{
237-
query = _entities.Include(query, r);
238-
});
254+
query = await _entities.IncludeAsync(query, r);
255+
}
239256

240257
var value = await _entities.FirstOrDefaultAsync(query);
241258

0 commit comments

Comments
 (0)