using System.Collections.Immutable;
using Dpz.Core.SourceGenerator.Models;
using Microsoft.CodeAnalysis;

namespace Dpz.Core.SourceGenerator;

/// <summary>
/// 负责扫描编译结果并生成服务注册元数据。
/// </summary>
internal static class ServiceRegistrationProvider
{
    private const string InterfaceNamespace = "Dpz.Core.Service.RepositoryService";
    private const string ImplementNamespace = "Dpz.Core.Service.RepositoryServiceImpl";
    private const string DependencyInjectionAttributeFullName =
        SourceGeneratorAttributeNames.DependencyInjectionAttributeFullName;
    private const string HttpClientAttributeFullName =
        SourceGeneratorAttributeNames.HttpClientAttributeFullName;

    internal static ImmutableArray<ServiceRegistration> GetRegistrations(
        Compilation compilation,
        SourceProductionContext sourceProductionContext
    )
    {
        var allTypes = new List<INamedTypeSymbol>();
        CollectTypes(compilation.GlobalNamespace, allTypes);

        var interfaces = allTypes
            .Where(type =>
                SymbolEqualityComparer.Default.Equals(type.ContainingAssembly, compilation.Assembly)
                && type.TypeKind == TypeKind.Interface
                && type.ContainingNamespace.ToDisplayString() == InterfaceNamespace
            )
            .OrderBy(type => type.ToDisplayString(), StringComparer.Ordinal)
            .ToList();

        var implementations = allTypes
            .Where(type =>
                SymbolEqualityComparer.Default.Equals(type.ContainingAssembly, compilation.Assembly)
                && type.TypeKind == TypeKind.Class
                && !type.IsAbstract
                && type.ContainingNamespace.ToDisplayString() == ImplementNamespace
            )
            .OrderBy(type => type.ToDisplayString(), StringComparer.Ordinal)
            .ToList();

        var registrations = ImmutableArray.CreateBuilder<ServiceRegistration>();
        foreach (var interfaceType in interfaces)
        {
            if (ShouldIgnore(interfaceType))
            {
                continue;
            }

            var implementationType = implementations.FirstOrDefault(type =>
                type.AllInterfaces.Any(implementedInterface =>
                    SymbolEqualityComparer.Default.Equals(implementedInterface, interfaceType)
                ) || IsConventionImplementation(type, interfaceType)
            );

            if (implementationType is null || ShouldIgnore(implementationType))
            {
                continue;
            }

            registrations.Add(
                CreateRegistration(interfaceType, implementationType, sourceProductionContext)
            );
        }

        return registrations.ToImmutable();
    }

    private static ServiceRegistration CreateRegistration(
        INamedTypeSymbol interfaceType,
        INamedTypeSymbol implementationType,
        SourceProductionContext sourceProductionContext
    )
    {
        var serviceType = SymbolDisplay.ToFullyQualifiedTypeName(interfaceType);
        var implementation = SymbolDisplay.ToFullyQualifiedTypeName(implementationType);
        var implementationNamespace = implementationType.ContainingNamespace.ToDisplayString();
        var cachedMethods = CacheMethodInspector.GetCachedMethods(
            interfaceType,
            implementationType,
            sourceProductionContext
        );
        var invalidatedMethods = CacheMethodInspector.GetInvalidatedMethods(
            interfaceType,
            implementationType,
            cachedMethods,
            sourceProductionContext
        );
        var interfaceMethods = InterfaceContractInspector.GetMethods(interfaceType);
        var interfaceProperties = InterfaceContractInspector.GetProperties(interfaceType);
        var interfaceEvents = InterfaceContractInspector.GetEvents(interfaceType);

        // HttpClient registration is a specialized DI registration. If both attributes are
        // present, HttpClient wins and normal lifetime settings are ignored.
        var httpClientRegistration =
            GetHttpClientRegistration(interfaceType)
            ?? GetHttpClientRegistration(implementationType);
        if (httpClientRegistration is not null)
        {
            return ServiceRegistration.HttpClient(
                serviceType,
                implementation,
                implementationNamespace,
                httpClientRegistration.Value.BaseAddress,
                httpClientRegistration.Value.TimeoutSeconds,
                cachedMethods,
                invalidatedMethods,
                interfaceMethods,
                interfaceProperties,
                interfaceEvents
            );
        }

        return ServiceRegistration.Service(
            serviceType,
            implementation,
            implementationNamespace,
            GetLifetime(implementationType, interfaceType),
            cachedMethods,
            invalidatedMethods,
            interfaceMethods,
            interfaceProperties,
            interfaceEvents
        );
    }

    private static void CollectTypes(INamespaceSymbol namespaceSymbol, List<INamedTypeSymbol> types)
    {
        foreach (var type in namespaceSymbol.GetTypeMembers())
        {
            CollectTypes(type, types);
        }

        foreach (var childNamespace in namespaceSymbol.GetNamespaceMembers())
        {
            CollectTypes(childNamespace, types);
        }
    }

    private static void CollectTypes(INamedTypeSymbol typeSymbol, List<INamedTypeSymbol> types)
    {
        types.Add(typeSymbol);

        foreach (var nestedType in typeSymbol.GetTypeMembers())
        {
            CollectTypes(nestedType, types);
        }
    }

