using System.Collections.Immutable;
using Dpz.Core.SourceGenerator.Models;
using Microsoft.CodeAnalysis;
namespace Dpz.Core.SourceGenerator;
/// <summary>
/// 负责扫描编译结果并生成服务注册元数据。
/// </summary>
internal static class ServiceRegistrationProvider
{
private const string InterfaceNamespace = "Dpz.Core.Service.RepositoryService";
private const string ImplementNamespace = "Dpz.Core.Service.RepositoryServiceImpl";
private const string DependencyInjectionAttributeFullName =
SourceGeneratorAttributeNames.DependencyInjectionAttributeFullName;
private const string HttpClientAttributeFullName =
SourceGeneratorAttributeNames.HttpClientAttributeFullName;
internal static ImmutableArray<ServiceRegistration> GetRegistrations(
Compilation compilation,
SourceProductionContext sourceProductionContext
)
{
var allTypes = new List<INamedTypeSymbol>();
CollectTypes(compilation.GlobalNamespace, allTypes);
var interfaces = allTypes
.Where(type =>
SymbolEqualityComparer.Default.Equals(type.ContainingAssembly, compilation.Assembly)
&& type.TypeKind == TypeKind.Interface
&& type.ContainingNamespace.ToDisplayString() == InterfaceNamespace
)
.OrderBy(type => type.ToDisplayString(), StringComparer.Ordinal)
.ToList();
var implementations = allTypes
.Where(type =>
SymbolEqualityComparer.Default.Equals(type.ContainingAssembly, compilation.Assembly)
&& type.TypeKind == TypeKind.Class
&& !type.IsAbstract
&& type.ContainingNamespace.ToDisplayString() == ImplementNamespace
)
.OrderBy(type => type.ToDisplayString(), StringComparer.Ordinal)
.ToList();
var registrations = ImmutableArray.CreateBuilder<ServiceRegistration>();
foreach (var interfaceType in interfaces)
{
if (ShouldIgnore(interfaceType))
{
continue;
}
var implementationType = implementations.FirstOrDefault(type =>
type.AllInterfaces.Any(implementedInterface =>
SymbolEqualityComparer.Default.Equals(implementedInterface, interfaceType)
) || IsConventionImplementation(type, interfaceType)
);
if (implementationType is null || ShouldIgnore(implementationType))
{
continue;
}
registrations.Add(
CreateRegistration(interfaceType, implementationType, sourceProductionContext)
);
}
return registrations.ToImmutable();
}
private static ServiceRegistration CreateRegistration(
INamedTypeSymbol interfaceType,
INamedTypeSymbol implementationType,
SourceProductionContext sourceProductionContext
)
{
var serviceType = SymbolDisplay.ToFullyQualifiedTypeName(interfaceType);
var implementation = SymbolDisplay.ToFullyQualifiedTypeName(implementationType);
var implementationNamespace = implementationType.ContainingNamespace.ToDisplayString();
var cachedMethods = CacheMethodInspector.GetCachedMethods(
interfaceType,
implementationType,
sourceProductionContext
);
var invalidatedMethods = CacheMethodInspector.GetInvalidatedMethods(
interfaceType,
implementationType,
cachedMethods,
sourceProductionContext
);
var interfaceMethods = InterfaceContractInspector.GetMethods(interfaceType);
var interfaceProperties = InterfaceContractInspector.GetProperties(interfaceType);
var interfaceEvents = InterfaceContractInspector.GetEvents(interfaceType);
// HttpClient registration is a specialized DI registration. If both attributes are
// present, HttpClient wins and normal lifetime settings are ignored.
var httpClientRegistration =
GetHttpClientRegistration(interfaceType)
?? GetHttpClientRegistration(implementationType);
if (httpClientRegistration is not null)
{
return ServiceRegistration.HttpClient(
serviceType,
implementation,
implementationNamespace,
httpClientRegistration.Value.BaseAddress,
httpClientRegistration.Value.TimeoutSeconds,
cachedMethods,
invalidatedMethods,
interfaceMethods,
interfaceProperties,
interfaceEvents
);
}
return ServiceRegistration.Service(
serviceType,
implementation,
implementationNamespace,
GetLifetime(implementationType, interfaceType),
cachedMethods,
invalidatedMethods,
interfaceMethods,
interfaceProperties,
interfaceEvents
);
}
private static void CollectTypes(INamespaceSymbol namespaceSymbol, List<INamedTypeSymbol> types)
{
foreach (var type in namespaceSymbol.GetTypeMembers())
{
CollectTypes(type, types);
}
foreach (var childNamespace in namespaceSymbol.GetNamespaceMembers())
{
CollectTypes(childNamespace, types);
}
}
private static void CollectTypes(INamedTypeSymbol typeSymbol, List<INamedTypeSymbol> types)
{
types.Add(typeSymbol);
foreach (var nestedType in typeSymbol.GetTypeMembers())
{
CollectTypes(nestedType, types);
}
}
private static bool ShouldIgnore(INamedTypeSymbol type)
{
var attribute = GetAttribute(type, DependencyInjectionAttributeFullName);
if (attribute is null)
{
return false;
}
foreach (var namedArgument in attribute.NamedArguments)
{
if (namedArgument is { Key: "Ignore", Value.Value: true })
{
return true;
}
}
return false;
}
private static bool IsConventionImplementation(
INamedTypeSymbol implementationType,
INamedTypeSymbol interfaceType
)
{
if (!interfaceType.Name.StartsWith("I", StringComparison.Ordinal))
{
return false;
}
return string.Equals(
implementationType.Name,
interfaceType.Name.Substring(1),
StringComparison.Ordinal
);
}
private static ServiceLifetime GetLifetime(
INamedTypeSymbol implementationType,
INamedTypeSymbol interfaceType
)
{
var attribute =
GetAttribute(implementationType, DependencyInjectionAttributeFullName)
?? GetAttribute(interfaceType, DependencyInjectionAttributeFullName);
if (attribute?.ConstructorArguments.Length == 1)
{
return (ServiceLifetime)(attribute.ConstructorArguments[0].Value as int? ?? 1);
}
return ServiceLifetime.Scoped;
}
private static HttpClientRegistration? GetHttpClientRegistration(INamedTypeSymbol type)
{
var attribute = GetAttribute(type, HttpClientAttributeFullName);
if (attribute?.ConstructorArguments.Length != 1)
{
return null;
}
var baseAddress = attribute.ConstructorArguments[0].Value as string;
if (string.IsNullOrWhiteSpace(baseAddress))
{
return null;
}
var timeoutSeconds = 0;
foreach (var namedArgument in attribute.NamedArguments)
{
if (
namedArgument is
{ Key: "TimeoutSeconds", Value.Value: int configuredTimeoutSeconds }
)
{
timeoutSeconds = configuredTimeoutSeconds;
}
}
return new HttpClientRegistration(baseAddress!, timeoutSeconds);
}
private static AttributeData? GetAttribute(INamedTypeSymbol type, string attributeFullName)
{
return GetAttribute((ISymbol)type, attributeFullName);
}
private static AttributeData? GetAttribute(ISymbol symbol, string attributeFullName)
{
return symbol
.GetAttributes()
.FirstOrDefault(attribute =>
attribute.AttributeClass?.ToDisplayString() == attributeFullName
);
}
}
⚠⚠ 以下内容为AI分析的结果,请根据实际情况进行判断。
代码解释
这是一个 C# Source Generator(源代码生成器) 的核心组件,负责在编译时自动扫描代码并生成服务注册元数据。
主要功能
1. ServiceRegistrationProvider 类
静态工具类,用于自动发现和注册依赖注入服务。
2. 核心方法:GetRegistrations
internal static ImmutableArray<ServiceRegistration> GetRegistrations(...)
工作流程:
收集所有类型:扫描整个编译单元中的所有类型
筛选接口:
- 必须是接口类型
- 位于
Dpz.Core.Service.RepositoryService命名空间 - 属于当前编译程序集
筛选实现类:
- 必须是非抽象类
- 位于
Dpz.Core.Service.RepositoryServiceImpl命名空间 - 属于当前编译程序集
匹配接口与实现:
- 显式实现:类实现了该接口
- 约定实现:接口名
IXxx对应实现类Xxx
生成注册元数据:为每对匹配的接口-实现创建注册信息
3. CreateRegistration 方法
创建单个服务注册项,包含:
- 服务类型和实现类型的完全限定名
- 缓存方法和缓存失效方法(通过 CacheMethodInspector 获取)
- 接口的方法、属性、事件(通过 InterfaceContractInspector 获取)
- 特殊处理 HttpClient 注册(如果标记了 HttpClient 特性)
- 普通服务的生命周期(Scoped/Singleton/Transient)
4. 辅助方法
CollectTypes:递归收集命名空间和嵌套类型中的所有类型
ShouldIgnore:检查类型是否标记了 Ignore = true 的 DependencyInjection 特性
IsConventionImplementation:检查是否符合命名约定(接口 IFoo 对应实现 Foo)
GetLifetime:从特性中读取服务生命周期,默认为 Scoped
GetHttpClientRegistration:解析 HttpClient 特性配置(BaseAddress 和 TimeoutSeconds)
GetAttribute:获取指定类型的特性数据
设计模式
- 约定优于配置:自动匹配特定命名空间下的接口和实现
- 基于特性的配置:通过
[DependencyInjection]和[HttpClient]特性控制行为 - 编译时代码生成:利用 Roslyn Source Generator 在编译期生成注册代码,避免运行时反射
使用场景
这个生成器适用于需要大量服务注册的项目,通过自动化减少手动编写依赖注入注册代码的工作量。
评论加载中...