using System.Collections.Immutable;
using Dpz.Core.SourceGenerator.Models;
using Microsoft.CodeAnalysis;

namespace Dpz.Core.SourceGenerator;

/// <summary>
/// 从服务实现方法上读取缓存特性,并转换为缓存装饰器生成所需的元数据。
/// </summary>
/// <remarks>
/// 缓存属于实现策略,因此这里只读取实现类方法上的 <c>CacheAttribute</c>,
/// 不读取接口方法上的同名特性。
/// </remarks>
internal static class CacheMethodInspector
{
#pragma warning disable RS2008
    private static readonly DiagnosticDescriptor UnsupportedCacheMethodDescriptor = new(
        "DPZ_CACHE001",
        "Unsupported cacheable service method",
        "Cannot generate cache decorator for '{0}.{1}': {2}",
        "Dpz.Core.Cache",
        DiagnosticSeverity.Error,
        true
    );

    private static readonly DiagnosticDescriptor MissingFusionCacheDescriptor = new(
        "DPZ_CACHE002",
        "Cacheable service requires IFusionCache",
        "Cannot generate cache decorator for '{0}.{1}': implementation type '{2}' must inject IFusionCache",
        "Dpz.Core.Cache",
        DiagnosticSeverity.Error,
        true
    );

    private static readonly DiagnosticDescriptor InvalidPostProcessDescriptor = new(
        "DPZ_CACHE003",
        "Invalid cache post-process method",
        "Cannot generate cache decorator for '{0}.{1}': {2}",
        "Dpz.Core.Cache",
        DiagnosticSeverity.Error,
        true
    );

    private static readonly DiagnosticDescriptor InvalidInvalidateCacheDescriptor = new(
        "DPZ_CACHE004",
        "Invalid cache invalidation method",
        "Cannot generate cache invalidation for '{0}.{1}': {2}",
        "Dpz.Core.Cache",
        DiagnosticSeverity.Error,
        true
    );
#pragma warning restore RS2008

    internal static ImmutableArray<CachedMethod> GetCachedMethods(
        INamedTypeSymbol interfaceType,
        INamedTypeSymbol implementationType,
        SourceProductionContext sourceProductionContext
    )
    {
        var methods = ImmutableArray.CreateBuilder<CachedMethod>();
        var interfaceMethods = GetInterfaceMethods(interfaceType);
        foreach (var implementationMethod in GetCacheAttributedMethods(implementationType))
        {
            var attribute = GetAttribute(
                implementationMethod,
                SourceGeneratorAttributeNames.CacheAttributeFullName
            );
            var interfaceMethod = FindInterfaceMethod(
                implementationType,
                interfaceMethods,
                implementationMethod
            );
            if (interfaceMethod is null)
            {
                sourceProductionContext.ReportDiagnostic(
                    Diagnostic.Create(
                        UnsupportedCacheMethodDescriptor,
                        implementationMethod.Locations.FirstOrDefault(),
                        implementationType.ToDisplayString(),
                        implementationMethod.Name,
                        "method is not part of the registered service interface"
                    )
                );
                continue;
            }

            if (!HasFusionCacheConstructorParameter(implementationType))
            {
                sourceProductionContext.ReportDiagnostic(
                    Diagnostic.Create(
                        MissingFusionCacheDescriptor,
                        implementationMethod?.Locations.FirstOrDefault()
                            ?? interfaceMethod.Locations.FirstOrDefault(),
                        interfaceType.ToDisplayString(),
                        interfaceMethod.Name,
                        implementationType.ToDisplayString()
                    )
                );
                continue;
            }

            if (
                !TryCreateCachedMethod(
                    interfaceMethod,
                    implementationType,
                    attribute!,
                    out var cachedMethod,
                    out var message
                )
            )
            {
                sourceProductionContext.ReportDiagnostic(
                    Diagnostic.Create(
                        message.StartsWith("post-process", StringComparison.Ordinal)
                            ? InvalidPostProcessDescriptor
                            : UnsupportedCacheMethodDescriptor,
                        implementationMethod?.Locations.FirstOrDefault()
                            ?? interfaceMethod.Locations.FirstOrDefault(),
                        interfaceType.ToDisplayString(),
                        interfaceMethod.Name,
                        message
                    )
                );
                continue;
            }

            methods.Add(cachedMethod);
        }

        return methods.ToImmutable();
    }

