From 991ced778f161add3fc492329e45e4aef8428b83 Mon Sep 17 00:00:00 2001 From: flx5 <1330854+flx5@users.noreply.github.com> Date: Thu, 30 Sep 2021 16:17:18 +0200 Subject: [PATCH] Remove useless wrappers tls and vncclient close their underlying connection anyways --- builder/xenserver/common/proxy.go | 73 +++++-------- builder/xenserver/common/ssh.go | 2 +- .../common/step_type_boot_command.go | 30 +---- builder/xenserver/common/vnc.go | 103 +++++++++--------- 4 files changed, 85 insertions(+), 123 deletions(-) diff --git a/builder/xenserver/common/proxy.go b/builder/xenserver/common/proxy.go index c791381..6f38ac2 100644 --- a/builder/xenserver/common/proxy.go +++ b/builder/xenserver/common/proxy.go @@ -11,27 +11,6 @@ import ( "time" ) -type ProxyWrapper struct { - Client *ssh.Client - - rawConnection net.Conn - sshConnection ssh.Conn -} - -func (w ProxyWrapper) Close() { - if w.Client != nil { - w.Client.Close() - } - - if w.sshConnection != nil { - w.sshConnection.Close() - } - - if w.rawConnection != nil { - w.rawConnection.Close() - } -} - func GetXenProxyAddress(state multistep.StateBag) (string, error) { proxyAddress, ok := state.GetOk("xen_proxy_address") @@ -66,7 +45,7 @@ func ConnectViaProxy(proxyAddress, address string) (net.Conn, error) { return c, nil } -func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, password string) (*ProxyWrapper, error) { +func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, password string) (*ssh.Client, error) { connection, err := ConnectViaProxy(proxyAddress, fmt.Sprintf("%s:%d", host, port)) if err != nil { @@ -111,16 +90,16 @@ func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, p sshClient := ssh.NewClient(sshConn, sshChan, req) - wrapper := ProxyWrapper{ - Client: sshClient, - rawConnection: connection, - sshConnection: sshConn, - } - - return &wrapper, nil + return sshClient, nil } func CreatePortForwarding(proxyAddress, targetAddress string) (net.Listener, error) { + return CreateCustomPortForwarding(func() (net.Conn, error) { + return ConnectViaProxy(proxyAddress, targetAddress) + }) +} + +func CreateCustomPortForwarding(connectTarget func() (net.Conn, error)) (net.Listener, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -135,32 +114,22 @@ func CreatePortForwarding(proxyAddress, targetAddress string) (net.Listener, err continue } - go handleConnection(proxyAddress, targetAddress, accept) + go handleConnection(accept, connectTarget) } }() return listener, nil } -func handleConnection(proxyAddress, targetAddress string, accept net.Conn) { - defer accept.Close() - conn, err := ConnectViaProxy(proxyAddress, targetAddress) - - if err != nil { - fmt.Printf("[FORWARD] Connect proxy Error: %v", err) - return - } - - defer conn.Close() - +func serviceForwardedConnection(clientConn net.Conn, targetConn net.Conn) { txDone := make(chan struct{}) rxDone := make(chan struct{}) go func() { - _, err := io.Copy(conn, accept) + _, err := io.Copy(targetConn, clientConn) // Close conn so that other copy operation unblocks - conn.Close() + targetConn.Close() close(txDone) if err != nil { @@ -170,10 +139,10 @@ func handleConnection(proxyAddress, targetAddress string, accept net.Conn) { }() go func() { - _, err := io.Copy(accept, conn) + _, err := io.Copy(clientConn, targetConn) // Close accept so that other copy operation unblocks - accept.Close() + clientConn.Close() close(rxDone) if err != nil { @@ -185,3 +154,17 @@ func handleConnection(proxyAddress, targetAddress string, accept net.Conn) { <-txDone <-rxDone } + +func handleConnection(clientConn net.Conn, connectTarget func() (net.Conn, error)) { + defer clientConn.Close() + targetConn, err := connectTarget() + + if err != nil { + fmt.Printf("[FORWARD] Connect proxy Error: %v", err) + return + } + + defer targetConn.Close() + + serviceForwardedConnection(clientConn, targetConn) +} diff --git a/builder/xenserver/common/ssh.go b/builder/xenserver/common/ssh.go index 2c787a7..a131361 100644 --- a/builder/xenserver/common/ssh.go +++ b/builder/xenserver/common/ssh.go @@ -61,7 +61,7 @@ func ExecuteHostSSHCmd(state multistep.StateBag, cmd string) (stdout string, err defer sshClient.Close() - return doExecuteSSHCmd(cmd, sshClient.Client) + return doExecuteSSHCmd(cmd, sshClient) } func connectSSH(host string, port int, username string, password string) (*ssh.Client, error) { diff --git a/builder/xenserver/common/step_type_boot_command.go b/builder/xenserver/common/step_type_boot_command.go index e773085..197dc0f 100644 --- a/builder/xenserver/common/step_type_boot_command.go +++ b/builder/xenserver/common/step_type_boot_command.go @@ -27,7 +27,6 @@ type StepTypeBootCommand struct { func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction { config := state.Get("commonconfig").(CommonConfig) ui := state.Get("ui").(packer.Ui) - c := state.Get("client").(*Connection) httpPort := state.Get("http_port").(int) var httpIP string @@ -42,44 +41,25 @@ func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateB return multistep.ActionContinue } - vmRef, err := c.client.VM.GetByNameLabel(c.session, config.VMName) - + location, err := GetVNCConsoleLocation(state) if err != nil { state.Put("error", err) ui.Error(err.Error()) return multistep.ActionHalt } - if len(vmRef) != 1 { - ui.Error(fmt.Sprintf("expected to find a single VM, instead found '%d'. Ensure the VM name is unique", len(vmRef))) - } - - consoles, err := c.client.VM.GetConsoles(c.session, vmRef[0]) - if err != nil { - state.Put("error", err) - ui.Error(err.Error()) - return multistep.ActionHalt - } - - if len(consoles) != 1 { - ui.Error(fmt.Sprintf("expected to find a VM console, instead found '%d'. Ensure there is only one console", len(consoles))) - return multistep.ActionHalt - } - - location, err := c.client.Console.GetLocation(c.session, consoles[0]) - ui.Say(fmt.Sprintf("Connecting to the VM console VNC over xapi via %s", location)) - vncConnectionWrapper, err := ConnectVNC(state, location) + vncClient, err := CreateVNCClient(state, location) if err != nil { ui.Error(err.Error()) return multistep.ActionHalt } - defer vncConnectionWrapper.Close() + defer vncClient.Close() - log.Printf("Connected to the VNC console: %s", vncConnectionWrapper.Client.DesktopName) + log.Printf("Connected to the VNC console: %s", vncClient.DesktopName) self.Ctx.Data = &bootCommandTemplateData{ config.VMName, @@ -87,7 +67,7 @@ func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateB uint(httpPort), } - vncDriver := bootcommand.NewVNCDriver(vncConnectionWrapper.Client, config.VNCConfig.BootKeyInterval) + vncDriver := bootcommand.NewVNCDriver(vncClient, config.VNCConfig.BootKeyInterval) ui.Say("Typing boot commands over VNC...") diff --git a/builder/xenserver/common/vnc.go b/builder/xenserver/common/vnc.go index 405eb6d..601774f 100644 --- a/builder/xenserver/common/vnc.go +++ b/builder/xenserver/common/vnc.go @@ -14,87 +14,86 @@ import ( "strings" ) -type VNCConnectionWrapper struct { - TlsConn *tls.Conn - rawConn net.Conn -} - -func (v VNCConnectionWrapper) Close() { - if v.TlsConn != nil { - _ = v.TlsConn.Close() - } - - if v.rawConn != nil { - _ = v.rawConn.Close() - } -} - -type VNCClientWrapper struct { - Client *vnc.ClientConn - - connection *VNCConnectionWrapper -} - -func (v VNCClientWrapper) Close() { - if v.Client != nil { - _ = v.Client.Close() - } - - v.connection.Close() -} - -func CreateVNCConnection(state multistep.StateBag, location string) (*VNCConnectionWrapper, error) { +func GetVNCConsoleLocation(state multistep.StateBag) (string, error) { xenClient := state.Get("client").(*Connection) - wrapper := VNCConnectionWrapper{} + config := state.Get("commonconfig").(CommonConfig) - var err error + vmRef, err := xenClient.client.VM.GetByNameLabel(xenClient.session, config.VMName) - target, err := getTcpAddress(location) if err != nil { - return nil, err + return "", err } - wrapper.rawConn, err = ConnectViaXenProxy(state, target) - if err != nil { - return nil, err + if len(vmRef) != 1 { + return "", fmt.Errorf("expected to find a single VM, instead found '%d'. Ensure the VM name is unique", len(vmRef)) } - wrapper.TlsConn, err = httpConnectRequest(location, string(xenClient.GetSessionRef()), wrapper.rawConn) + consoles, err := xenClient.client.VM.GetConsoles(xenClient.session, vmRef[0]) + if err != nil { - wrapper.Close() - return nil, err + return "", err } - return &wrapper, nil + if len(consoles) != 1 { + return "", fmt.Errorf("expected to find a VM console, instead found '%d'. Ensure there is only one console", len(consoles)) + } + + location, err := xenClient.client.Console.GetLocation(xenClient.session, consoles[0]) + + if err != nil { + return "", err + } + + return location, nil } -func ConnectVNC(state multistep.StateBag, location string) (*VNCClientWrapper, error) { - wrapper := VNCClientWrapper{} +func CreateVNCConnection(state multistep.StateBag, location string) (net.Conn, error) { + xenClient := state.Get("client").(*Connection) - var err error - - wrapper.connection, err = CreateVNCConnection(state, location) + target, err := GetTcpAddressFromURL(location) if err != nil { return nil, err } - wrapper.Client, err = vnc.Client(wrapper.connection.TlsConn, &vnc.ClientConfig{ + rawConn, err := ConnectViaXenProxy(state, target) + if err != nil { + return nil, err + } + + tlsConn, err := initializeVNCConnection(location, string(xenClient.GetSessionRef()), rawConn) + if err != nil { + rawConn.Close() + return nil, err + } + + return tlsConn, nil +} + +func CreateVNCClient(state multistep.StateBag, location string) (*vnc.ClientConn, error) { + var err error + + connection, err := CreateVNCConnection(state, location) + if err != nil { + return nil, err + } + + client, err := vnc.Client(connection, &vnc.ClientConfig{ Exclusive: false, }) if err != nil { - wrapper.Close() + connection.Close() return nil, err } - return &wrapper, nil + return client, nil } -func httpConnectRequest(location string, xenSessionRef string, proxyConnection net.Conn) (*tls.Conn, error) { +func initializeVNCConnection(location string, xenSessionRef string, rawConn net.Conn) (*tls.Conn, error) { tlsConfig := &tls.Config{ InsecureSkipVerify: true, } - tlsConnection := tls.Client(proxyConnection, tlsConfig) + tlsConnection := tls.Client(rawConn, tlsConfig) request, err := http.NewRequest(http.MethodConnect, location, http.NoBody) @@ -123,7 +122,7 @@ func httpConnectRequest(location string, xenSessionRef string, proxyConnection n return tlsConnection, nil } -func getTcpAddress(location string) (string, error) { +func GetTcpAddressFromURL(location string) (string, error) { parsedUrl, err := url.Parse(location) if err != nil { return "", err