summaryrefslogtreecommitdiff
path: root/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/server.go')
-rw-r--r--server/server.go112
1 files changed, 104 insertions, 8 deletions
diff --git a/server/server.go b/server/server.go
index 522cd17..369d5b1 100644
--- a/server/server.go
+++ b/server/server.go
@@ -6,12 +6,20 @@ import (
"log"
"fmt"
"time"
+ "strconv"
+ "strings"
"os/signal"
"git.citrons.xyz/metronode/classic"
"git.citrons.xyz/metronode/phony"
)
var SoftwareName = "Metronode"
+var supportedExtensions = []string {
+ "ExtEntityPositions",
+}
+var requiredExtensions = []string {
+ "ExtEntityPositions",
+}
type ServerInfo struct {
Name string
@@ -227,8 +235,9 @@ func (s *Server) newLevel(info levelInfo) (levelId, *level) {
return s.LastId, l
}
-func (s *Server) newPlayer(cl *client, name string) *player {
- pl := newPlayer(s, cl, name)
+func (s *Server) newPlayer(
+ cl *client, name string, ext map[string]bool) *player {
+ pl := newPlayer(s, cl, name, ext)
s.players[name] = pl
return pl
}
@@ -280,7 +289,8 @@ func (s *Server) kick(playerName string, reason string) bool {
}
func (s *Server) NewPlayer(
- from phony.Actor, cl *client, name string, reply func(*player)) {
+ from phony.Actor, cl *client, name string, ext map[string]bool,
+ reply func(*player)) {
s.Act(from, func() {
banReason, ok := s.worldState.Banned[name]
if ok {
@@ -293,12 +303,12 @@ func (s *Server) NewPlayer(
s.players[name].Act(s, func() {
s.players[name].kick("Replaced by new connection")
s.Act(s.players[name], func() {
- s.newPlayer(cl, name)
+ s.newPlayer(cl, name, ext)
s.GetPlayer(from, name, reply)
})
})
} else {
- s.newPlayer(cl, name)
+ s.newPlayer(cl, name, ext)
s.GetPlayer(from, name, reply)
}
})
@@ -347,6 +357,7 @@ type client struct {
conn net.Conn
username string
player *player
+ extensions map[string]bool
}
func newClient(server *Server, srvInfo ServerInfo, conn net.Conn) *client {
@@ -364,7 +375,9 @@ func (cl *client) performHandshake(conn net.Conn, srvInfo ServerInfo) {
cl.conn = conn
conn.SetDeadline(time.Now().Add(10 * time.Second))
- packet, err := classic.SReadPacket(conn)
+ var ext = make(map[string]bool)
+
+ packet, err := classic.SReadPacket(conn, ext)
if cl.handleError(err) != nil {
return
}
@@ -377,6 +390,11 @@ func (cl *client) performHandshake(conn net.Conn, srvInfo ServerInfo) {
)
}
cl.username = classic.UnpadString(pid.Username)
+ if pid.Ext == classic.UseCpe {
+ if !cl.cpeHandshake(conn, ext) {
+ return
+ }
+ }
default:
cl.disconnect("Expected handshake")
return
@@ -385,6 +403,7 @@ func (cl *client) performHandshake(conn net.Conn, srvInfo ServerInfo) {
cl.disconnect("Invalid player name")
return
}
+
err = classic.WritePacket(conn, &classic.ServerId {
Version: 7,
ServerName: classic.PadString(srvInfo.Name),
@@ -393,16 +412,93 @@ func (cl *client) performHandshake(conn net.Conn, srvInfo ServerInfo) {
if cl.handleError(err) != nil {
return
}
- cl.server.NewPlayer(cl, cl, cl.username, func(pl *player) {
+ cl.server.NewPlayer(cl, cl, cl.username, ext, func(pl *player) {
cl.player = pl
})
conn.SetDeadline(time.Time{})
}
+func (cl *client) cpeHandshake(conn net.Conn, ext map[string]bool) bool {
+ var (packet classic.Packet; err error)
+ err = classic.WritePacket(conn, &classic.ExtInfo {
+ AppName: classic.PadString(SoftwareName),
+ ExtensionCount: int16(len(supportedExtensions)),
+ })
+ if cl.handleError(err) != nil {
+ return false
+ }
+ for _, extString := range supportedExtensions {
+ var (name string; version = 1)
+ split := strings.Split(extString, ".")
+ name = split[0]
+ if len(split) > 1 {
+ version, err = strconv.Atoi(split[1])
+ if err != nil {
+ panic(err)
+ }
+ }
+ err = classic.WritePacket(conn, &classic.ExtEntry {
+ ExtName: classic.PadString(name),
+ Version: int32(version),
+ })
+ if cl.handleError(err) != nil {
+ return false
+ }
+ }
+ packet, err = classic.SReadPacket(conn, ext)
+ if cl.handleError(err) != nil {
+ return false
+ }
+ var count int
+ switch info := packet.(type) {
+ case *classic.ExtInfo:
+ log.Printf(
+ "%s is connecting via '%s' client with %d extensions",
+ cl.username, info.AppName, info.ExtensionCount,
+ )
+ count = int(info.ExtensionCount)
+ default:
+ cl.disconnect("Expected ExtInfo")
+ return false
+ }
+ for i := 0; i < count; i++ {
+ var extString string
+ packet, err = classic.SReadPacket(conn, ext)
+ if cl.handleError(err) != nil {
+ return false
+ }
+ switch entry := packet.(type) {
+ case *classic.ExtEntry:
+ extString = classic.UnpadString(entry.ExtName)
+ if entry.Version != 1 {
+ extString += "." + strconv.Itoa(int(entry.Version))
+ }
+ default:
+ cl.disconnect("Expected ExtEntry")
+ return false
+ }
+ ext[extString] = true
+ }
+ var extList []string
+ for extString := range ext {
+ extList = append(extList, extString)
+ }
+ log.Printf(
+ "%s has extensions: %s", cl.username, strings.Join(extList, ", "),
+ )
+ for _, req := range requiredExtensions {
+ if !ext[req] {
+ cl.disconnect("Missing required extension: " + req)
+ }
+ }
+ cl.extensions = ext
+ return true
+}
+
func (cl *client) readPackets(conn net.Conn) {
for {
- packet, err := classic.SReadPacket(conn)
+ packet, err := classic.SReadPacket(conn, cl.extensions)
cl.Act(nil, func() {
if cl.handleError(err) != nil {
return