using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Text; namespace LobbyServerDto { [Generator] public class LobbyMessageSourceGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { var typeDeclarations = context.SyntaxProvider.ForAttributeWithMetadataName( "LobbyServerDto.LobbyMessageAttribute", predicate: static (node, token) => { return (node is ClassDeclarationSyntax or StructDeclarationSyntax or RecordDeclarationSyntax or InterfaceDeclarationSyntax); }, transform: static (context, token) => { return (TypeDeclarationSyntax)context.TargetNode; }); var source = typeDeclarations .Combine(context.CompilationProvider) .WithComparer(Comparer.Instance).Collect(); context.RegisterSourceOutput(source, static (context, source) => { GenerateCode(context, source); }); } private static void GenerateCode(SourceProductionContext context, ImmutableArray<(TypeDeclarationSyntax, Compilation)> source) { if (source.IsDefaultOrEmpty) return; Dictionary sortableSyntaxes = new Dictionary(); foreach (var syntax in source) { var semanticModel = syntax.Item2.GetSemanticModel(syntax.Item1.SyntaxTree); var typeSymbol = semanticModel.GetDeclaredSymbol(syntax.Item1, context.CancellationToken); if (typeSymbol == null) { return; } var ns = typeSymbol.ContainingNamespace.IsGlobalNamespace ? null : typeSymbol.ContainingNamespace.ToString(); var name = typeSymbol.Name; var fullName = $"{ns}.{name}"; sortableSyntaxes.Add(fullName, new SyntaxInfo() { name = name, nameSpace = ns, semanticModel = semanticModel, syntaxNode = syntax.Item1, compilation = syntax.Item2 }); } StringBuilder s = new StringBuilder(); int id = 0; foreach (var foundClass in sortableSyntaxes.OrderBy(s => s.Key)) { s.Clear(); s.Append(@$"// using System.Collections.Generic; using System.Text; {(foundClass.Value.nameSpace is null ? null : $@"namespace {foundClass.Value.nameSpace} {{")} public partial class {foundClass.Value.name} {{ public const int TypeId = {id++}; public int Serialize(byte[] buffer) {{ int offset = 0; {{ uint v = (uint)TypeId; while (v >= 0x80) {{ buffer[offset++] = (byte)(v | 0x80); v >>= 7; }} buffer[offset++] = (byte)v; }} "); foreach (var member in foundClass.Value.syntaxNode.Members) { if (member is PropertyDeclarationSyntax p) { int maxLength = 256; var iMember = foundClass.Value.semanticModel.GetDeclaredSymbol(member); foreach (var maxLengthAttr in iMember.GetAttributes().Where(a => a.AttributeClass.Name == "MaxLengthAttribute")) { var length = (int)maxLengthAttr.ConstructorArguments.First().Value; if(length < maxLength) maxLength = length; } var name = p.Identifier.ToString(); switch(p.Type.ToString()) { case "bool": s.Append($@" buffer[offset++] = (byte)({name} == true ? 1 : 0);"); break; case "int": s.Append($@" buffer[offset++] = (byte){name}; buffer[offset++] = (byte)({name} >> 8); buffer[offset++] = (byte)({name} >> 16); buffer[offset++] = (byte)({name} >> 24);"); break; case "Guid": s.Append($@" Buffer.BlockCopy({name}.ToByteArray(), 0, buffer, offset, 16); offset += 16;"); break; case "string": case "string?": s.Append($@" if ({name} != null) {{ var str1 = Encoding.UTF8.GetBytes({name}.Substring(0, Math.Min({maxLength}, {name}.Length))); uint v = (uint)str1.Length; while (v >= 0x80) {{ buffer[offset++] = (byte)(v | 0x80); v >>= 7; }} buffer[offset++] = (byte)v; Buffer.BlockCopy(str1, 0, buffer, offset, str1.Length); offset += str1.Length; }} else {{ buffer[offset++] = 0; }} "); break; case "byte[]": case "byte[]?": s.Append($@" if ({name} != null) {{ int maxLength = Math.Min(PasswordHash.Length, {maxLength}); uint v = (uint)maxLength; while (v >= 0x80) {{ buffer[offset++] = (byte)(v | 0x80); v >>= 7; }} buffer[offset++] = (byte)v; Buffer.BlockCopy(PasswordHash, 0, buffer, offset, maxLength); offset += maxLength; }} else {{ buffer[offset++] = 0; }} "); break; default: throw new Exception($"Unkown type {p.Type.ToString()} on field {name}"); } } } s.Append(@$" return offset; }} public static {foundClass.Value.name} Deserialize(ReadOnlySpan buffer) {{ int offset = 0; {foundClass.Value.name} ret = new {foundClass.Value.name}(); {{ int count = 0; int shift = 0; byte b; do {{ // Check for a corrupted stream. Read a max of 5 bytes. // In a future version, add a DataFormatException. if (shift == 5 * 7) // 5 bytes max per Int32, shift += 7 throw new FormatException(""Format_Bad7BitInt32""); // ReadByte handles end of stream cases for us. b = buffer[offset++]; count |= (b & 0x7F) << shift; shift += 7; }} while ((b & 0x80) != 0); }} "); foreach (var member in foundClass.Value.syntaxNode.Members) { if (member is PropertyDeclarationSyntax p) { int maxLength = 256; var iMember = foundClass.Value.semanticModel.GetDeclaredSymbol(member); foreach (var maxLengthAttr in iMember.GetAttributes().Where(a => a.AttributeClass.Name == "MaxLengthAttribute")) { var length = (int)maxLengthAttr.ConstructorArguments.First().Value; if (length < maxLength) maxLength = length; } var name = p.Identifier.ToString(); switch (p.Type.ToString()) { case "bool": s.Append($@" ret.{name} = buffer[offset++] == 0 ? false : true;"); break; case "int": s.Append($@" ret.{name} = (int)(buffer[offset++] | buffer[offset++] << 8 | buffer[offset++] << 16 | buffer[offset++] << 24);"); break; case "Guid": s.Append($@" {{ ret.{name} = new Guid(buffer.Slice(offset, 16)); offset+=16; }}"); break; case "string": case "string?": s.Append($@" {{ int strLen = 0; int shift = 0; byte b; do {{ // Check for a corrupted stream. Read a max of 5 bytes. // In a future version, add a DataFormatException. if (shift == 5 * 7) // 5 bytes max per Int32, shift += 7 throw new FormatException(""Format_Bad7BitInt32""); // ReadByte handles end of stream cases for us. b = buffer[offset++]; strLen |= (b & 0x7F) << shift; shift += 7; }} while ((b & 0x80) != 0); if(strLen > 0) {{ ret.{name} = Encoding.UTF8.GetString(buffer.Slice(offset, strLen)); offset += strLen; }} }}"); break; case "byte[]": case "byte[]?": s.Append($@" {{ int strLen = 0; int shift = 0; byte b; do {{ // Check for a corrupted stream. Read a max of 5 bytes. // In a future version, add a DataFormatException. if (shift == 5 * 7) // 5 bytes max per Int32, shift += 7 throw new FormatException(""Format_Bad7BitInt32""); // ReadByte handles end of stream cases for us. b = buffer[offset++]; strLen |= (b & 0x7F) << shift; shift += 7; }} while ((b & 0x80) != 0); if(strLen > 0) {{ ret.{name} = buffer.Slice(offset, strLen).ToArray(); offset += strLen; }} }}"); break; default: throw new Exception($"Unkown type {p.Type.ToString()} on field {name}"); } } } s.Append(@$" return ret; }} }} {(foundClass.Value.nameSpace is null ? null : @"} ")}"); context.AddSource($"{foundClass.Key}.g.cs", s.ToString()); } } } class SyntaxInfo { public TypeDeclarationSyntax syntaxNode; public SemanticModel semanticModel; public string nameSpace; public string name; internal Compilation compilation; } class Comparer : IEqualityComparer<(TypeDeclarationSyntax, Compilation)> { public static readonly Comparer Instance = new Comparer(); public bool Equals((TypeDeclarationSyntax, Compilation) x, (TypeDeclarationSyntax, Compilation) y) { return x.Item1.Equals(y.Item1); } public int GetHashCode((TypeDeclarationSyntax, Compilation) obj) { return obj.Item1.GetHashCode(); } } [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] public class LobbyMessageAttribute : Attribute { } }