Remove useless wrappers

tls and vncclient close their underlying connection anyways
This commit is contained in:
flx5 2021-09-30 16:17:18 +02:00
parent 13c25063af
commit 991ced778f
4 changed files with 85 additions and 123 deletions

View File

@ -11,27 +11,6 @@ import (
"time"
)
type ProxyWrapper struct {
Client *ssh.Client
rawConnection net.Conn
sshConnection ssh.Conn
}
func (w ProxyWrapper) Close() {
if w.Client != nil {
w.Client.Close()
}
if w.sshConnection != nil {
w.sshConnection.Close()
}
if w.rawConnection != nil {
w.rawConnection.Close()
}
}
func GetXenProxyAddress(state multistep.StateBag) (string, error) {
proxyAddress, ok := state.GetOk("xen_proxy_address")
@ -66,7 +45,7 @@ func ConnectViaProxy(proxyAddress, address string) (net.Conn, error) {
return c, nil
}
func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, password string) (*ProxyWrapper, error) {
func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, password string) (*ssh.Client, error) {
connection, err := ConnectViaProxy(proxyAddress, fmt.Sprintf("%s:%d", host, port))
if err != nil {
@ -111,16 +90,16 @@ func ConnectSSHWithProxy(proxyAddress, host string, port int, username string, p
sshClient := ssh.NewClient(sshConn, sshChan, req)
wrapper := ProxyWrapper{
Client: sshClient,
rawConnection: connection,
sshConnection: sshConn,
}
return &wrapper, nil
return sshClient, nil
}
func CreatePortForwarding(proxyAddress, targetAddress string) (net.Listener, error) {
return CreateCustomPortForwarding(func() (net.Conn, error) {
return ConnectViaProxy(proxyAddress, targetAddress)
})
}
func CreateCustomPortForwarding(connectTarget func() (net.Conn, error)) (net.Listener, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -135,32 +114,22 @@ func CreatePortForwarding(proxyAddress, targetAddress string) (net.Listener, err
continue
}
go handleConnection(proxyAddress, targetAddress, accept)
go handleConnection(accept, connectTarget)
}
}()
return listener, nil
}
func handleConnection(proxyAddress, targetAddress string, accept net.Conn) {
defer accept.Close()
conn, err := ConnectViaProxy(proxyAddress, targetAddress)
if err != nil {
fmt.Printf("[FORWARD] Connect proxy Error: %v", err)
return
}
defer conn.Close()
func serviceForwardedConnection(clientConn net.Conn, targetConn net.Conn) {
txDone := make(chan struct{})
rxDone := make(chan struct{})
go func() {
_, err := io.Copy(conn, accept)
_, err := io.Copy(targetConn, clientConn)
// Close conn so that other copy operation unblocks
conn.Close()
targetConn.Close()
close(txDone)
if err != nil {
@ -170,10 +139,10 @@ func handleConnection(proxyAddress, targetAddress string, accept net.Conn) {
}()
go func() {
_, err := io.Copy(accept, conn)
_, err := io.Copy(clientConn, targetConn)
// Close accept so that other copy operation unblocks
accept.Close()
clientConn.Close()
close(rxDone)
if err != nil {
@ -185,3 +154,17 @@ func handleConnection(proxyAddress, targetAddress string, accept net.Conn) {
<-txDone
<-rxDone
}
func handleConnection(clientConn net.Conn, connectTarget func() (net.Conn, error)) {
defer clientConn.Close()
targetConn, err := connectTarget()
if err != nil {
fmt.Printf("[FORWARD] Connect proxy Error: %v", err)
return
}
defer targetConn.Close()
serviceForwardedConnection(clientConn, targetConn)
}

View File

@ -61,7 +61,7 @@ func ExecuteHostSSHCmd(state multistep.StateBag, cmd string) (stdout string, err
defer sshClient.Close()
return doExecuteSSHCmd(cmd, sshClient.Client)
return doExecuteSSHCmd(cmd, sshClient)
}
func connectSSH(host string, port int, username string, password string) (*ssh.Client, error) {

View File

@ -27,7 +27,6 @@ type StepTypeBootCommand struct {
func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateBag) multistep.StepAction {
config := state.Get("commonconfig").(CommonConfig)
ui := state.Get("ui").(packer.Ui)
c := state.Get("client").(*Connection)
httpPort := state.Get("http_port").(int)
var httpIP string
@ -42,44 +41,25 @@ func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateB
return multistep.ActionContinue
}
vmRef, err := c.client.VM.GetByNameLabel(c.session, config.VMName)
location, err := GetVNCConsoleLocation(state)
if err != nil {
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
if len(vmRef) != 1 {
ui.Error(fmt.Sprintf("expected to find a single VM, instead found '%d'. Ensure the VM name is unique", len(vmRef)))
}
consoles, err := c.client.VM.GetConsoles(c.session, vmRef[0])
if err != nil {
state.Put("error", err)
ui.Error(err.Error())
return multistep.ActionHalt
}
if len(consoles) != 1 {
ui.Error(fmt.Sprintf("expected to find a VM console, instead found '%d'. Ensure there is only one console", len(consoles)))
return multistep.ActionHalt
}
location, err := c.client.Console.GetLocation(c.session, consoles[0])
ui.Say(fmt.Sprintf("Connecting to the VM console VNC over xapi via %s", location))
vncConnectionWrapper, err := ConnectVNC(state, location)
vncClient, err := CreateVNCClient(state, location)
if err != nil {
ui.Error(err.Error())
return multistep.ActionHalt
}
defer vncConnectionWrapper.Close()
defer vncClient.Close()
log.Printf("Connected to the VNC console: %s", vncConnectionWrapper.Client.DesktopName)
log.Printf("Connected to the VNC console: %s", vncClient.DesktopName)
self.Ctx.Data = &bootCommandTemplateData{
config.VMName,
@ -87,7 +67,7 @@ func (self *StepTypeBootCommand) Run(ctx context.Context, state multistep.StateB
uint(httpPort),
}
vncDriver := bootcommand.NewVNCDriver(vncConnectionWrapper.Client, config.VNCConfig.BootKeyInterval)
vncDriver := bootcommand.NewVNCDriver(vncClient, config.VNCConfig.BootKeyInterval)
ui.Say("Typing boot commands over VNC...")

View File

@ -14,87 +14,86 @@ import (
"strings"
)
type VNCConnectionWrapper struct {
TlsConn *tls.Conn
rawConn net.Conn
}
func (v VNCConnectionWrapper) Close() {
if v.TlsConn != nil {
_ = v.TlsConn.Close()
}
if v.rawConn != nil {
_ = v.rawConn.Close()
}
}
type VNCClientWrapper struct {
Client *vnc.ClientConn
connection *VNCConnectionWrapper
}
func (v VNCClientWrapper) Close() {
if v.Client != nil {
_ = v.Client.Close()
}
v.connection.Close()
}
func CreateVNCConnection(state multistep.StateBag, location string) (*VNCConnectionWrapper, error) {
func GetVNCConsoleLocation(state multistep.StateBag) (string, error) {
xenClient := state.Get("client").(*Connection)
wrapper := VNCConnectionWrapper{}
config := state.Get("commonconfig").(CommonConfig)
var err error
vmRef, err := xenClient.client.VM.GetByNameLabel(xenClient.session, config.VMName)
target, err := getTcpAddress(location)
if err != nil {
return nil, err
return "", err
}
wrapper.rawConn, err = ConnectViaXenProxy(state, target)
if err != nil {
return nil, err
if len(vmRef) != 1 {
return "", fmt.Errorf("expected to find a single VM, instead found '%d'. Ensure the VM name is unique", len(vmRef))
}
wrapper.TlsConn, err = httpConnectRequest(location, string(xenClient.GetSessionRef()), wrapper.rawConn)
consoles, err := xenClient.client.VM.GetConsoles(xenClient.session, vmRef[0])
if err != nil {
wrapper.Close()
return nil, err
return "", err
}
return &wrapper, nil
if len(consoles) != 1 {
return "", fmt.Errorf("expected to find a VM console, instead found '%d'. Ensure there is only one console", len(consoles))
}
location, err := xenClient.client.Console.GetLocation(xenClient.session, consoles[0])
if err != nil {
return "", err
}
return location, nil
}
func ConnectVNC(state multistep.StateBag, location string) (*VNCClientWrapper, error) {
wrapper := VNCClientWrapper{}
func CreateVNCConnection(state multistep.StateBag, location string) (net.Conn, error) {
xenClient := state.Get("client").(*Connection)
var err error
wrapper.connection, err = CreateVNCConnection(state, location)
target, err := GetTcpAddressFromURL(location)
if err != nil {
return nil, err
}
wrapper.Client, err = vnc.Client(wrapper.connection.TlsConn, &vnc.ClientConfig{
rawConn, err := ConnectViaXenProxy(state, target)
if err != nil {
return nil, err
}
tlsConn, err := initializeVNCConnection(location, string(xenClient.GetSessionRef()), rawConn)
if err != nil {
rawConn.Close()
return nil, err
}
return tlsConn, nil
}
func CreateVNCClient(state multistep.StateBag, location string) (*vnc.ClientConn, error) {
var err error
connection, err := CreateVNCConnection(state, location)
if err != nil {
return nil, err
}
client, err := vnc.Client(connection, &vnc.ClientConfig{
Exclusive: false,
})
if err != nil {
wrapper.Close()
connection.Close()
return nil, err
}
return &wrapper, nil
return client, nil
}
func httpConnectRequest(location string, xenSessionRef string, proxyConnection net.Conn) (*tls.Conn, error) {
func initializeVNCConnection(location string, xenSessionRef string, rawConn net.Conn) (*tls.Conn, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
}
tlsConnection := tls.Client(proxyConnection, tlsConfig)
tlsConnection := tls.Client(rawConn, tlsConfig)
request, err := http.NewRequest(http.MethodConnect, location, http.NoBody)
@ -123,7 +122,7 @@ func httpConnectRequest(location string, xenSessionRef string, proxyConnection n
return tlsConnection, nil
}
func getTcpAddress(location string) (string, error) {
func GetTcpAddressFromURL(location string) (string, error) {
parsedUrl, err := url.Parse(location)
if err != nil {
return "", err