    internal static ImmutableArray<InvalidatedMethod> GetInvalidatedMethods(
        INamedTypeSymbol interfaceType,
        INamedTypeSymbol implementationType,
        ImmutableArray<CachedMethod> cachedMethods,
        SourceProductionContext sourceProductionContext
    )
    {
        var methods = ImmutableArray.CreateBuilder<InvalidatedMethod>();
        var interfaceMethods = GetInterfaceMethods(interfaceType);
        var cachedMethodNames = new HashSet<string>(
            cachedMethods.Select(static method => method.Name)
        );
        foreach (var implementationMethod in GetInvalidateAttributedMethods(implementationType))
        {
            var attribute = GetAttribute(
                implementationMethod,
                SourceGeneratorAttributeNames.InvalidateCacheAttributeFullName
            );
            var interfaceMethod = FindInterfaceMethod(
                implementationType,
                interfaceMethods,
                implementationMethod
            );
            if (interfaceMethod is null)
            {
                sourceProductionContext.ReportDiagnostic(
                    Diagnostic.Create(
                        InvalidInvalidateCacheDescriptor,
                        implementationMethod.Locations.FirstOrDefault(),
                        implementationType.ToDisplayString(),
                        implementationMethod.Name,
                        "method is not part of the registered service interface"
                    )
                );
                continue;
            }

            if (!HasFusionCacheConstructorParameter(implementationType))
            {
                sourceProductionContext.ReportDiagnostic(
                    Diagnostic.Create(
                        MissingFusionCacheDescriptor,
                        implementationMethod.Locations.FirstOrDefault(),
                        interfaceType.ToDisplayString(),
                        interfaceMethod.Name,
                        implementationType.ToDisplayString()
                    )
                );
                continue;
            }

            if (
                !TryCreateInvalidatedMethod(
                    interfaceMethod,
                    attribute!,
                    cachedMethodNames,
                    out var invalidatedMethod,
                    out var message
                )
            )
            {
                sourceProductionContext.ReportDiagnostic(
                    Diagnostic.Create(
                        InvalidInvalidateCacheDescriptor,
                        implementationMethod.Locations.FirstOrDefault(),
                        interfaceType.ToDisplayString(),
                        interfaceMethod.Name,
                        message
                    )
                );
                continue;
            }

            methods.Add(invalidatedMethod);
        }

        return methods.ToImmutable();
    }

    private static ImmutableArray<IMethodSymbol> GetInterfaceMethods(INamedTypeSymbol interfaceType)
    {
        return
        [
            .. interfaceType
                .AllInterfaces.SelectMany(type => type.GetMembers())
                .Concat(interfaceType.GetMembers())
                .OfType<IMethodSymbol>()
                .Where(method => method.MethodKind == MethodKind.Ordinary),
        ];
    }

    private static ImmutableArray<IMethodSymbol> GetCacheAttributedMethods(
        INamedTypeSymbol implementationType
    )
    {
        return
        [
            .. implementationType
                .GetMembers()
                .OfType<IMethodSymbol>()
                .Where(method =>
                    method.MethodKind == MethodKind.Ordinary
                    && GetAttribute(method, SourceGeneratorAttributeNames.CacheAttributeFullName)
                        is not null
                ),
        ];
    }

    private static ImmutableArray<IMethodSymbol> GetInvalidateAttributedMethods(
        INamedTypeSymbol implementationType
    )
    {
        return
        [
            .. implementationType
                .GetMembers()
                .OfType<IMethodSymbol>()
                .Where(method =>
                    method.MethodKind == MethodKind.Ordinary
                    && GetAttribute(
                        method,
                        SourceGeneratorAttributeNames.InvalidateCacheAttributeFullName
                    )
                        is not null
                ),
        ];
    }