    private static bool ShouldIgnore(INamedTypeSymbol type)
    {
        var attribute = GetAttribute(type, DependencyInjectionAttributeFullName);
        if (attribute is null)
        {
            return false;
        }

        foreach (var namedArgument in attribute.NamedArguments)
        {
            if (namedArgument is { Key: "Ignore", Value.Value: true })
            {
                return true;
            }
        }

        return false;
    }

    private static bool IsConventionImplementation(
        INamedTypeSymbol implementationType,
        INamedTypeSymbol interfaceType
    )
    {
        if (!interfaceType.Name.StartsWith("I", StringComparison.Ordinal))
        {
            return false;
        }

        return string.Equals(
            implementationType.Name,
            interfaceType.Name.Substring(1),
            StringComparison.Ordinal
        );
    }

    private static ServiceLifetime GetLifetime(
        INamedTypeSymbol implementationType,
        INamedTypeSymbol interfaceType
    )
    {
        var attribute =
            GetAttribute(implementationType, DependencyInjectionAttributeFullName)
            ?? GetAttribute(interfaceType, DependencyInjectionAttributeFullName);
        if (attribute?.ConstructorArguments.Length == 1)
        {
            return (ServiceLifetime)(attribute.ConstructorArguments[0].Value as int? ?? 1);
        }

        return ServiceLifetime.Scoped;
    }

    private static HttpClientRegistration? GetHttpClientRegistration(INamedTypeSymbol type)
    {
        var attribute = GetAttribute(type, HttpClientAttributeFullName);
        if (attribute?.ConstructorArguments.Length != 1)
        {
            return null;
        }

        var baseAddress = attribute.ConstructorArguments[0].Value as string;
        if (string.IsNullOrWhiteSpace(baseAddress))
        {
            return null;
        }

        var timeoutSeconds = 0;
        foreach (var namedArgument in attribute.NamedArguments)
        {
            if (
                namedArgument is
                { Key: "TimeoutSeconds", Value.Value: int configuredTimeoutSeconds }
            )
            {
                timeoutSeconds = configuredTimeoutSeconds;
            }
        }

        return new HttpClientRegistration(baseAddress!, timeoutSeconds);
    }

    private static AttributeData? GetAttribute(INamedTypeSymbol type, string attributeFullName)
    {
        return GetAttribute((ISymbol)type, attributeFullName);
    }

    private static AttributeData? GetAttribute(ISymbol symbol, string attributeFullName)
    {
        return symbol
            .GetAttributes()
            .FirstOrDefault(attribute =>
                attribute.AttributeClass?.ToDisplayString() == attributeFullName
            );
    }
}
⚠⚠    以下内容为AI分析的结果,请根据实际情况进行判断。

代码解释

这是一个 C# Source Generator(源代码生成器) 的核心组件,负责在编译时自动扫描代码并生成服务注册元数据。

主要功能

1. ServiceRegistrationProvider 类

静态工具类,用于自动发现和注册依赖注入服务。

2. 核心方法:GetRegistrations

internal static ImmutableArray<ServiceRegistration> GetRegistrations(...)

工作流程:

  1. 收集所有类型:扫描整个编译单元中的所有类型

  2. 筛选接口

    • 必须是接口类型
    • 位于 Dpz.Core.Service.RepositoryService 命名空间
    • 属于当前编译程序集
  3. 筛选实现类

    • 必须是非抽象类
    • 位于 Dpz.Core.Service.RepositoryServiceImpl 命名空间
    • 属于当前编译程序集
  4. 匹配接口与实现

    • 显式实现:类实现了该接口
    • 约定实现:接口名 IXxx 对应实现类 Xxx
  5. 生成注册元数据:为每对匹配的接口-实现创建注册信息

3. CreateRegistration 方法

创建单个服务注册项,包含:

  • 服务类型和实现类型的完全限定名
  • 缓存方法和缓存失效方法(通过 CacheMethodInspector 获取)
  • 接口的方法、属性、事件(通过 InterfaceContractInspector 获取)
  • 特殊处理 HttpClient 注册(如果标记了 HttpClient 特性)
  • 普通服务的生命周期(Scoped/Singleton/Transient)

4. 辅助方法

CollectTypes:递归收集命名空间和嵌套类型中的所有类型

ShouldIgnore:检查类型是否标记了 Ignore = true 的 DependencyInjection 特性

IsConventionImplementation:检查是否符合命名约定(接口 IFoo 对应实现 Foo

GetLifetime:从特性中读取服务生命周期,默认为 Scoped

GetHttpClientRegistration:解析 HttpClient 特性配置(BaseAddress 和 TimeoutSeconds)

GetAttribute:获取指定类型的特性数据

设计模式

  • 约定优于配置:自动匹配特定命名空间下的接口和实现
  • 基于特性的配置:通过 [DependencyInjection][HttpClient] 特性控制行为
  • 编译时代码生成:利用 Roslyn Source Generator 在编译期生成注册代码,避免运行时反射

使用场景

这个生成器适用于需要大量服务注册的项目,通过自动化减少手动编写依赖注入注册代码的工作量。

评论加载中...