pkg/skel: refactor to use dependency injection
authorGabe Rosenhouse <rosenhouse@gmail.com>
Thu, 14 Jul 2016 02:12:06 +0000 (22:12 -0400)
committerGabe Rosenhouse <rosenhouse@gmail.com>
Thu, 14 Jul 2016 04:06:58 +0000 (00:06 -0400)
Extract dependencies on os to enable more complete unit test coverage

skel/skel.go
skel/skel_test.go
testutils/bad_reader.go [new file with mode: 0644]

index 9cf0391..484ef37 100644 (file)
@@ -18,6 +18,7 @@ package skel
 
 import (
        "fmt"
+       "io"
        "io/ioutil"
        "log"
        "os"
@@ -36,11 +37,14 @@ type CmdArgs struct {
        StdinData   []byte
 }
 
+type dispatcher struct {
+       Getenv func(string) string
+       Stdin  io.Reader
+}
+
 type reqForCmdEntry map[string]bool
 
-// PluginMain is the "main" for a plugin. It accepts
-// two callback functions for add and del commands.
-func PluginMain(cmdAdd, cmdDel func(_ *CmdArgs) error) {
+func (t *dispatcher) getCmdArgsFromEnv() (string, *CmdArgs, error) {
        var cmd, contID, netns, ifName, args, path string
 
        vars := []struct {
@@ -100,20 +104,21 @@ func PluginMain(cmdAdd, cmdDel func(_ *CmdArgs) error) {
 
        argsMissing := false
        for _, v := range vars {
-               *v.val = os.Getenv(v.name)
+               *v.val = t.Getenv(v.name)
                if v.reqForCmd[cmd] && *v.val == "" {
                        log.Printf("%v env variable missing", v.name)
+                       // TODO: test this logging ^^^  and log to stderr instead of stdout
                        argsMissing = true
                }
        }
 
        if argsMissing {
-               dieMsg("required env variables missing")
+               return "", nil, fmt.Errorf("required env variables missing")
        }
 
-       stdinData, err := ioutil.ReadAll(os.Stdin)
+       stdinData, err := ioutil.ReadAll(t.Stdin)
        if err != nil {
-               dieMsg("error reading from stdin: %v", err)
+               return "", nil, fmt.Errorf("error reading from stdin: %v", err)
        }
 
        cmdArgs := &CmdArgs{
@@ -124,6 +129,21 @@ func PluginMain(cmdAdd, cmdDel func(_ *CmdArgs) error) {
                Path:        path,
                StdinData:   stdinData,
        }
+       return cmd, cmdArgs, nil
+}
+
+func createTypedError(f string, args ...interface{}) *types.Error {
+       return &types.Error{
+               Code: 100,
+               Msg:  fmt.Sprintf(f, args...),
+       }
+}
+
+func (t *dispatcher) pluginMain(cmdAdd, cmdDel func(_ *CmdArgs) error) *types.Error {
+       cmd, cmdArgs, err := t.getCmdArgsFromEnv()
+       if err != nil {
+               return createTypedError(err.Error())
+       }
 
        switch cmd {
        case "ADD":
@@ -133,24 +153,31 @@ func PluginMain(cmdAdd, cmdDel func(_ *CmdArgs) error) {
                err = cmdDel(cmdArgs)
 
        default:
-               dieMsg("unknown CNI_COMMAND: %v", cmd)
+               return createTypedError("unknown CNI_COMMAND: %v", cmd)
        }
 
        if err != nil {
                if e, ok := err.(*types.Error); ok {
                        // don't wrap Error in Error
-                       dieErr(e)
+                       return e
                }
-               dieMsg(err.Error())
+               return createTypedError(err.Error())
        }
+       return nil
 }
 
-func dieMsg(f string, args ...interface{}) {
-       e := &types.Error{
-               Code: 100,
-               Msg:  fmt.Sprintf(f, args...),
+// PluginMain is the "main" for a plugin. It accepts
+// two callback functions for add and del commands.
+func PluginMain(cmdAdd, cmdDel func(_ *CmdArgs) error) {
+       caller := dispatcher{
+               Getenv: os.Getenv,
+               Stdin:  os.Stdin,
+       }
+
+       err := caller.pluginMain(cmdAdd, cmdDel)
+       if err != nil {
+               dieErr(err)
        }
-       dieErr(e)
 }
 
 func dieErr(e *types.Error) {
index a52e014..8a69d55 100644 (file)
 package skel
 
 import (
-       "os"
+       "errors"
+       "io"
+       "strings"
 
+       "github.com/containernetworking/cni/pkg/types"
+
+       "github.com/containernetworking/cni/pkg/testutils"
        . "github.com/onsi/ginkgo"
+       . "github.com/onsi/ginkgo/extensions/table"
        . "github.com/onsi/gomega"
 )
 
-var _ = Describe("Skel", func() {
+type fakeCmd struct {
+       CallCount int
+       Returns   struct {
+               Error error
+       }
+       Received struct {
+               CmdArgs *CmdArgs
+       }
+}
+
+func (c *fakeCmd) Func(args *CmdArgs) error {
+       c.CallCount++
+       c.Received.CmdArgs = args
+       return c.Returns.Error
+}
+
+var _ = Describe("dispatching to the correct callback", func() {
        var (
-               fNoop = func(_ *CmdArgs) error { return nil }
-               // fErr    = func(_ *CmdArgs) error { return errors.New("dummy") }
-               envVars = []struct {
-                       name string
-                       val  string
-               }{
-                       {"CNI_CONTAINERID", "dummy"},
-                       {"CNI_NETNS", "dummy"},
-                       {"CNI_IFNAME", "dummy"},
-                       {"CNI_ARGS", "dummy"},
-                       {"CNI_PATH", "dummy"},
-               }
+               environment     map[string]string
+               stdin           io.Reader
+               cmdAdd, cmdDel  *fakeCmd
+               dispatch        *dispatcher
+               expectedCmdArgs *CmdArgs
        )
 
-       It("Must be possible to set the env vars", func() {
-               for _, v := range envVars {
-                       err := os.Setenv(v.name, v.val)
-                       Expect(err).NotTo(HaveOccurred())
+       BeforeEach(func() {
+               environment = map[string]string{
+                       "CNI_COMMAND":     "ADD",
+                       "CNI_CONTAINERID": "some-container-id",
+                       "CNI_NETNS":       "/some/netns/path",
+                       "CNI_IFNAME":      "eth0",
+                       "CNI_ARGS":        "some;extra;args",
+                       "CNI_PATH":        "/some/cni/path",
+               }
+               stdin = strings.NewReader(`{ "some": "config" }`)
+               dispatch = &dispatcher{
+                       Getenv: func(key string) string { return environment[key] },
+                       Stdin:  stdin,
+               }
+               cmdAdd = &fakeCmd{}
+               cmdDel = &fakeCmd{}
+               expectedCmdArgs = &CmdArgs{
+                       ContainerID: "some-container-id",
+                       Netns:       "/some/netns/path",
+                       IfName:      "eth0",
+                       Args:        "some;extra;args",
+                       Path:        "/some/cni/path",
+                       StdinData:   []byte(`{ "some": "config" }`),
                }
        })
 
-       Context("When dummy environment variables are passed", func() {
+       var envVarChecker = func(envVar string, isRequired bool) {
+               delete(environment, envVar)
+
+               err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+               if isRequired {
+                       Expect(err).To(Equal(&types.Error{
+                               Code: 100,
+                               Msg:  "required env variables missing",
+                       }))
+               } else {
+                       Expect(err).NotTo(HaveOccurred())
+               }
+       }
+
+       Context("when the CNI_COMMAND is ADD", func() {
+               It("extracts env vars and stdin data and calls cmdAdd", func() {
+                       err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
 
-               It("should not fail with ADD and noop callback", func() {
-                       err := os.Setenv("CNI_COMMAND", "ADD")
                        Expect(err).NotTo(HaveOccurred())
-                       PluginMain(fNoop, nil)
+                       Expect(cmdAdd.CallCount).To(Equal(1))
+                       Expect(cmdDel.CallCount).To(Equal(0))
+                       Expect(cmdAdd.Received.CmdArgs).To(Equal(expectedCmdArgs))
                })
 
-               // TODO: figure out howto mock printing and os.Exit()
-               // It("should fail with ADD and error callback", func() {
-               //      err := os.Setenv("CNI_COMMAND", "ADD")
-               //      Expect(err).NotTo(HaveOccurred())
-               //      PluginMain(fErr, nil)
-               // })
+               It("does not call cmdDel", func() {
+                       err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
 
-               It("should not fail with DEL and noop callback", func() {
-                       err := os.Setenv("CNI_COMMAND", "DEL")
                        Expect(err).NotTo(HaveOccurred())
-                       PluginMain(nil, fNoop)
+                       Expect(cmdDel.CallCount).To(Equal(0))
                })
 
-               // TODO: figure out howto mock printing and os.Exit()
-               // It("should fail with DEL and error callback", func() {
-               //      err := os.Setenv("CNI_COMMAND", "DEL")
-               //      Expect(err).NotTo(HaveOccurred())
-               //      PluginMain(fErr, nil)
-               // })
+               DescribeTable("required / optional env vars", envVarChecker,
+                       // TODO: Entry("command", "CNI_COMMAND", true),
+                       Entry("container id", "CNI_CONTAINER_ID", false),
+                       Entry("net ns", "CNI_NETNS", true),
+                       Entry("if name", "CNI_IFNAME", true),
+                       Entry("args", "CNI_ARGS", false),
+                       Entry("path", "CNI_PATH", true),
+               )
+       })
+
+       Context("when the CNI_COMMAND is DEL", func() {
+               BeforeEach(func() {
+                       environment["CNI_COMMAND"] = "DEL"
+               })
+
+               It("calls cmdDel with the env vars and stdin data", func() {
+                       err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
 
-               It("should not fail with DEL and no NETNS and noop callback", func() {
-                       err := os.Setenv("CNI_COMMAND", "DEL")
                        Expect(err).NotTo(HaveOccurred())
-                       err = os.Unsetenv("CNI_NETNS")
+                       Expect(cmdDel.CallCount).To(Equal(1))
+                       Expect(cmdDel.Received.CmdArgs).To(Equal(expectedCmdArgs))
+               })
+
+               It("does not call cmdAdd", func() {
+                       err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+
                        Expect(err).NotTo(HaveOccurred())
-                       PluginMain(nil, fNoop)
+                       Expect(cmdAdd.CallCount).To(Equal(0))
                })
 
+               DescribeTable("required / optional env vars", envVarChecker,
+                       // TODO: Entry("command", "CNI_COMMAND", true),
+                       Entry("container id", "CNI_CONTAINER_ID", false),
+                       Entry("net ns", "CNI_NETNS", false),
+                       Entry("if name", "CNI_IFNAME", true),
+                       Entry("args", "CNI_ARGS", false),
+                       Entry("path", "CNI_PATH", true),
+               )
+       })
+
+       Context("when the CNI_COMMAND is unrecognized", func() {
+               BeforeEach(func() {
+                       environment["CNI_COMMAND"] = "NOPE"
+               })
+
+               It("does not call any cmd callback", func() {
+                       dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+
+                       Expect(cmdAdd.CallCount).To(Equal(0))
+                       Expect(cmdDel.CallCount).To(Equal(0))
+               })
+
+               It("returns an error", func() {
+                       err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+
+                       Expect(err).To(Equal(&types.Error{
+                               Code: 100,
+                               Msg:  "unknown CNI_COMMAND: NOPE",
+                       }))
+               })
+       })
+
+       Context("when stdin cannot be read", func() {
+               BeforeEach(func() {
+                       dispatch.Stdin = &testutils.BadReader{}
+               })
+
+               It("does not call any cmd callback", func() {
+                       dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+
+                       Expect(cmdAdd.CallCount).To(Equal(0))
+                       Expect(cmdDel.CallCount).To(Equal(0))
+               })
+
+               It("wraps and returns the error", func() {
+                       err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+
+                       Expect(err).To(Equal(&types.Error{
+                               Code: 100,
+                               Msg:  "error reading from stdin: banana",
+                       }))
+               })
+       })
+
+       Context("when the callback returns an error", func() {
+               Context("when it is a typed Error", func() {
+                       BeforeEach(func() {
+                               cmdAdd.Returns.Error = &types.Error{
+                                       Code: 1234,
+                                       Msg:  "insufficient something",
+                               }
+                       })
+
+                       It("returns the error as-is", func() {
+                               err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+
+                               Expect(err).To(Equal(&types.Error{
+                                       Code: 1234,
+                                       Msg:  "insufficient something",
+                               }))
+                       })
+               })
+
+               Context("when it is an unknown error", func() {
+                       BeforeEach(func() {
+                               cmdAdd.Returns.Error = errors.New("potato")
+                       })
+
+                       It("wraps and returns the error", func() {
+                               err := dispatch.pluginMain(cmdAdd.Func, cmdDel.Func)
+
+                               Expect(err).To(Equal(&types.Error{
+                                       Code: 100,
+                                       Msg:  "potato",
+                               }))
+                       })
+               })
        })
 })
diff --git a/testutils/bad_reader.go b/testutils/bad_reader.go
new file mode 100644 (file)
index 0000000..b3c0e97
--- /dev/null
@@ -0,0 +1,32 @@
+// Copyright 2014 CNI authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutils
+
+import "errors"
+
+type BadReader struct {
+       Error error
+}
+
+func (r *BadReader) Read(buffer []byte) (int, error) {
+       if r.Error != nil {
+               return 0, r.Error
+       }
+       return 0, errors.New("banana")
+}
+
+func (r *BadReader) Close() error {
+       return nil
+}