    private static IMethodSymbol? FindInterfaceMethod(
        INamedTypeSymbol implementationType,
        ImmutableArray<IMethodSymbol> interfaceMethods,
        IMethodSymbol implementationMethod
    )
    {
        foreach (var interfaceMethod in interfaceMethods)
        {
            var mappedMethod = implementationType.FindImplementationForInterfaceMember(
                interfaceMethod
            );
            if (
                mappedMethod is IMethodSymbol method
                && SymbolEqualityComparer.Default.Equals(method, implementationMethod)
            )
            {
                return interfaceMethod;
            }

            if (HasSameSignature(interfaceMethod, implementationMethod))
            {
                return interfaceMethod;
            }
        }

        return null;
    }

    private static bool HasSameSignature(
        IMethodSymbol interfaceMethod,
        IMethodSymbol implementationMethod
    )
    {
        return implementationMethod.Name == interfaceMethod.Name
            && implementationMethod.TypeParameters.Length == interfaceMethod.TypeParameters.Length
            && implementationMethod.Parameters.Length == interfaceMethod.Parameters.Length
            && implementationMethod
                .Parameters.Zip(
                    interfaceMethod.Parameters,
                    (left, right) =>
                        left.RefKind == right.RefKind
                        && SymbolEqualityComparer.Default.Equals(left.Type, right.Type)
                )
                .All(static match => match);
    }

    private static bool HasFusionCacheConstructorParameter(INamedTypeSymbol implementationType)
    {
        return true;
    }

    private static bool TryCreateCachedMethod(
        IMethodSymbol method,
        INamedTypeSymbol implementationType,
        AttributeData attribute,
        out CachedMethod cachedMethod,
        out string message
    )
    {
        cachedMethod = default;
        message = string.Empty;

        if (!IsTaskOfT(method.ReturnType, out var valueType))
        {
            message = "only Task<T> return types are supported";
            return false;
        }

        var cacheOptions = GetCacheOptions(attribute);
        if (cacheOptions.ExpirationSeconds <= 0)
        {
            message = "ExpirationSeconds must be greater than 0";
            return false;
        }

        foreach (var parameter in method.Parameters)
        {
            if (parameter.RefKind != RefKind.None)
            {
                message = $"parameter '{parameter.Name}' uses unsupported ref kind";
                return false;
            }

            if (IsCancellationToken(parameter.Type) || cacheOptions.CacheKey is not null)
            {
                continue;
            }

            if (!CanGenerateKeyFor(parameter.Type))
            {
                message =
                    $"parameter '{parameter.Name}' has unsupported type '{parameter.Type.ToDisplayString()}'";
                return false;
            }
        }

        if (
            !string.IsNullOrWhiteSpace(cacheOptions.PostProcess)
            && !HasValidPostProcessMethod(
                implementationType,
                cacheOptions.PostProcess!,
                valueType,
                out message
            )
        )
        {
            return false;
        }

        cachedMethod = new CachedMethod(
            method.Name,
            SymbolDisplay.ToFullyQualifiedTypeName(method.ReturnType),
            SymbolDisplay.ToFullyQualifiedTypeName(valueType),
            [
                .. method.Parameters.Select(parameter => new CachedParameter(
                    parameter.Name,
                    SymbolDisplay.ToFullyQualifiedTypeName(parameter.Type),
                    parameter.RefKind,
                    parameter.HasExplicitDefaultValue,
                    SymbolDisplay.ToDefaultValueLiteral(parameter),
                    parameter.IsParams
                )),
            ],
            cacheOptions
        );
        return true;
    }

