namespace SharedDATA.Api
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using SharedDATA.Context;
///
/// Represents a default generic repository implements the interface.
///
/// The type of the entity.
public class Repository : IRepository where TEntity : class
{
protected readonly DbContext _dbContext;
protected readonly DbSet _dbSet;
///
/// Initializes a new instance of the class.
///
/// The database context.
public Repository(DbContext dbContext)
{
_dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
_dbSet = _dbContext.Set();
}
///
/// Changes the table name. This require the tables in the same database.
///
///
///
/// This only been used for supporting multiple tables in the same model. This require the tables in the same database.
///
public virtual void ChangeTable(string table)
{
if (_dbContext.Model.FindEntityType(typeof(TEntity)) is IConventionEntityType relational)
{
relational.SetTableName(table);
}
}
///
/// Gets all entities. This method is not recommended
///
/// The .
public IQueryable GetAll()
{
return _dbSet;
}
///
/// Gets all entities. This method is not recommended
///
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// true to disable changing tracking; otherwise, false. Default to true.
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// Ex: This method defaults to a read-only, no-tracking query.
public IQueryable GetAll(
Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null, bool disableTracking = true, bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query);
}
else
{
return query;
}
}
///
/// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query.
///
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// The index of page.
/// The size of the page.
/// True to disable changing tracking; otherwise, false. Default to true.
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// This method default no-tracking query.
public virtual IPagedList GetPagedList(Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
int pageIndex = 0,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).ToPagedList(pageIndex, pageSize);
}
else
{
return query.ToPagedList(pageIndex, pageSize);
}
}
///
/// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query.
///
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// The index of page.
/// The size of the page.
/// True to disable changing tracking; otherwise, false. Default to true.
///
/// A to observe while waiting for the task to complete.
///
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// This method default no-tracking query.
public virtual Task> GetPagedListAsync(Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
int pageIndex = 0,
int pageSize = 20,
bool disableTracking = true,
CancellationToken cancellationToken = default(CancellationToken),
bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).ToPagedListAsync(pageIndex, pageSize, 0, cancellationToken);
}
else
{
return query.ToPagedListAsync(pageIndex, pageSize, 0, cancellationToken);
}
}
///
/// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query.
///
/// The selector for projection.
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// The index of page.
/// The size of the page.
/// True to disable changing tracking; otherwise, false. Default to true.
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// This method default no-tracking query.
public virtual IPagedList GetPagedList(Expression> selector,
Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
int pageIndex = 0,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false)
where TResult : class
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).Select(selector).ToPagedList(pageIndex, pageSize);
}
else
{
return query.Select(selector).ToPagedList(pageIndex, pageSize);
}
}
///
/// Gets the based on a predicate, orderby delegate and page information. This method default no-tracking query.
///
/// The selector for projection.
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// The index of page.
/// The size of the page.
/// True to disable changing tracking; otherwise, false. Default to true.
///
/// A to observe while waiting for the task to complete.
///
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// This method default no-tracking query.
public virtual Task> GetPagedListAsync(Expression> selector,
Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
int pageIndex = 0,
int pageSize = 20,
bool disableTracking = true,
CancellationToken cancellationToken = default(CancellationToken),
bool ignoreQueryFilters = false)
where TResult : class
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).Select(selector).ToPagedListAsync(pageIndex, pageSize, 0, cancellationToken);
}
else
{
return query.Select(selector).ToPagedListAsync(pageIndex, pageSize, 0, cancellationToken);
}
}
///
/// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method default no-tracking query.
///
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// True to disable changing tracking; otherwise, false. Default to true.
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// This method default no-tracking query.
public virtual TEntity GetFirstOrDefault(Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).FirstOrDefault();
}
else
{
return query.FirstOrDefault();
}
}
///
public virtual async Task GetFirstOrDefaultAsync(Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).FirstOrDefaultAsync();
}
else
{
return await query.FirstOrDefaultAsync();
}
}
///
/// Gets the first or default entity based on a predicate, orderby delegate and include delegate. This method default no-tracking query.
///
/// The selector for projection.
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// True to disable changing tracking; otherwise, false. Default to true.
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// This method default no-tracking query.
public virtual TResult GetFirstOrDefault(Expression> selector,
Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).Select(selector).FirstOrDefault();
}
else
{
return query.Select(selector).FirstOrDefault();
}
}
///
public virtual async Task GetFirstOrDefaultAsync(Expression> selector,
Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
bool disableTracking = true, bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).Select(selector).FirstOrDefaultAsync();
}
else
{
return await query.Select(selector).FirstOrDefaultAsync();
}
}
///
/// Uses raw SQL queries to fetch the specified data.
///
/// The raw SQL.
/// The parameters.
/// An that contains elements that satisfy the condition specified by raw SQL.
public virtual IQueryable FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters);
///
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
///
/// The values of the primary key for the entity to be found.
/// The found entity or null.
public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues);
///
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
///
/// The values of the primary key for the entity to be found.
/// A that represents the asynchronous insert operation.
public virtual ValueTask FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues);
///
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
///
/// The values of the primary key for the entity to be found.
/// A to observe while waiting for the task to complete.
/// A that represents the asynchronous find operation. The task result contains the found entity or null.
public virtual ValueTask FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken);
///
/// Gets the count based on a predicate.
///
///
///
public virtual int Count(Expression> predicate = null)
{
if (predicate == null)
{
return _dbSet.Count();
}
else
{
return _dbSet.Count(predicate);
}
}
///
/// Gets async the count based on a predicate.
///
///
///
public virtual async Task CountAsync(Expression> predicate = null)
{
if (predicate == null)
{
return await _dbSet.CountAsync();
}
else
{
return await _dbSet.CountAsync(predicate);
}
}
///
/// Gets the long count based on a predicate.
///
///
///
public virtual long LongCount(Expression> predicate = null)
{
if (predicate == null)
{
return _dbSet.LongCount();
}
else
{
return _dbSet.LongCount(predicate);
}
}
///
/// Gets async the long count based on a predicate.
///
///
///
public virtual async Task LongCountAsync(Expression> predicate = null)
{
if (predicate == null)
{
return await _dbSet.LongCountAsync();
}
else
{
return await _dbSet.LongCountAsync(predicate);
}
}
///
/// Gets the max based on a predicate.
///
///
/// ///
/// decimal
public virtual T Max(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return _dbSet.Max(selector);
else
return _dbSet.Where(predicate).Max(selector);
}
///
/// Gets the async max based on a predicate.
///
///
/// ///
/// decimal
public virtual async Task MaxAsync(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return await _dbSet.MaxAsync(selector);
else
return await _dbSet.Where(predicate).MaxAsync(selector);
}
///
/// Gets the min based on a predicate.
///
///
/// ///
/// decimal
public virtual T Min(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return _dbSet.Min(selector);
else
return _dbSet.Where(predicate).Min(selector);
}
///
/// Gets the async min based on a predicate.
///
///
/// ///
/// decimal
public virtual async Task MinAsync(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return await _dbSet.MinAsync(selector);
else
return await _dbSet.Where(predicate).MinAsync(selector);
}
///
/// Gets the average based on a predicate.
///
///
/// ///
/// decimal
public virtual decimal Average(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return _dbSet.Average(selector);
else
return _dbSet.Where(predicate).Average(selector);
}
///
/// Gets the async average based on a predicate.
///
///
/// ///
/// decimal
public virtual async Task AverageAsync(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return await _dbSet.AverageAsync(selector);
else
return await _dbSet.Where(predicate).AverageAsync(selector);
}
///
/// Gets the sum based on a predicate.
///
///
/// ///
/// decimal
public virtual decimal Sum(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return _dbSet.Sum(selector);
else
return _dbSet.Where(predicate).Sum(selector);
}
///
/// Gets the async sum based on a predicate.
///
///
/// ///
/// decimal
public virtual async Task SumAsync(Expression> predicate = null, Expression> selector = null)
{
if (predicate == null)
return await _dbSet.SumAsync(selector);
else
return await _dbSet.Where(predicate).SumAsync(selector);
}
///
/// Gets the exists based on a predicate.
///
///
///
public bool Exists(Expression> selector = null)
{
if (selector == null)
{
return _dbSet.Any();
}
else
{
return _dbSet.Any(selector);
}
}
///
/// Gets the async exists based on a predicate.
///
///
///
public async Task ExistsAsync(Expression> selector = null)
{
if (selector == null)
{
return await _dbSet.AnyAsync();
}
else
{
return await _dbSet.AnyAsync(selector);
}
}
///
/// Inserts a new entity synchronously.
///
/// The entity to insert.
public virtual TEntity Insert(TEntity entity)
{
return _dbSet.Add(entity).Entity;
}
///
/// Inserts a range of entities synchronously.
///
/// The entities to insert.
public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities);
///
/// Inserts a range of entities synchronously.
///
/// The entities to insert.
public virtual void Insert(IEnumerable entities) => _dbSet.AddRange(entities);
///
/// Inserts a new entity asynchronously.
///
/// The entity to insert.
/// A to observe while waiting for the task to complete.
/// A that represents the asynchronous insert operation.
public virtual ValueTask> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken))
{
return _dbSet.AddAsync(entity, cancellationToken);
// Shadow properties?
//var property = _dbContext.Entry(entity).Property("Created");
//if (property != null) {
//property.CurrentValue = DateTime.Now;
//}
}
///
/// Inserts a range of entities asynchronously.
///
/// The entities to insert.
/// A that represents the asynchronous insert operation.
public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities);
///
/// Inserts a range of entities asynchronously.
///
/// The entities to insert.
/// A to observe while waiting for the task to complete.
/// A that represents the asynchronous insert operation.
public virtual Task InsertAsync(IEnumerable entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken);
///
/// Updates the specified entity.
///
/// The entity.
public virtual void Update(TEntity entity)
{
_dbSet.Update(entity);
}
///
/// Updates the specified entity.
///
/// The entity.
public virtual void UpdateAsync(TEntity entity)
{
_dbSet.Update(entity);
}
///
/// Updates the specified entities.
///
/// The entities.
public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities);
///
/// Updates the specified entities.
///
/// The entities.
public virtual void Update(IEnumerable entities) => _dbSet.UpdateRange(entities);
///
/// Deletes the specified entity.
///
/// The entity to delete.
public virtual void Delete(TEntity entity) => _dbSet.Remove(entity);
///
/// Deletes the entity by the specified primary key.
///
/// The primary key value.
public virtual void Delete(object id)
{
// using a stub entity to mark for deletion
var typeInfo = typeof(TEntity).GetTypeInfo();
var key = _dbContext.Model.FindEntityType(typeInfo).FindPrimaryKey().Properties.FirstOrDefault();
var property = typeInfo.GetProperty(key?.Name);
if (property != null)
{
var entity = Activator.CreateInstance();
property.SetValue(entity, id);
_dbContext.Entry(entity).State = EntityState.Deleted;
}
else
{
var entity = _dbSet.Find(id);
if (entity != null)
{
Delete(entity);
}
}
}
///
/// Deletes the specified entities.
///
/// The entities.
public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities);
///
/// Deletes the specified entities.
///
/// The entities.
public virtual void Delete(IEnumerable entities) => _dbSet.RemoveRange(entities);
///
/// Gets all entities. This method is not recommended
///
/// The .
public async Task> GetAllAsync()
{
return await _dbSet.ToListAsync();
}
///
/// Gets all entities. This method is not recommended
///
/// A function to test each element for a condition.
/// A function to order elements.
/// A function to include navigation properties
/// true to disable changing tracking; otherwise, false. Default to true.
/// Ignore query filters
/// An that contains elements that satisfy the condition specified by .
/// Ex: This method defaults to a read-only, no-tracking query.
public async Task> GetAllAsync(Expression> predicate = null,
Func, IOrderedQueryable> orderBy = null,
Func, IIncludableQueryable> include = null,
bool disableTracking = true, bool ignoreQueryFilters = false)
{
IQueryable query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).ToListAsync();
}
else
{
return await query.ToListAsync();
}
}
///
/// Change entity state for patch method on web api.
///
/// The entity.
/// /// The entity state.
public void ChangeEntityState(TEntity entity, EntityState state)
{
_dbContext.Entry(entity).State = state;
}
}
}