LobbyServer/LobbyServerSourceGenerator/LobbyMessageSourceGenerator.cs

343 lines
13 KiB
C#

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text;
namespace LobbyServerDto
{
[Generator]
public class LobbyMessageSourceGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var typeDeclarations = context.SyntaxProvider.ForAttributeWithMetadataName(
"LobbyServerDto.LobbyMessageAttribute",
predicate: static (node, token) =>
{
return (node is ClassDeclarationSyntax
or StructDeclarationSyntax
or RecordDeclarationSyntax
or InterfaceDeclarationSyntax);
},
transform: static (context, token) =>
{
return (TypeDeclarationSyntax)context.TargetNode;
});
var source = typeDeclarations
.Combine(context.CompilationProvider)
.WithComparer(Comparer.Instance).Collect();
context.RegisterSourceOutput(source, static (context, source) =>
{
GenerateCode(context, source);
});
}
private static void GenerateCode(SourceProductionContext context, ImmutableArray<(TypeDeclarationSyntax, Compilation)> source)
{
if (source.IsDefaultOrEmpty)
return;
Dictionary<string, SyntaxInfo> sortableSyntaxes = new Dictionary<string, SyntaxInfo>();
foreach (var syntax in source)
{
var semanticModel = syntax.Item2.GetSemanticModel(syntax.Item1.SyntaxTree);
var typeSymbol = semanticModel.GetDeclaredSymbol(syntax.Item1, context.CancellationToken);
if (typeSymbol == null)
{
return;
}
var ns = typeSymbol.ContainingNamespace.IsGlobalNamespace
? null
: typeSymbol.ContainingNamespace.ToString();
var name = typeSymbol.Name;
var fullName = $"{ns}.{name}";
sortableSyntaxes.Add(fullName, new SyntaxInfo() { name = name, nameSpace = ns, semanticModel = semanticModel, syntaxNode = syntax.Item1, compilation = syntax.Item2 });
}
StringBuilder s = new StringBuilder();
int id = 0;
foreach (var foundClass in sortableSyntaxes.OrderBy(s => s.Key))
{
s.Clear();
s.Append(@$"// <auto-generated />
using System.Collections.Generic;
using System.Text;
{(foundClass.Value.nameSpace is null ? null : $@"namespace {foundClass.Value.nameSpace}
{{")}
public partial class {foundClass.Value.name}
{{
public const int TypeId = {id++};
public int Serialize(byte[] buffer)
{{
int offset = 0;
{{
uint v = (uint)TypeId;
while (v >= 0x80)
{{
buffer[offset++] = (byte)(v | 0x80);
v >>= 7;
}}
buffer[offset++] = (byte)v;
}}
");
foreach (var member in foundClass.Value.syntaxNode.Members)
{
if (member is PropertyDeclarationSyntax p)
{
int maxLength = 256;
var iMember = foundClass.Value.semanticModel.GetDeclaredSymbol(member);
foreach (var maxLengthAttr in iMember.GetAttributes().Where(a => a.AttributeClass.Name == "MaxLengthAttribute"))
{
var length = (int)maxLengthAttr.ConstructorArguments.First().Value;
if(length < maxLength)
maxLength = length;
}
var name = p.Identifier.ToString();
switch(p.Type.ToString())
{
case "bool":
s.Append($@"
buffer[offset++] = (byte)({name} == true ? 1 : 0);");
break;
case "int":
s.Append($@"
buffer[offset++] = (byte){name};
buffer[offset++] = (byte)({name} >> 8);
buffer[offset++] = (byte)({name} >> 16);
buffer[offset++] = (byte)({name} >> 24);");
break;
case "Guid":
s.Append($@"
Buffer.BlockCopy({name}.ToByteArray(), 0, buffer, offset, 16);
offset += 16;");
break;
case "string":
case "string?":
s.Append($@"
if ({name} != null)
{{
var str1 = Encoding.UTF8.GetBytes({name}.Substring(0, Math.Min({maxLength}, {name}.Length)));
uint v = (uint)str1.Length;
while (v >= 0x80)
{{
buffer[offset++] = (byte)(v | 0x80);
v >>= 7;
}}
buffer[offset++] = (byte)v;
Buffer.BlockCopy(str1, 0, buffer, offset, str1.Length);
offset += str1.Length;
}}
else
{{
buffer[offset++] = 0;
}}
");
break;
case "byte[]":
case "byte[]?":
s.Append($@"
if ({name} != null)
{{
int maxLength = Math.Min(PasswordHash.Length, {maxLength});
uint v = (uint)maxLength;
while (v >= 0x80)
{{
buffer[offset++] = (byte)(v | 0x80);
v >>= 7;
}}
buffer[offset++] = (byte)v;
Buffer.BlockCopy(PasswordHash, 0, buffer, offset, maxLength);
offset += maxLength;
}}
else
{{
buffer[offset++] = 0;
}}
");
break;
default:
throw new Exception($"Unkown type {p.Type.ToString()} on field {name}");
}
}
}
s.Append(@$"
return offset;
}}
public static {foundClass.Value.name} Deserialize(ReadOnlySpan<byte> buffer)
{{
int offset = 0;
{foundClass.Value.name} ret = new {foundClass.Value.name}();
{{
int count = 0;
int shift = 0;
byte b;
do {{
// Check for a corrupted stream. Read a max of 5 bytes.
// In a future version, add a DataFormatException.
if (shift == 5 * 7) // 5 bytes max per Int32, shift += 7
throw new FormatException(""Format_Bad7BitInt32"");
// ReadByte handles end of stream cases for us.
b = buffer[offset++];
count |= (b & 0x7F) << shift;
shift += 7;
}} while ((b & 0x80) != 0);
}}
");
foreach (var member in foundClass.Value.syntaxNode.Members)
{
if (member is PropertyDeclarationSyntax p)
{
int maxLength = 256;
var iMember = foundClass.Value.semanticModel.GetDeclaredSymbol(member);
foreach (var maxLengthAttr in iMember.GetAttributes().Where(a => a.AttributeClass.Name == "MaxLengthAttribute"))
{
var length = (int)maxLengthAttr.ConstructorArguments.First().Value;
if (length < maxLength)
maxLength = length;
}
var name = p.Identifier.ToString();
switch (p.Type.ToString())
{
case "bool":
s.Append($@"
ret.{name} = buffer[offset++] == 0 ? false : true;");
break;
case "int":
s.Append($@"
ret.{name} = (int)(buffer[offset++] | buffer[offset++] << 8 | buffer[offset++] << 16 | buffer[offset++] << 24);");
break;
case "Guid":
s.Append($@"
{{
ret.{name} = new Guid(buffer.Slice(offset, 16));
offset+=16;
}}");
break;
case "string":
case "string?":
s.Append($@"
{{
int strLen = 0;
int shift = 0;
byte b;
do {{
// Check for a corrupted stream. Read a max of 5 bytes.
// In a future version, add a DataFormatException.
if (shift == 5 * 7) // 5 bytes max per Int32, shift += 7
throw new FormatException(""Format_Bad7BitInt32"");
// ReadByte handles end of stream cases for us.
b = buffer[offset++];
strLen |= (b & 0x7F) << shift;
shift += 7;
}} while ((b & 0x80) != 0);
if(strLen > 0)
{{
ret.{name} = Encoding.UTF8.GetString(buffer.Slice(offset, strLen));
offset += strLen;
}}
}}");
break;
case "byte[]":
case "byte[]?":
s.Append($@"
{{
int strLen = 0;
int shift = 0;
byte b;
do {{
// Check for a corrupted stream. Read a max of 5 bytes.
// In a future version, add a DataFormatException.
if (shift == 5 * 7) // 5 bytes max per Int32, shift += 7
throw new FormatException(""Format_Bad7BitInt32"");
// ReadByte handles end of stream cases for us.
b = buffer[offset++];
strLen |= (b & 0x7F) << shift;
shift += 7;
}} while ((b & 0x80) != 0);
if(strLen > 0)
{{
ret.{name} = buffer.Slice(offset, strLen).ToArray();
offset += strLen;
}}
}}");
break;
default:
throw new Exception($"Unkown type {p.Type.ToString()} on field {name}");
}
}
}
s.Append(@$"
return ret;
}}
}}
{(foundClass.Value.nameSpace is null ? null : @"}
")}");
context.AddSource($"{foundClass.Key}.g.cs", s.ToString());
}
}
}
class SyntaxInfo
{
public TypeDeclarationSyntax syntaxNode;
public SemanticModel semanticModel;
public string nameSpace;
public string name;
internal Compilation compilation;
}
class Comparer : IEqualityComparer<(TypeDeclarationSyntax, Compilation)>
{
public static readonly Comparer Instance = new Comparer();
public bool Equals((TypeDeclarationSyntax, Compilation) x, (TypeDeclarationSyntax, Compilation) y)
{
return x.Item1.Equals(y.Item1);
}
public int GetHashCode((TypeDeclarationSyntax, Compilation) obj)
{
return obj.Item1.GetHashCode();
}
}
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Interface, AllowMultiple = false, Inherited = false)]
public class LobbyMessageAttribute : Attribute
{
}
}