    private static bool TryCreateInvalidatedMethod(
        IMethodSymbol method,
        AttributeData attribute,
        HashSet<string> cachedMethodNames,
        out InvalidatedMethod invalidatedMethod,
        out string message
    )
    {
        invalidatedMethod = default;
        message = string.Empty;

        if (!IsAwaitableServiceMethod(method.ReturnType))
        {
            message = "only Task, Task<T>, ValueTask and ValueTask<T> return types are supported";
            return false;
        }

        var methods = ImmutableArray.CreateBuilder<string>();
        foreach (var namedArgument in attribute.NamedArguments)
        {
            if (
                namedArgument.Key != "Methods"
                || namedArgument.Value.Kind != TypedConstantKind.Array
            )
            {
                continue;
            }

            foreach (var value in namedArgument.Value.Values)
            {
                if (value.Value is string methodName && !string.IsNullOrWhiteSpace(methodName))
                {
                    methods.Add(methodName);
                }
            }
        }

        if (methods.Count == 0)
        {
            message = "Methods must contain at least one cached method name";
            return false;
        }

        foreach (var methodName in methods)
        {
            if (!cachedMethodNames.Contains(methodName))
            {
                message =
                    $"method '{methodName}' is not a cacheable method on the same implementation";
                return false;
            }
        }

        invalidatedMethod = new InvalidatedMethod(
            method.Name,
            SymbolDisplay.ToFullyQualifiedTypeName(method.ReturnType),
            [
                .. method.Parameters.Select(parameter => new CachedParameter(
                    parameter.Name,
                    SymbolDisplay.ToFullyQualifiedTypeName(parameter.Type),
                    parameter.RefKind,
                    parameter.HasExplicitDefaultValue,
                    SymbolDisplay.ToDefaultValueLiteral(parameter),
                    parameter.IsParams
                )),
            ],
            methods.ToImmutable()
        );
        return true;
    }

    private static bool IsTaskOfT(ITypeSymbol returnType, out ITypeSymbol valueType)
    {
        valueType = returnType;
        if (
            returnType is not INamedTypeSymbol namedReturnType
            || namedReturnType.ConstructedFrom.ToDisplayString()
                != "System.Threading.Tasks.Task<TResult>"
            || namedReturnType.TypeArguments.Length != 1
        )
        {
            return false;
        }

        valueType = namedReturnType.TypeArguments[0];
        return true;
    }

    private static bool IsAwaitableServiceMethod(ITypeSymbol returnType)
    {
        var fullName = returnType.ToDisplayString();
        if (fullName is "System.Threading.Tasks.Task" or "System.Threading.Tasks.ValueTask")
        {
            return true;
        }

        return returnType is INamedTypeSymbol namedReturnType
            && namedReturnType.TypeArguments.Length == 1
            && namedReturnType.ConstructedFrom.ToDisplayString()
                is "System.Threading.Tasks.Task<TResult>"
                    or "System.Threading.Tasks.ValueTask<TResult>";
    }

    private static CacheOptions GetCacheOptions(AttributeData attribute)
    {
        var prefix = default(string);
        var cacheKey = default(string);
        var expirationSeconds = 3600;
        var hasExplicitExpirationSeconds = false;
        var postProcess = default(string);
        var additionalTags = ImmutableArray.CreateBuilder<string>();

        foreach (var namedArgument in attribute.NamedArguments)
        {
            switch (namedArgument.Key)
            {
                case "Prefix":
                    prefix = namedArgument.Value.Value as string;
                    break;
                case "CacheKey":
                    cacheKey = namedArgument.Value.Value as string;
                    break;
                case "ExpirationSeconds" when namedArgument.Value.Value is int configured:
                    expirationSeconds = configured;
                    hasExplicitExpirationSeconds = true;
                    break;
                case "PostProcess":
                    postProcess = namedArgument.Value.Value as string;
                    break;
                case "AdditionalTags" when namedArgument.Value.Kind == TypedConstantKind.Array:
                    foreach (var value in namedArgument.Value.Values)
                    {
                        if (value.Value is string tag && !string.IsNullOrWhiteSpace(tag))
                        {
                            additionalTags.Add(tag);
                        }
                    }
                    break;
            }
        }

        return new CacheOptions(
            prefix,
            cacheKey,
            expirationSeconds,
            hasExplicitExpirationSeconds,
            postProcess,
            additionalTags.ToImmutable()
        );
    }

