using System.Collections.Immutable;
using Dpz.Core.SourceGenerator.Models;
using Microsoft.CodeAnalysis;
namespace Dpz.Core.SourceGenerator;
/// <summary>
/// 从服务实现方法上读取缓存特性,并转换为缓存装饰器生成所需的元数据。
/// </summary>
/// <remarks>
/// 缓存属于实现策略,因此这里只读取实现类方法上的 <c>CacheAttribute</c>,
/// 不读取接口方法上的同名特性。
/// </remarks>
internal static class CacheMethodInspector
{
#pragma warning disable RS2008
private static readonly DiagnosticDescriptor UnsupportedCacheMethodDescriptor = new(
"DPZ_CACHE001",
"Unsupported cacheable service method",
"Cannot generate cache decorator for '{0}.{1}': {2}",
"Dpz.Core.Cache",
DiagnosticSeverity.Error,
true
);
private static readonly DiagnosticDescriptor MissingFusionCacheDescriptor = new(
"DPZ_CACHE002",
"Cacheable service requires IFusionCache",
"Cannot generate cache decorator for '{0}.{1}': implementation type '{2}' must inject IFusionCache",
"Dpz.Core.Cache",
DiagnosticSeverity.Error,
true
);
private static readonly DiagnosticDescriptor InvalidPostProcessDescriptor = new(
"DPZ_CACHE003",
"Invalid cache post-process method",
"Cannot generate cache decorator for '{0}.{1}': {2}",
"Dpz.Core.Cache",
DiagnosticSeverity.Error,
true
);
private static readonly DiagnosticDescriptor InvalidInvalidateCacheDescriptor = new(
"DPZ_CACHE004",
"Invalid cache invalidation method",
"Cannot generate cache invalidation for '{0}.{1}': {2}",
"Dpz.Core.Cache",
DiagnosticSeverity.Error,
true
);
#pragma warning restore RS2008
internal static ImmutableArray<CachedMethod> GetCachedMethods(
INamedTypeSymbol interfaceType,
INamedTypeSymbol implementationType,
SourceProductionContext sourceProductionContext
)
{
var methods = ImmutableArray.CreateBuilder<CachedMethod>();
var interfaceMethods = GetInterfaceMethods(interfaceType);
foreach (var implementationMethod in GetCacheAttributedMethods(implementationType))
{
var attribute = GetAttribute(
implementationMethod,
SourceGeneratorAttributeNames.CacheAttributeFullName
);
var interfaceMethod = FindInterfaceMethod(
implementationType,
interfaceMethods,
implementationMethod
);
if (interfaceMethod is null)
{
sourceProductionContext.ReportDiagnostic(
Diagnostic.Create(
UnsupportedCacheMethodDescriptor,
implementationMethod.Locations.FirstOrDefault(),
implementationType.ToDisplayString(),
implementationMethod.Name,
"method is not part of the registered service interface"
)
);
continue;
}
if (!HasFusionCacheConstructorParameter(implementationType))
{
sourceProductionContext.ReportDiagnostic(
Diagnostic.Create(
MissingFusionCacheDescriptor,
implementationMethod?.Locations.FirstOrDefault()
?? interfaceMethod.Locations.FirstOrDefault(),
interfaceType.ToDisplayString(),
interfaceMethod.Name,
implementationType.ToDisplayString()
)
);
continue;
}
if (
!TryCreateCachedMethod(
interfaceMethod,
implementationType,
attribute!,
out var cachedMethod,
out var message
)
)
{
sourceProductionContext.ReportDiagnostic(
Diagnostic.Create(
message.StartsWith("post-process", StringComparison.Ordinal)
? InvalidPostProcessDescriptor
: UnsupportedCacheMethodDescriptor,
implementationMethod?.Locations.FirstOrDefault()
?? interfaceMethod.Locations.FirstOrDefault(),
interfaceType.ToDisplayString(),
interfaceMethod.Name,
message
)
);
continue;
}
methods.Add(cachedMethod);
}
return methods.ToImmutable();
}
internal static ImmutableArray<InvalidatedMethod> GetInvalidatedMethods(
INamedTypeSymbol interfaceType,
INamedTypeSymbol implementationType,
ImmutableArray<CachedMethod> cachedMethods,
SourceProductionContext sourceProductionContext
)
{
var methods = ImmutableArray.CreateBuilder<InvalidatedMethod>();
var interfaceMethods = GetInterfaceMethods(interfaceType);
var cachedMethodNames = new HashSet<string>(
cachedMethods.Select(static method => method.Name)
);
foreach (var implementationMethod in GetInvalidateAttributedMethods(implementationType))
{
var attribute = GetAttribute(
implementationMethod,
SourceGeneratorAttributeNames.InvalidateCacheAttributeFullName
);
var interfaceMethod = FindInterfaceMethod(
implementationType,
interfaceMethods,
implementationMethod
);
if (interfaceMethod is null)
{
sourceProductionContext.ReportDiagnostic(
Diagnostic.Create(
InvalidInvalidateCacheDescriptor,
implementationMethod.Locations.FirstOrDefault(),
implementationType.ToDisplayString(),
implementationMethod.Name,
"method is not part of the registered service interface"
)
);
continue;
}
if (!HasFusionCacheConstructorParameter(implementationType))
{
sourceProductionContext.ReportDiagnostic(
Diagnostic.Create(
MissingFusionCacheDescriptor,
implementationMethod.Locations.FirstOrDefault(),
interfaceType.ToDisplayString(),
interfaceMethod.Name,
implementationType.ToDisplayString()
)
);
continue;
}
if (
!TryCreateInvalidatedMethod(
interfaceMethod,
attribute!,
cachedMethodNames,
out var invalidatedMethod,
out var message
)
)
{
sourceProductionContext.ReportDiagnostic(
Diagnostic.Create(
InvalidInvalidateCacheDescriptor,
implementationMethod.Locations.FirstOrDefault(),
interfaceType.ToDisplayString(),
interfaceMethod.Name,
message
)
);
continue;
}
methods.Add(invalidatedMethod);
}
return methods.ToImmutable();
}
private static ImmutableArray<IMethodSymbol> GetInterfaceMethods(INamedTypeSymbol interfaceType)
{
return
[
.. interfaceType
.AllInterfaces.SelectMany(type => type.GetMembers())
.Concat(interfaceType.GetMembers())
.OfType<IMethodSymbol>()
.Where(method => method.MethodKind == MethodKind.Ordinary),
];
}
private static ImmutableArray<IMethodSymbol> GetCacheAttributedMethods(
INamedTypeSymbol implementationType
)
{
return
[
.. implementationType
.GetMembers()
.OfType<IMethodSymbol>()
.Where(method =>
method.MethodKind == MethodKind.Ordinary
&& GetAttribute(method, SourceGeneratorAttributeNames.CacheAttributeFullName)
is not null
),
];
}
private static ImmutableArray<IMethodSymbol> GetInvalidateAttributedMethods(
INamedTypeSymbol implementationType
)
{
return
[
.. implementationType
.GetMembers()
.OfType<IMethodSymbol>()
.Where(method =>
method.MethodKind == MethodKind.Ordinary
&& GetAttribute(
method,
SourceGeneratorAttributeNames.InvalidateCacheAttributeFullName
)
is not null
),
];
}
private static IMethodSymbol? FindInterfaceMethod(
INamedTypeSymbol implementationType,
ImmutableArray<IMethodSymbol> interfaceMethods,
IMethodSymbol implementationMethod
)
{
foreach (var interfaceMethod in interfaceMethods)
{
var mappedMethod = implementationType.FindImplementationForInterfaceMember(
interfaceMethod
);
if (
mappedMethod is IMethodSymbol method
&& SymbolEqualityComparer.Default.Equals(method, implementationMethod)
)
{
return interfaceMethod;
}
if (HasSameSignature(interfaceMethod, implementationMethod))
{
return interfaceMethod;
}
}
return null;
}
private static bool HasSameSignature(
IMethodSymbol interfaceMethod,
IMethodSymbol implementationMethod
)
{
return implementationMethod.Name == interfaceMethod.Name
&& implementationMethod.TypeParameters.Length == interfaceMethod.TypeParameters.Length
&& implementationMethod.Parameters.Length == interfaceMethod.Parameters.Length
&& implementationMethod
.Parameters.Zip(
interfaceMethod.Parameters,
(left, right) =>
left.RefKind == right.RefKind
&& SymbolEqualityComparer.Default.Equals(left.Type, right.Type)
)
.All(static match => match);
}
private static bool HasFusionCacheConstructorParameter(INamedTypeSymbol implementationType)
{
return true;
}
private static bool TryCreateCachedMethod(
IMethodSymbol method,
INamedTypeSymbol implementationType,
AttributeData attribute,
out CachedMethod cachedMethod,
out string message
)
{
cachedMethod = default;
message = string.Empty;
if (!IsTaskOfT(method.ReturnType, out var valueType))
{
message = "only Task<T> return types are supported";
return false;
}
var cacheOptions = GetCacheOptions(attribute);
if (cacheOptions.ExpirationSeconds <= 0)
{
message = "ExpirationSeconds must be greater than 0";
return false;
}
foreach (var parameter in method.Parameters)
{
if (parameter.RefKind != RefKind.None)
{
message = $"parameter '{parameter.Name}' uses unsupported ref kind";
return false;
}
if (IsCancellationToken(parameter.Type) || cacheOptions.CacheKey is not null)
{
continue;
}
if (!CanGenerateKeyFor(parameter.Type))
{
message =
$"parameter '{parameter.Name}' has unsupported type '{parameter.Type.ToDisplayString()}'";
return false;
}
}
if (
!string.IsNullOrWhiteSpace(cacheOptions.PostProcess)
&& !HasValidPostProcessMethod(
implementationType,
cacheOptions.PostProcess!,
valueType,
out message
)
)
{
return false;
}
cachedMethod = new CachedMethod(
method.Name,
SymbolDisplay.ToFullyQualifiedTypeName(method.ReturnType),
SymbolDisplay.ToFullyQualifiedTypeName(valueType),
[
.. method.Parameters.Select(parameter => new CachedParameter(
parameter.Name,
SymbolDisplay.ToFullyQualifiedTypeName(parameter.Type),
parameter.RefKind,
parameter.HasExplicitDefaultValue,
SymbolDisplay.ToDefaultValueLiteral(parameter),
parameter.IsParams
)),
],
cacheOptions
);
return true;
}
private static bool TryCreateInvalidatedMethod(
IMethodSymbol method,
AttributeData attribute,
HashSet<string> cachedMethodNames,
out InvalidatedMethod invalidatedMethod,
out string message
)
{
invalidatedMethod = default;
message = string.Empty;
if (!IsAwaitableServiceMethod(method.ReturnType))
{
message = "only Task, Task<T>, ValueTask and ValueTask<T> return types are supported";
return false;
}
var methods = ImmutableArray.CreateBuilder<string>();
foreach (var namedArgument in attribute.NamedArguments)
{
if (
namedArgument.Key != "Methods"
|| namedArgument.Value.Kind != TypedConstantKind.Array
)
{
continue;
}
foreach (var value in namedArgument.Value.Values)
{
if (value.Value is string methodName && !string.IsNullOrWhiteSpace(methodName))
{
methods.Add(methodName);
}
}
}
if (methods.Count == 0)
{
message = "Methods must contain at least one cached method name";
return false;
}
foreach (var methodName in methods)
{
if (!cachedMethodNames.Contains(methodName))
{
message =
$"method '{methodName}' is not a cacheable method on the same implementation";
return false;
}
}
invalidatedMethod = new InvalidatedMethod(
method.Name,
SymbolDisplay.ToFullyQualifiedTypeName(method.ReturnType),
[
.. method.Parameters.Select(parameter => new CachedParameter(
parameter.Name,
SymbolDisplay.ToFullyQualifiedTypeName(parameter.Type),
parameter.RefKind,
parameter.HasExplicitDefaultValue,
SymbolDisplay.ToDefaultValueLiteral(parameter),
parameter.IsParams
)),
],
methods.ToImmutable()
);
return true;
}
private static bool IsTaskOfT(ITypeSymbol returnType, out ITypeSymbol valueType)
{
valueType = returnType;
if (
returnType is not INamedTypeSymbol namedReturnType
|| namedReturnType.ConstructedFrom.ToDisplayString()
!= "System.Threading.Tasks.Task<TResult>"
|| namedReturnType.TypeArguments.Length != 1
)
{
return false;
}
valueType = namedReturnType.TypeArguments[0];
return true;
}
private static bool IsAwaitableServiceMethod(ITypeSymbol returnType)
{
var fullName = returnType.ToDisplayString();
if (fullName is "System.Threading.Tasks.Task" or "System.Threading.Tasks.ValueTask")
{
return true;
}
return returnType is INamedTypeSymbol namedReturnType
&& namedReturnType.TypeArguments.Length == 1
&& namedReturnType.ConstructedFrom.ToDisplayString()
is "System.Threading.Tasks.Task<TResult>"
or "System.Threading.Tasks.ValueTask<TResult>";
}
private static CacheOptions GetCacheOptions(AttributeData attribute)
{
var prefix = default(string);
var cacheKey = default(string);
var expirationSeconds = 3600;
var hasExplicitExpirationSeconds = false;
var postProcess = default(string);
var additionalTags = ImmutableArray.CreateBuilder<string>();
foreach (var namedArgument in attribute.NamedArguments)
{
switch (namedArgument.Key)
{
case "Prefix":
prefix = namedArgument.Value.Value as string;
break;
case "CacheKey":
cacheKey = namedArgument.Value.Value as string;
break;
case "ExpirationSeconds" when namedArgument.Value.Value is int configured:
expirationSeconds = configured;
hasExplicitExpirationSeconds = true;
break;
case "PostProcess":
postProcess = namedArgument.Value.Value as string;
break;
case "AdditionalTags" when namedArgument.Value.Kind == TypedConstantKind.Array:
foreach (var value in namedArgument.Value.Values)
{
if (value.Value is string tag && !string.IsNullOrWhiteSpace(tag))
{
additionalTags.Add(tag);
}
}
break;
}
}
return new CacheOptions(
prefix,
cacheKey,
expirationSeconds,
hasExplicitExpirationSeconds,
postProcess,
additionalTags.ToImmutable()
);
}
private static bool HasValidPostProcessMethod(
INamedTypeSymbol implementationType,
string methodName,
ITypeSymbol valueType,
out string message
)
{
foreach (var method in implementationType.GetMembers(methodName).OfType<IMethodSymbol>())
{
if (method.IsStatic)
{
continue;
}
if (method.Parameters.Length != 1)
{
continue;
}
if (!CanPassValueTo(valueType, method.Parameters[0].Type))
{
continue;
}
if (!IsPostProcessReturnType(method.ReturnType))
{
message =
$"post-process method '{methodName}' must return Task, ValueTask, Task<T> or ValueTask<T>";
return false;
}
message = string.Empty;
return true;
}
message =
$"post-process method '{methodName}' was not found or its first parameter is not compatible with '{valueType.ToDisplayString()}'";
return false;
}
private static bool IsPostProcessReturnType(ITypeSymbol returnType)
{
return IsAwaitableServiceMethod(returnType);
}
private static bool CanPassValueTo(ITypeSymbol valueType, ITypeSymbol parameterType)
{
if (SymbolEqualityComparer.Default.Equals(valueType, parameterType))
{
return true;
}
if (
valueType.AllInterfaces.Any(@interface =>
SymbolEqualityComparer.Default.Equals(@interface, parameterType)
)
)
{
return true;
}
var baseType = valueType.BaseType;
while (baseType is not null)
{
if (SymbolEqualityComparer.Default.Equals(baseType, parameterType))
{
return true;
}
baseType = baseType.BaseType;
}
return false;
}
private static bool CanGenerateKeyFor(ITypeSymbol type)
{
if (IsCancellationToken(type))
{
return true;
}
if (IsNullableValueType(type, out var underlyingType))
{
return CanGenerateKeyFor(underlyingType);
}
if (type is IArrayTypeSymbol arrayType)
{
return CanGenerateKeyFor(arrayType.ElementType);
}
if (TryGetEnumerableElementType(type, out var elementType))
{
return CanGenerateKeyFor(elementType);
}
if (IsScalarKeyType(type))
{
return true;
}
return false;
}
private static bool IsScalarKeyType(ITypeSymbol type)
{
if (type.TypeKind == TypeKind.Enum)
{
return true;
}
if (
type.SpecialType
is SpecialType.System_Boolean
or SpecialType.System_Char
or SpecialType.System_SByte
or SpecialType.System_Byte
or SpecialType.System_Int16
or SpecialType.System_UInt16
or SpecialType.System_Int32
or SpecialType.System_UInt32
or SpecialType.System_Int64
or SpecialType.System_UInt64
or SpecialType.System_Decimal
or SpecialType.System_Single
or SpecialType.System_Double
or SpecialType.System_String
)
{
return true;
}
var fullName = type.ToDisplayString();
return fullName
is "System.Guid"
or "System.DateTime"
or "System.DateTimeOffset"
or "System.DateOnly"
or "System.TimeOnly";
}
private static bool IsNullableValueType(ITypeSymbol type, out ITypeSymbol underlyingType)
{
underlyingType = type;
if (
type is not INamedTypeSymbol namedType
|| namedType.ConstructedFrom.ToDisplayString() != "System.Nullable<T>"
|| namedType.TypeArguments.Length != 1
)
{
return false;
}
underlyingType = namedType.TypeArguments[0];
return true;
}
private static bool TryGetEnumerableElementType(ITypeSymbol type, out ITypeSymbol elementType)
{
elementType = type;
if (type.SpecialType == SpecialType.System_String)
{
return false;
}
if (
type is INamedTypeSymbol namedType
&& namedType.ConstructedFrom.ToDisplayString()
== "System.Collections.Generic.IEnumerable<T>"
&& namedType.TypeArguments.Length == 1
)
{
elementType = namedType.TypeArguments[0];
return true;
}
foreach (var implementedInterface in type.AllInterfaces)
{
if (
implementedInterface.ConstructedFrom.ToDisplayString()
!= "System.Collections.Generic.IEnumerable<T>"
)
{
continue;
}
elementType = implementedInterface.TypeArguments[0];
return true;
}
return false;
}
private static bool IsCancellationToken(ITypeSymbol type)
{
return type.ToDisplayString() == "System.Threading.CancellationToken";
}
private static AttributeData? GetAttribute(ISymbol symbol, string attributeFullName)
{
return symbol
.GetAttributes()
.FirstOrDefault(attribute =>
attribute.AttributeClass?.ToDisplayString() == attributeFullName
);
}
private static string ToCSharpStringLiteral(string value) =>
SymbolDisplay.ToCSharpStringLiteral(value);
}
⚠⚠ 以下内容为AI分析的结果,请根据实际情况进行判断。
代码解释
这是一个 C# Source Generator(源代码生成器)的核心类,用于检查服务方法上的缓存特性并生成缓存装饰器所需的元数据。
主要功能
1. 诊断描述符定义
定义了4个编译时诊断错误类型:
- DPZ_CACHE001: 不支持的可缓存服务方法
- DPZ_CACHE002: 缺少 IFusionCache 依赖注入
- DPZ_CACHE003: 无效的缓存后处理方法
- DPZ_CACHE004: 无效的缓存失效方法
2. 核心方法
GetCachedMethods - 获取可缓存方法
- 遍历实现类中标记了
[Cache]特性的方法 - 验证方法是否存在于接口中
- 检查实现类是否注入了
IFusionCache - 验证方法是否符合缓存要求(返回类型、参数类型等)
- 返回
CachedMethod元数据集合
GetInvalidatedMethods - 获取缓存失效方法
- 遍历标记了
[InvalidateCache]特性的方法 - 验证方法的有效性
- 检查失效目标方法是否存在于可缓存方法中
- 返回
InvalidatedMethod元数据集合
3. 验证规则
可缓存方法必须满足:
- 返回类型必须是
Task<T> - 不能有
ref/out参数 - 参数类型必须可用于生成缓存键(支持基本类型、枚举、集合等)
- 过期时间必须大于 0
- 如果指定了后处理方法,该方法必须存在且签名匹配
支持的缓存键参数类型:
- 基本类型(int, string, bool, DateTime, Guid 等)
- 枚举类型
- Nullable 类型
- 数组和 IEnumerable
- CancellationToken(会被忽略)
缓存失效方法必须满足:
- 返回类型为
Task、Task<T>、ValueTask或ValueTask<T> Methods属性中指定的方法必须是同一实现类中的可缓存方法
4. 辅助功能
接口映射: 通过
FindInterfaceMethod匹配实现方法与接口方法签名比较: 通过
HasSameSignature比较方法签名类型检查:
IsTaskOfT: 检查是否为 TaskCanGenerateKeyFor: 检查类型是否可生成缓存键IsScalarKeyType: 检查是否为标量类型IsNullableValueType: 检查是否为可空值类型
特性解析:
GetCacheOptions提取缓存配置(前缀、缓存键、过期时间、后处理方法等)
设计亮点
- 编译时验证: 在编译阶段发现问题,避免运行时错误
- 类型安全: 严格的类型检查确保生成代码的正确性
- 诊断友好: 详细的错误消息帮助开发者快速定位问题
- 灵活性: 支持自定义缓存键、后处理方法、失效策略等
- 只读实现特性: 缓存作为实现细节,不污染接口定义
使用场景
此代码是 AOP(面向切面编程)模式的应用,用于为服务方法自动生成缓存装饰器代码,减少样板代码,统一缓存处理逻辑。
评论加载中...