    private static bool HasValidPostProcessMethod(
        INamedTypeSymbol implementationType,
        string methodName,
        ITypeSymbol valueType,
        out string message
    )
    {
        foreach (var method in implementationType.GetMembers(methodName).OfType<IMethodSymbol>())
        {
            if (method.IsStatic)
            {
                continue;
            }

            if (method.Parameters.Length != 1)
            {
                continue;
            }

            if (!CanPassValueTo(valueType, method.Parameters[0].Type))
            {
                continue;
            }

            if (!IsPostProcessReturnType(method.ReturnType))
            {
                message =
                    $"post-process method '{methodName}' must return Task, ValueTask, Task<T> or ValueTask<T>";
                return false;
            }

            message = string.Empty;
            return true;
        }

        message =
            $"post-process method '{methodName}' was not found or its first parameter is not compatible with '{valueType.ToDisplayString()}'";
        return false;
    }

    private static bool IsPostProcessReturnType(ITypeSymbol returnType)
    {
        return IsAwaitableServiceMethod(returnType);
    }

    private static bool CanPassValueTo(ITypeSymbol valueType, ITypeSymbol parameterType)
    {
        if (SymbolEqualityComparer.Default.Equals(valueType, parameterType))
        {
            return true;
        }

        if (
            valueType.AllInterfaces.Any(@interface =>
                SymbolEqualityComparer.Default.Equals(@interface, parameterType)
            )
        )
        {
            return true;
        }

        var baseType = valueType.BaseType;
        while (baseType is not null)
        {
            if (SymbolEqualityComparer.Default.Equals(baseType, parameterType))
            {
                return true;
            }

            baseType = baseType.BaseType;
        }

        return false;
    }

    private static bool CanGenerateKeyFor(ITypeSymbol type)
    {
        if (IsCancellationToken(type))
        {
            return true;
        }

        if (IsNullableValueType(type, out var underlyingType))
        {
            return CanGenerateKeyFor(underlyingType);
        }

        if (type is IArrayTypeSymbol arrayType)
        {
            return CanGenerateKeyFor(arrayType.ElementType);
        }

        if (TryGetEnumerableElementType(type, out var elementType))
        {
            return CanGenerateKeyFor(elementType);
        }

        if (IsScalarKeyType(type))
        {
            return true;
        }

        return false;
    }

    private static bool IsScalarKeyType(ITypeSymbol type)
    {
        if (type.TypeKind == TypeKind.Enum)
        {
            return true;
        }

        if (
            type.SpecialType
            is SpecialType.System_Boolean
                or SpecialType.System_Char
                or SpecialType.System_SByte
                or SpecialType.System_Byte
                or SpecialType.System_Int16
                or SpecialType.System_UInt16
                or SpecialType.System_Int32
                or SpecialType.System_UInt32
                or SpecialType.System_Int64
                or SpecialType.System_UInt64
                or SpecialType.System_Decimal
                or SpecialType.System_Single
                or SpecialType.System_Double
                or SpecialType.System_String
        )
        {
            return true;
        }

        var fullName = type.ToDisplayString();
        return fullName
            is "System.Guid"
                or "System.DateTime"
                or "System.DateTimeOffset"
                or "System.DateOnly"
                or "System.TimeOnly";
    }

    private static bool IsNullableValueType(ITypeSymbol type, out ITypeSymbol underlyingType)
    {
        underlyingType = type;
        if (
            type is not INamedTypeSymbol namedType
            || namedType.ConstructedFrom.ToDisplayString() != "System.Nullable<T>"
            || namedType.TypeArguments.Length != 1
        )
        {
            return false;
        }

        underlyingType = namedType.TypeArguments[0];
        return true;
    }

    private static bool TryGetEnumerableElementType(ITypeSymbol type, out ITypeSymbol elementType)
    {
        elementType = type;
        if (type.SpecialType == SpecialType.System_String)
        {
            return false;
        }

        if (
            type is INamedTypeSymbol namedType
            && namedType.ConstructedFrom.ToDisplayString()
                == "System.Collections.Generic.IEnumerable<T>"
            && namedType.TypeArguments.Length == 1
        )
        {
            elementType = namedType.TypeArguments[0];
            return true;
        }

        foreach (var implementedInterface in type.AllInterfaces)
        {
            if (
                implementedInterface.ConstructedFrom.ToDisplayString()
                != "System.Collections.Generic.IEnumerable<T>"
            )
            {
                continue;
            }

            elementType = implementedInterface.TypeArguments[0];
            return true;
        }

        return false;
    }

    private static bool IsCancellationToken(ITypeSymbol type)
    {
        return type.ToDisplayString() == "System.Threading.CancellationToken";
    }

    private static AttributeData? GetAttribute(ISymbol symbol, string attributeFullName)
    {
        return symbol
            .GetAttributes()
            .FirstOrDefault(attribute =>
                attribute.AttributeClass?.ToDisplayString() == attributeFullName
            );
    }

    private static string ToCSharpStringLiteral(string value) =>
        SymbolDisplay.ToCSharpStringLiteral(value);
}
⚠⚠    以下内容为AI分析的结果,请根据实际情况进行判断。

代码解释

这是一个 C# Source Generator(源代码生成器)的核心类,用于检查服务方法上的缓存特性并生成缓存装饰器所需的元数据

主要功能

1. 诊断描述符定义

定义了4个编译时诊断错误类型:

  • DPZ_CACHE001: 不支持的可缓存服务方法
  • DPZ_CACHE002: 缺少 IFusionCache 依赖注入
  • DPZ_CACHE003: 无效的缓存后处理方法
  • DPZ_CACHE004: 无效的缓存失效方法

2. 核心方法

GetCachedMethods - 获取可缓存方法

  • 遍历实现类中标记了 [Cache] 特性的方法
  • 验证方法是否存在于接口中
  • 检查实现类是否注入了 IFusionCache
  • 验证方法是否符合缓存要求(返回类型、参数类型等)
  • 返回 CachedMethod 元数据集合

GetInvalidatedMethods - 获取缓存失效方法

  • 遍历标记了 [InvalidateCache] 特性的方法
  • 验证方法的有效性
  • 检查失效目标方法是否存在于可缓存方法中
  • 返回 InvalidatedMethod 元数据集合

3. 验证规则

可缓存方法必须满足:

  • 返回类型必须是 Task<T>
  • 不能有 ref/out 参数
  • 参数类型必须可用于生成缓存键(支持基本类型、枚举、集合等)
  • 过期时间必须大于 0
  • 如果指定了后处理方法,该方法必须存在且签名匹配

支持的缓存键参数类型:

  • 基本类型(int, string, bool, DateTime, Guid 等)
  • 枚举类型
  • Nullable 类型
  • 数组和 IEnumerable
  • CancellationToken(会被忽略)

缓存失效方法必须满足:

  • 返回类型为 TaskTask<T>ValueTaskValueTask<T>
  • Methods 属性中指定的方法必须是同一实现类中的可缓存方法

4. 辅助功能

  • 接口映射: 通过 FindInterfaceMethod 匹配实现方法与接口方法

  • 签名比较: 通过 HasSameSignature 比较方法签名

  • 类型检查:

    • IsTaskOfT: 检查是否为 Task
    • CanGenerateKeyFor: 检查类型是否可生成缓存键
    • IsScalarKeyType: 检查是否为标量类型
    • IsNullableValueType: 检查是否为可空值类型
  • 特性解析: GetCacheOptions 提取缓存配置(前缀、缓存键、过期时间、后处理方法等)

设计亮点

  1. 编译时验证: 在编译阶段发现问题,避免运行时错误
  2. 类型安全: 严格的类型检查确保生成代码的正确性
  3. 诊断友好: 详细的错误消息帮助开发者快速定位问题
  4. 灵活性: 支持自定义缓存键、后处理方法、失效策略等
  5. 只读实现特性: 缓存作为实现细节,不污染接口定义

使用场景

此代码是 AOP(面向切面编程)模式的应用,用于为服务方法自动生成缓存装饰器代码,减少样板代码,统一缓存处理逻辑。

评论加载中...